| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543 |
- #!/usr/bin/env python3
- import argparse
- import glob
- import importlib.util
- import logging
- import re
- import sys
- import time
- import os
- import asyncio
- from enum import StrEnum
- from pathlib import Path
- from typing import List, Optional, Tuple
- from abc import ABC, abstractmethod
- from dataclasses import dataclass, field
- import httpx
- from tabulate import tabulate
- from termcolor import colored
- from .backend import BackendConfig, DummyBackend
- from .proxy import ProxyConfig, ProxyManager
- logger = logging.getLogger(__name__)
- class TestStatus(StrEnum):
- PASS = "green"
- FAIL = "yellow"
- ERROR = "red"
- @dataclass
- class TestResult:
- """Result of a single test execution"""
- test_id: str
- description: str
- status: TestStatus
- duration: float
- error_messages: Optional[List[str]] = field(default_factory=list)
- class BaseProxyTest(ABC):
- """Base class for all HTTP tests - ASYNC VERSION
-
- All test classes must inherit from this class and implement the async run_test() method.
-
- Attributes:
- description (str): Description for the test that will be printed in overall summary
- url (str): URL to run the test against (default: 'http://localhost:4242/')
- method (str): HTTP method to use (default: GET)
- headers (dict): A dictionary of headers that will be sent by the client
- body (str): A string that will be sent as request body (if applicable)
- expected_status (int): The HTTP status that will be compared with the received one
- expected_headers (dict): Headers that must be present in the response
- expected_body_pattern (str): A regex used to match the body content
- expected_header_patterns (dict): Header name to regex pattern mapping
- forbidden_client_headers (list): Headers that must NOT be in the response
- expected_backend_headers (list): Headers that must be present in backend request
- forbidden_backend_headers (list): Headers that must NOT be in backend request
- backend_header_patterns (dict): Backend header name to regex pattern mapping
- """
- def __init__(self):
- self.test_id = self.__class__.__name__
- self.description = getattr(self, 'description', 'No description provided')
- self.url = getattr(self, 'url', 'http://localhost:4242/')
- self.method = getattr(self, 'method', 'GET')
- self.headers = getattr(self, 'headers', {})
- self.body = getattr(self, 'body', '')
- # Test configuration
- self.backend_config = getattr(self, 'backend_config', BackendConfig())
- self.proxy_config = getattr(self, 'proxy_config', ProxyConfig())
- # Validation rules
- self.expected_status = getattr(self, 'expected_status', 200)
- self.expected_headers = getattr(self, 'expected_headers', {})
- self.expected_body_pattern = getattr(self, 'expected_body_pattern', None)
- self.expected_header_patterns = getattr(self, 'expected_header_patterns', {})
- self.forbidden_client_headers = getattr(self, 'forbidden_client_headers', [])
- self.expected_backend_headers = getattr(self, 'expected_backend_headers', [])
- self.forbidden_backend_headers = getattr(self, 'forbidden_backend_headers', [])
- self.backend_header_patterns = getattr(self, 'backend_header_patterns', {})
- # Runtime data
- self.backend = None
- self.proxy = None
- self.response = None
- self.response_headers = {}
- self.response_body = ""
- async def setup(self):
- """Async setup - start backend and proxy"""
- logger.debug("Setting up test environment")
-
- # Start backend
- logger.debug(f"Instantiating backend with configuration: {self.backend_config}")
- self.backend = DummyBackend(self.backend_config)
- logger.debug("Starting backend")
- await self.backend.start()
-
- # Wait for backend to be ready
- await self.backend.wait_until_ready()
- # Start proxy
- logger.debug(f"Instantiating reverse proxy with configuration: {self.proxy_config}")
- self.proxy = ProxyManager(self.proxy_config)
- logger.debug(f"Starting reverse proxy with configuration: {self.backend_config}")
- self.proxy.start(self.backend_config)
-
- # Wait for proxy to be ready
- logger.debug("Waiting for proxy to be ready")
- await self.proxy.wait_until_ready(self.url)
- # Reset backend counters after health check
- self.backend.request_count = 0
- self.backend.received_headers = {}
- self.backend.received_body = ""
-
- logger.debug("Setup complete")
- async def teardown(self):
- """Async teardown - stop backend and proxy"""
- logger.debug("Cleaning up test environment")
- if self.backend:
- logger.debug("Stopping backend")
- await self.backend.stop()
- if self.proxy:
- logger.debug("Stopping reverse proxy")
- self.proxy.stop()
- async def make_request(self):
- """Make async HTTP request through the proxy"""
- logger.debug(f"Making {self.method} request to {self.url}")
- logger.debug(f"Request headers: {self.headers}")
- if self.body:
- logger.debug(f"Request body: {self.body}")
-
- async with httpx.AsyncClient(http2=True, timeout=10.0) as client:
- response = await client.request(
- method=self.method,
- url=self.url,
- headers=self.headers,
- content=self.body
- )
-
- logger.debug(f"Response status: {response.status_code}")
- logger.debug(f"Response headers: {dict(response.headers)}")
-
- self.response = response
- self.response_headers = dict(response.headers)
- self.response_body = response.text
-
- return response
- def validate_response(self) -> Tuple[TestStatus, Optional[List[str]]]:
- """Validate the HTTP response"""
- validation_errors = []
-
- # Check status code
- if self.response.status_code != self.expected_status:
- logger.info(f"Expected status {self.expected_status}, got {self.response.status_code}")
- validation_errors.append(
- f"Expected status {self.expected_status}, got {self.response.status_code}")
- # Check expected headers
- for header, expected_value in self.expected_headers.items():
- if header not in self.response_headers:
- logger.info(f"Expected header '{header}' not found in response")
- validation_errors.append(f"Expected header '{header}' not found in response")
- else:
- received_value = self.response_headers[header]
- if received_value != expected_value:
- logger.info(
- f"Header '{header}' expected '{expected_value}', got '{received_value}'")
- validation_errors.append(
- f"Header '{header}' expected '{expected_value}', got '{received_value}'")
- # Check header patterns
- for header, pattern in self.expected_header_patterns.items():
- if header not in self.response_headers:
- logger.info(f"Header '{header}' for pattern matching not found")
- validation_errors.append(f"Header '{header}' for pattern matching not found")
- elif not re.match(pattern, self.response_headers[header]):
- logger.info(f"Header '{header}' value doesn't match pattern '{pattern}'")
- validation_errors.append(
- f"Header '{header}' value doesn't match pattern '{pattern}'")
- # Check forbidden client headers
- for header in self.forbidden_client_headers:
- if header in self.response_headers:
- logger.info(f"Forbidden header '{header}' found in client response")
- validation_errors.append(f"Forbidden header '{header}' found in client response")
- # Check body pattern
- if self.expected_body_pattern and not re.search(
- self.expected_body_pattern, self.response_body):
- logger.info(f"Response body doesn't match pattern '{self.expected_body_pattern}'")
- validation_errors.append(
- f"Response body doesn't match pattern '{self.expected_body_pattern}'")
- # Check backend headers
- for header in self.expected_backend_headers:
- if header not in self.backend.received_headers:
- logger.info(f"Expected header {header} in backend headers")
- validation_errors.append(f"Expected header {header} in backend headers")
- # Check backend headers that shouldn't be there
- for header in self.forbidden_backend_headers:
- if header in self.backend.received_headers:
- logger.info(f"Forbidden header '{header}' found in backend request")
- validation_errors.append(f"Forbidden header '{header}' found in backend request")
- # Check backend header patterns
- for header, pattern in self.backend_header_patterns.items():
- if header not in self.backend.received_headers:
- logger.info(f"Expected backend header '{header}' not found")
- validation_errors.append(f"Expected backend header '{header}' not found")
- elif not re.match(pattern, self.backend.received_headers[header]):
- logger.info(f"Backend header '{header}' doesn't match pattern '{pattern}'")
- validation_errors.append(
- f"Backend header '{header}' doesn't match pattern '{pattern}'")
- logger.debug("All checks done")
- if validation_errors:
- return TestStatus.FAIL, validation_errors
- return TestStatus.PASS, None
- @abstractmethod
- async def run_test(self) -> bool:
- """Run the actual test logic. Must be implemented by subclasses
-
- This method MUST be async and should use await for any async operations.
- """
- await self.make_request()
- return True
- async def execute(self) -> TestResult:
- """Execute the complete test asynchronously"""
- start_time = time.time()
- try:
- await self.setup()
- success = await self.run_test()
- if success:
- logger.debug("Validating response...")
- # Validate response if request was made
- if self.response:
- validation_status, errors = self.validate_response()
- return TestResult(
- self.test_id,
- self.description,
- validation_status,
- time.time() - start_time,
- errors
- )
- else:
- return TestResult(
- self.test_id,
- self.description,
- TestStatus.PASS,
- time.time() - start_time,
- None
- )
- else:
- return TestResult(
- self.test_id,
- self.description,
- TestStatus.ERROR,
- time.time() - start_time,
- ["Test logic failed"]
- )
- except Exception as e:
- logger.exception(f"Test {self.test_id} failed with exception")
- return TestResult(
- self.test_id,
- self.description,
- TestStatus.ERROR,
- time.time() - start_time,
- [str(e)]
- )
- finally:
- await self.teardown()
- class TestRunner:
- """Async test runner"""
- def __init__(self):
- self.results = []
- async def run_all_tests(self, tests: List[BaseProxyTest]):
- """Run all tests sequentially (async)"""
- logger.debug("Running all tests")
- for test in tests:
- logger.info(colored(f"Running test: {test.test_id}", "blue"))
- result = await test.execute()
- logger.debug(f"Test result for {test.test_id}: {result}")
- self.results.append(result)
- async def run_all_tests_parallel(
- self,
- tests: List[BaseProxyTest],
- max_concurrent: int = 3
- ):
- """Run tests in parallel with concurrency limit
-
- Args:
- tests: List of tests to run
- max_concurrent: Maximum number of tests to run concurrently
- """
- logger.debug(f"Running tests in parallel (max {max_concurrent} concurrent)")
-
- semaphore = asyncio.Semaphore(max_concurrent)
-
- async def run_with_semaphore(test):
- async with semaphore:
- logger.info(colored(f"Running test: {test.test_id}", "blue"))
- result = await test.execute()
- logger.debug(f"Test result for {test.test_id}: {result}")
- return result
-
- # Run all tests concurrently with limit
- results = await asyncio.gather(
- *[run_with_semaphore(test) for test in tests],
- return_exceptions=False
- )
-
- self.results.extend(results)
- def print_summary(self):
- """Print test summary"""
- total_tests = len(self.results)
- passed_tests = sum(1 for r in self.results if r.status == TestStatus.PASS)
- failed_tests = total_tests - passed_tests
- total_time = sum(r.duration for r in self.results)
- print("\nTEST EXECUTION SUMMARY")
- print("=" * 80)
-
- # Print failures first
- for result in self.results:
- if result.status != TestStatus.PASS:
- print(f"\nTest {colored(result.test_id, 'blue')} - {result.description}")
- print(f"Status: {colored(result.status.name, result.status.value)}")
- if result.error_messages:
- print("Errors:")
- for msg in result.error_messages:
- print(colored(f" {msg}", result.status.value))
- # Print results table
- print("\n" + "=" * 80)
- result_table = [["Status", "Name", "Description", "Duration (s)"]]
- for result in self.results:
- status_color = colored(result.status.name, result.status.value)
- result_table.append([
- status_color,
- result.test_id,
- result.description,
- f"{result.duration:.3f}"
- ])
- print(tabulate(result_table, headers="firstrow"))
-
- # Print summary stats
- print("\n" + "=" * 80)
- summary_table = [[
- "Total Tests",
- "Passed",
- "Failed",
- "Total Duration (s)"
- ], [
- total_tests,
- colored(str(passed_tests), "green"),
- colored(str(failed_tests), "red" if failed_tests > 0 else "green"),
- f"{total_time:.3f}"
- ]]
- print(tabulate(summary_table, headers="firstrow"))
- print("=" * 80 + "\n")
- def discover_tests(paths: List[str]) -> List[BaseProxyTest]:
- """Discover all test classes in the passed paths and returns
- a list of tests ready to be run
- """
- all_files = []
- tests = []
- for p in paths:
- test_path = Path(p)
- if not test_path.exists():
- logger.error(f"{test_path} does not exist")
- continue
- logger.debug(f"Searching for test files in {test_path}")
- if test_path.is_dir():
- all_files.extend(test_path.glob("*.py"))
- elif test_path.is_file():
- if test_path.suffix == ".py":
- all_files.append(test_path)
- elif '*' in str(test_path) or '?' in str(test_path) or '[' in str(test_path):
- globbed_files = glob.glob(str(test_path))
- for f in globbed_files:
- if os.path.isfile(f) and f.endswith(".py"):
- all_files.append(Path(f))
- else:
- logger.error(f"Cannot find test files in {test_path}")
- raise RuntimeError(f"Cannot find test files in {test_path}")
- if not all_files:
- logger.error(f"No test file to import from {paths}")
- raise RuntimeError(f"No test files to import from {paths}")
- for test_file in all_files:
- try:
- spec = importlib.util.spec_from_file_location(
- f"test_{test_file.stem}",
- test_file,
- )
- module = importlib.util.module_from_spec(spec)
- spec.loader.exec_module(module)
- # Find all test classes
- for attr_name in dir(module):
- attr = getattr(module, attr_name)
- if (isinstance(attr, type) and
- issubclass(attr, BaseProxyTest) and
- attr != BaseProxyTest):
- tests.append(attr())
- logger.debug(f"Loaded tests from {test_file}: {[t.test_id for t in tests]}")
- except Exception as e:
- logger.error(f"Failed to load test file {test_file}: {e}")
- raise
- return tests
- def setup_logging(level):
- """Configure logging"""
- logging.basicConfig(
- level=level,
- format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
- )
- async def async_main(args):
- """Async main function - runs all tests"""
- logger.info("Starting httphound (async version)")
-
- # ASCII art only in debug mode
- if args.log_level == 'DEBUG':
- hound = r'''
- __
- \ ______/ V`-,
- } /~~
- /_)^ --,r'
- |b |b
- '''
- logger.debug(hound)
- try:
- # Discover tests (sync operation)
- tests = discover_tests(args.paths)
- logger.info(f"Discovered {len(tests)} tests in {args.paths}")
- if not tests:
- logger.info("No tests found")
- return 0
- # Run tests
- runner = TestRunner()
-
- if args.parallel:
- logger.info(f"Running tests in parallel (max {args.parallel} concurrent)")
- await runner.run_all_tests_parallel(tests, max_concurrent=args.parallel)
- else:
- logger.info("Running tests sequentially")
- await runner.run_all_tests(tests)
- # Print summary
- runner.print_summary()
- # Exit with error code if any tests failed
- failed_count = sum(1 for r in runner.results if r.status != TestStatus.PASS)
- if failed_count > 0:
- logger.error(f"{failed_count} out of {len(runner.results)} tests failed")
- return 1
-
- logger.info("All tests passed!")
- return 0
- except Exception as e:
- logger.error(f"Test execution failed: {e}")
- logger.exception("Full traceback:")
- return 1
- def main():
- """Synchronous entry point that launches async main"""
- parser = argparse.ArgumentParser(description="Reverse proxy test tool (async)")
- parser.add_argument(
- '--log-level', '-l',
- type=str.upper,
- choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
- default='INFO',
- help='Set logging level (Default: INFO)'
- )
- parser.add_argument(
- '--parallel', '-p',
- type=int,
- default=None,
- metavar='N',
- help='Run tests in parallel with max N concurrent tests'
- )
- parser.add_argument(
- 'paths',
- nargs='+',
- help='Directories, file patterns, or individual Python files (.py)'
- )
- args = parser.parse_args()
- log_level = getattr(logging, args.log_level.upper())
- setup_logging(log_level)
- # Run async main
- try:
- exit_code = asyncio.run(async_main(args))
- sys.exit(exit_code)
- except KeyboardInterrupt:
- logger.info("\nInterrupted by user")
- sys.exit(130)
- def start():
- """Console script entry point"""
- main()
|