main.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543
  1. #!/usr/bin/env python3
  2. import argparse
  3. import glob
  4. import importlib.util
  5. import logging
  6. import re
  7. import sys
  8. import time
  9. import os
  10. import asyncio
  11. from enum import StrEnum
  12. from pathlib import Path
  13. from typing import List, Optional, Tuple
  14. from abc import ABC, abstractmethod
  15. from dataclasses import dataclass, field
  16. import httpx
  17. from tabulate import tabulate
  18. from termcolor import colored
  19. from .backend import BackendConfig, DummyBackend
  20. from .proxy import ProxyConfig, ProxyManager
  21. logger = logging.getLogger(__name__)
  22. class TestStatus(StrEnum):
  23. PASS = "green"
  24. FAIL = "yellow"
  25. ERROR = "red"
  26. @dataclass
  27. class TestResult:
  28. """Result of a single test execution"""
  29. test_id: str
  30. description: str
  31. status: TestStatus
  32. duration: float
  33. error_messages: Optional[List[str]] = field(default_factory=list)
  34. class BaseProxyTest(ABC):
  35. """Base class for all HTTP tests - ASYNC VERSION
  36. All test classes must inherit from this class and implement the async run_test() method.
  37. Attributes:
  38. description (str): Description for the test that will be printed in overall summary
  39. url (str): URL to run the test against (default: 'http://localhost:4242/')
  40. method (str): HTTP method to use (default: GET)
  41. headers (dict): A dictionary of headers that will be sent by the client
  42. body (str): A string that will be sent as request body (if applicable)
  43. expected_status (int): The HTTP status that will be compared with the received one
  44. expected_headers (dict): Headers that must be present in the response
  45. expected_body_pattern (str): A regex used to match the body content
  46. expected_header_patterns (dict): Header name to regex pattern mapping
  47. forbidden_client_headers (list): Headers that must NOT be in the response
  48. expected_backend_headers (list): Headers that must be present in backend request
  49. forbidden_backend_headers (list): Headers that must NOT be in backend request
  50. backend_header_patterns (dict): Backend header name to regex pattern mapping
  51. """
  52. def __init__(self):
  53. self.test_id = self.__class__.__name__
  54. self.description = getattr(self, 'description', 'No description provided')
  55. self.url = getattr(self, 'url', 'http://localhost:4242/')
  56. self.method = getattr(self, 'method', 'GET')
  57. self.headers = getattr(self, 'headers', {})
  58. self.body = getattr(self, 'body', '')
  59. # Test configuration
  60. self.backend_config = getattr(self, 'backend_config', BackendConfig())
  61. self.proxy_config = getattr(self, 'proxy_config', ProxyConfig())
  62. # Validation rules
  63. self.expected_status = getattr(self, 'expected_status', 200)
  64. self.expected_headers = getattr(self, 'expected_headers', {})
  65. self.expected_body_pattern = getattr(self, 'expected_body_pattern', None)
  66. self.expected_header_patterns = getattr(self, 'expected_header_patterns', {})
  67. self.forbidden_client_headers = getattr(self, 'forbidden_client_headers', [])
  68. self.expected_backend_headers = getattr(self, 'expected_backend_headers', [])
  69. self.forbidden_backend_headers = getattr(self, 'forbidden_backend_headers', [])
  70. self.backend_header_patterns = getattr(self, 'backend_header_patterns', {})
  71. # Runtime data
  72. self.backend = None
  73. self.proxy = None
  74. self.response = None
  75. self.response_headers = {}
  76. self.response_body = ""
  77. async def setup(self):
  78. """Async setup - start backend and proxy"""
  79. logger.debug("Setting up test environment")
  80. # Start backend
  81. logger.debug(f"Instantiating backend with configuration: {self.backend_config}")
  82. self.backend = DummyBackend(self.backend_config)
  83. logger.debug("Starting backend")
  84. await self.backend.start()
  85. # Wait for backend to be ready
  86. await self.backend.wait_until_ready()
  87. # Start proxy
  88. logger.debug(f"Instantiating reverse proxy with configuration: {self.proxy_config}")
  89. self.proxy = ProxyManager(self.proxy_config)
  90. logger.debug(f"Starting reverse proxy with configuration: {self.backend_config}")
  91. self.proxy.start(self.backend_config)
  92. # Wait for proxy to be ready
  93. logger.debug("Waiting for proxy to be ready")
  94. await self.proxy.wait_until_ready(self.url)
  95. # Reset backend counters after health check
  96. self.backend.request_count = 0
  97. self.backend.received_headers = {}
  98. self.backend.received_body = ""
  99. logger.debug("Setup complete")
  100. async def teardown(self):
  101. """Async teardown - stop backend and proxy"""
  102. logger.debug("Cleaning up test environment")
  103. if self.backend:
  104. logger.debug("Stopping backend")
  105. await self.backend.stop()
  106. if self.proxy:
  107. logger.debug("Stopping reverse proxy")
  108. self.proxy.stop()
  109. async def make_request(self):
  110. """Make async HTTP request through the proxy"""
  111. logger.debug(f"Making {self.method} request to {self.url}")
  112. logger.debug(f"Request headers: {self.headers}")
  113. if self.body:
  114. logger.debug(f"Request body: {self.body}")
  115. async with httpx.AsyncClient(http2=True, timeout=10.0) as client:
  116. response = await client.request(
  117. method=self.method,
  118. url=self.url,
  119. headers=self.headers,
  120. content=self.body
  121. )
  122. logger.debug(f"Response status: {response.status_code}")
  123. logger.debug(f"Response headers: {dict(response.headers)}")
  124. self.response = response
  125. self.response_headers = dict(response.headers)
  126. self.response_body = response.text
  127. return response
  128. def validate_response(self) -> Tuple[TestStatus, Optional[List[str]]]:
  129. """Validate the HTTP response"""
  130. validation_errors = []
  131. # Check status code
  132. if self.response.status_code != self.expected_status:
  133. logger.info(f"Expected status {self.expected_status}, got {self.response.status_code}")
  134. validation_errors.append(
  135. f"Expected status {self.expected_status}, got {self.response.status_code}")
  136. # Check expected headers
  137. for header, expected_value in self.expected_headers.items():
  138. if header not in self.response_headers:
  139. logger.info(f"Expected header '{header}' not found in response")
  140. validation_errors.append(f"Expected header '{header}' not found in response")
  141. else:
  142. received_value = self.response_headers[header]
  143. if received_value != expected_value:
  144. logger.info(
  145. f"Header '{header}' expected '{expected_value}', got '{received_value}'")
  146. validation_errors.append(
  147. f"Header '{header}' expected '{expected_value}', got '{received_value}'")
  148. # Check header patterns
  149. for header, pattern in self.expected_header_patterns.items():
  150. if header not in self.response_headers:
  151. logger.info(f"Header '{header}' for pattern matching not found")
  152. validation_errors.append(f"Header '{header}' for pattern matching not found")
  153. elif not re.match(pattern, self.response_headers[header]):
  154. logger.info(f"Header '{header}' value doesn't match pattern '{pattern}'")
  155. validation_errors.append(
  156. f"Header '{header}' value doesn't match pattern '{pattern}'")
  157. # Check forbidden client headers
  158. for header in self.forbidden_client_headers:
  159. if header in self.response_headers:
  160. logger.info(f"Forbidden header '{header}' found in client response")
  161. validation_errors.append(f"Forbidden header '{header}' found in client response")
  162. # Check body pattern
  163. if self.expected_body_pattern and not re.search(
  164. self.expected_body_pattern, self.response_body):
  165. logger.info(f"Response body doesn't match pattern '{self.expected_body_pattern}'")
  166. validation_errors.append(
  167. f"Response body doesn't match pattern '{self.expected_body_pattern}'")
  168. # Check backend headers
  169. for header in self.expected_backend_headers:
  170. if header not in self.backend.received_headers:
  171. logger.info(f"Expected header {header} in backend headers")
  172. validation_errors.append(f"Expected header {header} in backend headers")
  173. # Check backend headers that shouldn't be there
  174. for header in self.forbidden_backend_headers:
  175. if header in self.backend.received_headers:
  176. logger.info(f"Forbidden header '{header}' found in backend request")
  177. validation_errors.append(f"Forbidden header '{header}' found in backend request")
  178. # Check backend header patterns
  179. for header, pattern in self.backend_header_patterns.items():
  180. if header not in self.backend.received_headers:
  181. logger.info(f"Expected backend header '{header}' not found")
  182. validation_errors.append(f"Expected backend header '{header}' not found")
  183. elif not re.match(pattern, self.backend.received_headers[header]):
  184. logger.info(f"Backend header '{header}' doesn't match pattern '{pattern}'")
  185. validation_errors.append(
  186. f"Backend header '{header}' doesn't match pattern '{pattern}'")
  187. logger.debug("All checks done")
  188. if validation_errors:
  189. return TestStatus.FAIL, validation_errors
  190. return TestStatus.PASS, None
  191. @abstractmethod
  192. async def run_test(self) -> bool:
  193. """Run the actual test logic. Must be implemented by subclasses
  194. This method MUST be async and should use await for any async operations.
  195. """
  196. await self.make_request()
  197. return True
  198. async def execute(self) -> TestResult:
  199. """Execute the complete test asynchronously"""
  200. start_time = time.time()
  201. try:
  202. await self.setup()
  203. success = await self.run_test()
  204. if success:
  205. logger.debug("Validating response...")
  206. # Validate response if request was made
  207. if self.response:
  208. validation_status, errors = self.validate_response()
  209. return TestResult(
  210. self.test_id,
  211. self.description,
  212. validation_status,
  213. time.time() - start_time,
  214. errors
  215. )
  216. else:
  217. return TestResult(
  218. self.test_id,
  219. self.description,
  220. TestStatus.PASS,
  221. time.time() - start_time,
  222. None
  223. )
  224. else:
  225. return TestResult(
  226. self.test_id,
  227. self.description,
  228. TestStatus.ERROR,
  229. time.time() - start_time,
  230. ["Test logic failed"]
  231. )
  232. except Exception as e:
  233. logger.exception(f"Test {self.test_id} failed with exception")
  234. return TestResult(
  235. self.test_id,
  236. self.description,
  237. TestStatus.ERROR,
  238. time.time() - start_time,
  239. [str(e)]
  240. )
  241. finally:
  242. await self.teardown()
  243. class TestRunner:
  244. """Async test runner"""
  245. def __init__(self):
  246. self.results = []
  247. async def run_all_tests(self, tests: List[BaseProxyTest]):
  248. """Run all tests sequentially (async)"""
  249. logger.debug("Running all tests")
  250. for test in tests:
  251. logger.info(colored(f"Running test: {test.test_id}", "blue"))
  252. result = await test.execute()
  253. logger.debug(f"Test result for {test.test_id}: {result}")
  254. self.results.append(result)
  255. async def run_all_tests_parallel(
  256. self,
  257. tests: List[BaseProxyTest],
  258. max_concurrent: int = 3
  259. ):
  260. """Run tests in parallel with concurrency limit
  261. Args:
  262. tests: List of tests to run
  263. max_concurrent: Maximum number of tests to run concurrently
  264. """
  265. logger.debug(f"Running tests in parallel (max {max_concurrent} concurrent)")
  266. semaphore = asyncio.Semaphore(max_concurrent)
  267. async def run_with_semaphore(test):
  268. async with semaphore:
  269. logger.info(colored(f"Running test: {test.test_id}", "blue"))
  270. result = await test.execute()
  271. logger.debug(f"Test result for {test.test_id}: {result}")
  272. return result
  273. # Run all tests concurrently with limit
  274. results = await asyncio.gather(
  275. *[run_with_semaphore(test) for test in tests],
  276. return_exceptions=False
  277. )
  278. self.results.extend(results)
  279. def print_summary(self):
  280. """Print test summary"""
  281. total_tests = len(self.results)
  282. passed_tests = sum(1 for r in self.results if r.status == TestStatus.PASS)
  283. failed_tests = total_tests - passed_tests
  284. total_time = sum(r.duration for r in self.results)
  285. print("\nTEST EXECUTION SUMMARY")
  286. print("=" * 80)
  287. # Print failures first
  288. for result in self.results:
  289. if result.status != TestStatus.PASS:
  290. print(f"\nTest {colored(result.test_id, 'blue')} - {result.description}")
  291. print(f"Status: {colored(result.status.name, result.status.value)}")
  292. if result.error_messages:
  293. print("Errors:")
  294. for msg in result.error_messages:
  295. print(colored(f" {msg}", result.status.value))
  296. # Print results table
  297. print("\n" + "=" * 80)
  298. result_table = [["Status", "Name", "Description", "Duration (s)"]]
  299. for result in self.results:
  300. status_color = colored(result.status.name, result.status.value)
  301. result_table.append([
  302. status_color,
  303. result.test_id,
  304. result.description,
  305. f"{result.duration:.3f}"
  306. ])
  307. print(tabulate(result_table, headers="firstrow"))
  308. # Print summary stats
  309. print("\n" + "=" * 80)
  310. summary_table = [[
  311. "Total Tests",
  312. "Passed",
  313. "Failed",
  314. "Total Duration (s)"
  315. ], [
  316. total_tests,
  317. colored(str(passed_tests), "green"),
  318. colored(str(failed_tests), "red" if failed_tests > 0 else "green"),
  319. f"{total_time:.3f}"
  320. ]]
  321. print(tabulate(summary_table, headers="firstrow"))
  322. print("=" * 80 + "\n")
  323. def discover_tests(paths: List[str]) -> List[BaseProxyTest]:
  324. """Discover all test classes in the passed paths and returns
  325. a list of tests ready to be run
  326. """
  327. all_files = []
  328. tests = []
  329. for p in paths:
  330. test_path = Path(p)
  331. if not test_path.exists():
  332. logger.error(f"{test_path} does not exist")
  333. continue
  334. logger.debug(f"Searching for test files in {test_path}")
  335. if test_path.is_dir():
  336. all_files.extend(test_path.glob("*.py"))
  337. elif test_path.is_file():
  338. if test_path.suffix == ".py":
  339. all_files.append(test_path)
  340. elif '*' in str(test_path) or '?' in str(test_path) or '[' in str(test_path):
  341. globbed_files = glob.glob(str(test_path))
  342. for f in globbed_files:
  343. if os.path.isfile(f) and f.endswith(".py"):
  344. all_files.append(Path(f))
  345. else:
  346. logger.error(f"Cannot find test files in {test_path}")
  347. raise RuntimeError(f"Cannot find test files in {test_path}")
  348. if not all_files:
  349. logger.error(f"No test file to import from {paths}")
  350. raise RuntimeError(f"No test files to import from {paths}")
  351. for test_file in all_files:
  352. try:
  353. spec = importlib.util.spec_from_file_location(
  354. f"test_{test_file.stem}",
  355. test_file,
  356. )
  357. module = importlib.util.module_from_spec(spec)
  358. spec.loader.exec_module(module)
  359. # Find all test classes
  360. for attr_name in dir(module):
  361. attr = getattr(module, attr_name)
  362. if (isinstance(attr, type) and
  363. issubclass(attr, BaseProxyTest) and
  364. attr != BaseProxyTest):
  365. tests.append(attr())
  366. logger.debug(f"Loaded tests from {test_file}: {[t.test_id for t in tests]}")
  367. except Exception as e:
  368. logger.error(f"Failed to load test file {test_file}: {e}")
  369. raise
  370. return tests
  371. def setup_logging(level):
  372. """Configure logging"""
  373. logging.basicConfig(
  374. level=level,
  375. format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
  376. )
  377. async def async_main(args):
  378. """Async main function - runs all tests"""
  379. logger.info("Starting httphound (async version)")
  380. # ASCII art only in debug mode
  381. if args.log_level == 'DEBUG':
  382. hound = r'''
  383. __
  384. \ ______/ V`-,
  385. } /~~
  386. /_)^ --,r'
  387. |b |b
  388. '''
  389. logger.debug(hound)
  390. try:
  391. # Discover tests (sync operation)
  392. tests = discover_tests(args.paths)
  393. logger.info(f"Discovered {len(tests)} tests in {args.paths}")
  394. if not tests:
  395. logger.info("No tests found")
  396. return 0
  397. # Run tests
  398. runner = TestRunner()
  399. if args.parallel:
  400. logger.info(f"Running tests in parallel (max {args.parallel} concurrent)")
  401. await runner.run_all_tests_parallel(tests, max_concurrent=args.parallel)
  402. else:
  403. logger.info("Running tests sequentially")
  404. await runner.run_all_tests(tests)
  405. # Print summary
  406. runner.print_summary()
  407. # Exit with error code if any tests failed
  408. failed_count = sum(1 for r in runner.results if r.status != TestStatus.PASS)
  409. if failed_count > 0:
  410. logger.error(f"{failed_count} out of {len(runner.results)} tests failed")
  411. return 1
  412. logger.info("All tests passed!")
  413. return 0
  414. except Exception as e:
  415. logger.error(f"Test execution failed: {e}")
  416. logger.exception("Full traceback:")
  417. return 1
  418. def main():
  419. """Synchronous entry point that launches async main"""
  420. parser = argparse.ArgumentParser(description="Reverse proxy test tool (async)")
  421. parser.add_argument(
  422. '--log-level', '-l',
  423. type=str.upper,
  424. choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
  425. default='INFO',
  426. help='Set logging level (Default: INFO)'
  427. )
  428. parser.add_argument(
  429. '--parallel', '-p',
  430. type=int,
  431. default=None,
  432. metavar='N',
  433. help='Run tests in parallel with max N concurrent tests'
  434. )
  435. parser.add_argument(
  436. 'paths',
  437. nargs='+',
  438. help='Directories, file patterns, or individual Python files (.py)'
  439. )
  440. args = parser.parse_args()
  441. log_level = getattr(logging, args.log_level.upper())
  442. setup_logging(log_level)
  443. # Run async main
  444. try:
  445. exit_code = asyncio.run(async_main(args))
  446. sys.exit(exit_code)
  447. except KeyboardInterrupt:
  448. logger.info("\nInterrupted by user")
  449. sys.exit(130)
  450. def start():
  451. """Console script entry point"""
  452. main()