|
|
@@ -1,7 +1,6 @@
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
import argparse
|
|
|
-
|
|
|
import glob
|
|
|
import importlib.util
|
|
|
import logging
|
|
|
@@ -44,49 +43,29 @@ class TestResult:
|
|
|
|
|
|
|
|
|
class BaseProxyTest(ABC):
|
|
|
- """Base class for all HTTP tests
|
|
|
-
|
|
|
- All test classes must inherit from this class and eventually
|
|
|
- override the required parameters.
|
|
|
-
|
|
|
+ """Base class for all HTTP tests - ASYNC VERSION
|
|
|
+
|
|
|
+ All test classes must inherit from this class and implement the async run_test() method.
|
|
|
+
|
|
|
Attributes:
|
|
|
description (str): Description for the test that will be printed in overall summary
|
|
|
url (str): URL to run the test against (default: 'http://localhost:4242/')
|
|
|
method (str): HTTP method to use (default: GET)
|
|
|
headers (dict): A dictionary of headers that will be sent by the client
|
|
|
- while performing the request.
|
|
|
body (str): A string that will be sent as request body (if applicable)
|
|
|
- expected_status (int): The HTTP status that will be compared with the received
|
|
|
- one. The test fails if doesnt' match (default: 200)
|
|
|
- expected_headers (dict): A dictionary (header_name: header_value) of headers
|
|
|
- that needs to be present in the response. The test fails if the header_name
|
|
|
- is not present in the response or if header_value doesn't match.
|
|
|
- expected_body_pattern (str): A regex used to match the body content. The test
|
|
|
- fails if the match isn't found.
|
|
|
- expected_header_patterns (dict): a dictionary (header_name: header_value_pattern)
|
|
|
- that will be checked against received headers. If header_name is not present
|
|
|
- in the response headers or header_value_pattern doesn't match the header
|
|
|
- content, the test will fail.
|
|
|
- forbidden_client_headers (list): A list of header names that needs to be *absent*
|
|
|
- from the header list sent to the client. If any of the header name is found in
|
|
|
- the response the test will fail.
|
|
|
- expected_backend_headers (list): A list of headers that needs to be present in the
|
|
|
- request sent to the backend. This is used to check if the reverse proxy deletes
|
|
|
- one or more headers before forwarding the request to the backend. The test will
|
|
|
- fail if one or more header are not present.
|
|
|
- forbidden_backend_headers (list): Similar to the above parameter but the test will
|
|
|
- fail instead if any of the header is found among the headers received by the
|
|
|
- backend.
|
|
|
- backend_header_patterns (dict): A dictionary (header_name: header_value_pattern)
|
|
|
- that will be checked against the headers received by the backend. If the header
|
|
|
- name is not present in the backend headers or header_value_pattern doesn't match
|
|
|
- the header value received by the backend, the test will fail.
|
|
|
+ expected_status (int): The HTTP status that will be compared with the received one
|
|
|
+ expected_headers (dict): Headers that must be present in the response
|
|
|
+ expected_body_pattern (str): A regex used to match the body content
|
|
|
+ expected_header_patterns (dict): Header name to regex pattern mapping
|
|
|
+ forbidden_client_headers (list): Headers that must NOT be in the response
|
|
|
+ expected_backend_headers (list): Headers that must be present in backend request
|
|
|
+ forbidden_backend_headers (list): Headers that must NOT be in backend request
|
|
|
+ backend_header_patterns (dict): Backend header name to regex pattern mapping
|
|
|
"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.test_id = self.__class__.__name__
|
|
|
- self.description = getattr(self, 'description',
|
|
|
- 'No description provided')
|
|
|
+ self.description = getattr(self, 'description', 'No description provided')
|
|
|
self.url = getattr(self, 'url', 'http://localhost:4242/')
|
|
|
self.method = getattr(self, 'method', 'GET')
|
|
|
self.headers = getattr(self, 'headers', {})
|
|
|
@@ -113,51 +92,68 @@ class BaseProxyTest(ABC):
|
|
|
self.response_headers = {}
|
|
|
self.response_body = ""
|
|
|
|
|
|
- def setup(self):
|
|
|
+ async def setup(self):
|
|
|
+ """Async setup - start backend and proxy"""
|
|
|
logger.debug("Setting up test environment")
|
|
|
+
|
|
|
# Start backend
|
|
|
logger.debug(f"Instantiating backend with configuration: {self.backend_config}")
|
|
|
self.backend = DummyBackend(self.backend_config)
|
|
|
logger.debug("Starting backend")
|
|
|
- self.backend.start()
|
|
|
- time.sleep(999)
|
|
|
+ await self.backend.start()
|
|
|
+
|
|
|
+ # Wait for backend to be ready
|
|
|
+ await self.backend.wait_until_ready()
|
|
|
|
|
|
# Start proxy
|
|
|
logger.debug(f"Instantiating reverse proxy with configuration: {self.proxy_config}")
|
|
|
self.proxy = ProxyManager(self.proxy_config)
|
|
|
logger.debug(f"Starting reverse proxy with configuration: {self.backend_config}")
|
|
|
self.proxy.start(self.backend_config)
|
|
|
- logger.debug("Sleeping for 0.1s before proceeding")
|
|
|
- time.sleep(0.1)
|
|
|
-
|
|
|
- def teardown(self):
|
|
|
+
|
|
|
+ # Wait for proxy to be ready
|
|
|
+ logger.debug("Waiting for proxy to be ready")
|
|
|
+ await self.proxy.wait_until_ready(self.url)
|
|
|
+ logger.debug("Setup complete")
|
|
|
+
|
|
|
+ async def teardown(self):
|
|
|
+ """Async teardown - stop backend and proxy"""
|
|
|
logger.debug("Cleaning up test environment")
|
|
|
if self.backend:
|
|
|
logger.debug("Stopping backend")
|
|
|
- asyncio.run(self.backend.stop())
|
|
|
+ await self.backend.stop()
|
|
|
if self.proxy:
|
|
|
logger.debug("Stopping reverse proxy")
|
|
|
self.proxy.stop()
|
|
|
|
|
|
- def make_request(self):
|
|
|
- """Make HTTP request through the proxy"""
|
|
|
-
|
|
|
- request = httpx.Request(
|
|
|
- method=self.method,
|
|
|
- url=self.url,
|
|
|
- headers=self.headers,
|
|
|
- content=self.body)
|
|
|
- logger.debug(f"Performing HTTP request: {request}")
|
|
|
- with httpx.Client(http2=True) as client:
|
|
|
- response = client.send(request=request)
|
|
|
- logger.debug(f"Response: {response}")
|
|
|
+ async def make_request(self):
|
|
|
+ """Make async HTTP request through the proxy"""
|
|
|
+ logger.debug(f"Making {self.method} request to {self.url}")
|
|
|
+ logger.debug(f"Request headers: {self.headers}")
|
|
|
+ if self.body:
|
|
|
+ logger.debug(f"Request body: {self.body}")
|
|
|
+
|
|
|
+ async with httpx.AsyncClient(http2=True, timeout=10.0) as client:
|
|
|
+ response = await client.request(
|
|
|
+ method=self.method,
|
|
|
+ url=self.url,
|
|
|
+ headers=self.headers,
|
|
|
+ content=self.body
|
|
|
+ )
|
|
|
+
|
|
|
+ logger.debug(f"Response status: {response.status_code}")
|
|
|
+ logger.debug(f"Response headers: {dict(response.headers)}")
|
|
|
+
|
|
|
self.response = response
|
|
|
-
|
|
|
- return response
|
|
|
+ self.response_headers = dict(response.headers)
|
|
|
+ self.response_body = response.text
|
|
|
+
|
|
|
+ return response
|
|
|
|
|
|
def validate_response(self) -> Tuple[TestStatus, Optional[List[str]]]:
|
|
|
"""Validate the HTTP response"""
|
|
|
validation_errors = []
|
|
|
+
|
|
|
# Check status code
|
|
|
if self.response.status_code != self.expected_status:
|
|
|
logger.info(f"Expected status {self.expected_status}, got {self.response.status_code}")
|
|
|
@@ -182,8 +178,7 @@ class BaseProxyTest(ABC):
|
|
|
if header not in self.response_headers:
|
|
|
logger.info(f"Header '{header}' for pattern matching not found")
|
|
|
validation_errors.append(f"Header '{header}' for pattern matching not found")
|
|
|
-
|
|
|
- if not re.match(pattern, self.response_headers[header]):
|
|
|
+ elif not re.match(pattern, self.response_headers[header]):
|
|
|
logger.info(f"Header '{header}' value doesn't match pattern '{pattern}'")
|
|
|
validation_errors.append(
|
|
|
f"Header '{header}' value doesn't match pattern '{pattern}'")
|
|
|
@@ -218,8 +213,7 @@ class BaseProxyTest(ABC):
|
|
|
if header not in self.backend.received_headers:
|
|
|
logger.info(f"Expected backend header '{header}' not found")
|
|
|
validation_errors.append(f"Expected backend header '{header}' not found")
|
|
|
-
|
|
|
- if not re.match(pattern, self.backend.received_headers[header]):
|
|
|
+ elif not re.match(pattern, self.backend.received_headers[header]):
|
|
|
logger.info(f"Backend header '{header}' doesn't match pattern '{pattern}'")
|
|
|
validation_errors.append(
|
|
|
f"Backend header '{header}' doesn't match pattern '{pattern}'")
|
|
|
@@ -229,19 +223,22 @@ class BaseProxyTest(ABC):
|
|
|
return TestStatus.FAIL, validation_errors
|
|
|
return TestStatus.PASS, None
|
|
|
|
|
|
- #@abstractmethod
|
|
|
- def run_test(self) -> bool:
|
|
|
- """Run the actual test logic. Must be implemented by subclasses"""
|
|
|
- self.make_request()
|
|
|
+ @abstractmethod
|
|
|
+ async def run_test(self) -> bool:
|
|
|
+ """Run the actual test logic. Must be implemented by subclasses
|
|
|
+
|
|
|
+ This method MUST be async and should use await for any async operations.
|
|
|
+ """
|
|
|
+ await self.make_request()
|
|
|
return True
|
|
|
|
|
|
- def execute(self) -> TestResult:
|
|
|
- """Execute the complete test"""
|
|
|
+ async def execute(self) -> TestResult:
|
|
|
+ """Execute the complete test asynchronously"""
|
|
|
start_time = time.time()
|
|
|
|
|
|
try:
|
|
|
- self.setup()
|
|
|
- success = self.run_test()
|
|
|
+ await self.setup()
|
|
|
+ success = await self.run_test()
|
|
|
|
|
|
if success:
|
|
|
logger.debug("Validating response...")
|
|
|
@@ -255,6 +252,14 @@ class BaseProxyTest(ABC):
|
|
|
time.time() - start_time,
|
|
|
errors
|
|
|
)
|
|
|
+ else:
|
|
|
+ return TestResult(
|
|
|
+ self.test_id,
|
|
|
+ self.description,
|
|
|
+ TestStatus.PASS,
|
|
|
+ time.time() - start_time,
|
|
|
+ None
|
|
|
+ )
|
|
|
else:
|
|
|
return TestResult(
|
|
|
self.test_id,
|
|
|
@@ -274,55 +279,104 @@ class BaseProxyTest(ABC):
|
|
|
[str(e)]
|
|
|
)
|
|
|
finally:
|
|
|
- self.teardown()
|
|
|
+ await self.teardown()
|
|
|
|
|
|
|
|
|
class TestRunner:
|
|
|
- """Main test runner"""
|
|
|
+ """Async test runner"""
|
|
|
|
|
|
def __init__(self):
|
|
|
self.results = []
|
|
|
|
|
|
- def run_all_tests(self, tests: List[BaseProxyTest]) -> TestResult:
|
|
|
- """Run all tests"""
|
|
|
+ async def run_all_tests(self, tests: List[BaseProxyTest]):
|
|
|
+ """Run all tests sequentially (async)"""
|
|
|
logger.debug("Running all tests")
|
|
|
for test in tests:
|
|
|
logger.info(colored(f"Running test: {test.test_id}", "blue"))
|
|
|
- result = test.execute()
|
|
|
+ result = await test.execute()
|
|
|
logger.debug(f"Test result for {test.test_id}: {result}")
|
|
|
self.results.append(result)
|
|
|
- return result
|
|
|
|
|
|
- def print_summary(self):
|
|
|
+ async def run_all_tests_parallel(
|
|
|
+ self,
|
|
|
+ tests: List[BaseProxyTest],
|
|
|
+ max_concurrent: int = 3
|
|
|
+ ):
|
|
|
+ """Run tests in parallel with concurrency limit
|
|
|
+
|
|
|
+ Args:
|
|
|
+ tests: List of tests to run
|
|
|
+ max_concurrent: Maximum number of tests to run concurrently
|
|
|
+ """
|
|
|
+ logger.debug(f"Running tests in parallel (max {max_concurrent} concurrent)")
|
|
|
+
|
|
|
+ semaphore = asyncio.Semaphore(max_concurrent)
|
|
|
+
|
|
|
+ async def run_with_semaphore(test):
|
|
|
+ async with semaphore:
|
|
|
+ logger.info(colored(f"Running test: {test.test_id}", "blue"))
|
|
|
+ result = await test.execute()
|
|
|
+ logger.debug(f"Test result for {test.test_id}: {result}")
|
|
|
+ return result
|
|
|
+
|
|
|
+ # Run all tests concurrently with limit
|
|
|
+ results = await asyncio.gather(
|
|
|
+ *[run_with_semaphore(test) for test in tests],
|
|
|
+ return_exceptions=False
|
|
|
+ )
|
|
|
+
|
|
|
+ self.results.extend(results)
|
|
|
|
|
|
+ def print_summary(self):
|
|
|
+ """Print test summary"""
|
|
|
total_tests = len(self.results)
|
|
|
passed_tests = sum(1 for r in self.results if r.status == TestStatus.PASS)
|
|
|
failed_tests = total_tests - passed_tests
|
|
|
total_time = sum(r.duration for r in self.results)
|
|
|
|
|
|
- print("TEST EXECUTION SUMMARY")
|
|
|
- print()
|
|
|
+ print("\nTEST EXECUTION SUMMARY")
|
|
|
+ print("=" * 80)
|
|
|
+
|
|
|
+ # Print failures first
|
|
|
for result in self.results:
|
|
|
if result.status != TestStatus.PASS:
|
|
|
- print(f"Test {colored(result.test_id, 'blue')} failed with following errors:")
|
|
|
- for msg in result.error_messages:
|
|
|
- print(colored(f"\t{msg}", result.status.value))
|
|
|
- print()
|
|
|
+ print(f"\nTest {colored(result.test_id, 'blue')} - {result.description}")
|
|
|
+ print(f"Status: {colored(result.status.name, result.status.value)}")
|
|
|
+ if result.error_messages:
|
|
|
+ print("Errors:")
|
|
|
+ for msg in result.error_messages:
|
|
|
+ print(colored(f" {msg}", result.status.value))
|
|
|
|
|
|
- result_table = [["Status", "Name", "Description", "Duration"]]
|
|
|
+ # Print results table
|
|
|
+ print("\n" + "=" * 80)
|
|
|
+ result_table = [["Status", "Name", "Description", "Duration (s)"]]
|
|
|
|
|
|
for result in self.results:
|
|
|
status_color = colored(result.status.name, result.status.value)
|
|
|
-
|
|
|
- result_table.append([status_color,
|
|
|
- result.test_id,
|
|
|
- result.description,
|
|
|
- result.duration])
|
|
|
+ result_table.append([
|
|
|
+ status_color,
|
|
|
+ result.test_id,
|
|
|
+ result.description,
|
|
|
+ f"{result.duration:.3f}"
|
|
|
+ ])
|
|
|
|
|
|
print(tabulate(result_table, headers="firstrow"))
|
|
|
- print()
|
|
|
- print(tabulate([["Total tests", "Passed", "Failed", "Total duration"],
|
|
|
- [total_tests, passed_tests, failed_tests, f"{total_time:.2f}"]]))
|
|
|
+
|
|
|
+ # Print summary stats
|
|
|
+ print("\n" + "=" * 80)
|
|
|
+ summary_table = [[
|
|
|
+ "Total Tests",
|
|
|
+ "Passed",
|
|
|
+ "Failed",
|
|
|
+ "Total Duration (s)"
|
|
|
+ ], [
|
|
|
+ total_tests,
|
|
|
+ colored(str(passed_tests), "green"),
|
|
|
+ colored(str(failed_tests), "red" if failed_tests > 0 else "green"),
|
|
|
+ f"{total_time:.3f}"
|
|
|
+ ]]
|
|
|
+ print(tabulate(summary_table, headers="firstrow"))
|
|
|
+ print("=" * 80 + "\n")
|
|
|
|
|
|
|
|
|
def discover_tests(paths: List[str]) -> List[BaseProxyTest]:
|
|
|
@@ -336,7 +390,7 @@ def discover_tests(paths: List[str]) -> List[BaseProxyTest]:
|
|
|
test_path = Path(p)
|
|
|
|
|
|
if not test_path.exists():
|
|
|
- logger.error(f"{test_path} does not exists")
|
|
|
+ logger.error(f"{test_path} does not exist")
|
|
|
continue
|
|
|
|
|
|
logger.debug(f"Searching for test files in {test_path}")
|
|
|
@@ -345,18 +399,18 @@ def discover_tests(paths: List[str]) -> List[BaseProxyTest]:
|
|
|
elif test_path.is_file():
|
|
|
if test_path.suffix == ".py":
|
|
|
all_files.append(test_path)
|
|
|
- elif '*' in test_path or '?' in test_path or '[' in test_path:
|
|
|
- globbed_files = glob.glob(test_path)
|
|
|
+ elif '*' in str(test_path) or '?' in str(test_path) or '[' in str(test_path):
|
|
|
+ globbed_files = glob.glob(str(test_path))
|
|
|
for f in globbed_files:
|
|
|
if os.path.isfile(f) and f.endswith(".py"):
|
|
|
- all_files.append(f)
|
|
|
+ all_files.append(Path(f))
|
|
|
else:
|
|
|
logger.error(f"Cannot find test files in {test_path}")
|
|
|
raise RuntimeError(f"Cannot find test files in {test_path}")
|
|
|
|
|
|
if not all_files:
|
|
|
logger.error(f"No test file to import from {paths}")
|
|
|
- raise RuntimeError(f"Not test files to import from {paths}")
|
|
|
+ raise RuntimeError(f"No test files to import from {paths}")
|
|
|
|
|
|
for test_file in all_files:
|
|
|
try:
|
|
|
@@ -372,11 +426,12 @@ def discover_tests(paths: List[str]) -> List[BaseProxyTest]:
|
|
|
attr = getattr(module, attr_name)
|
|
|
if (isinstance(attr, type) and
|
|
|
issubclass(attr, BaseProxyTest) and
|
|
|
- attr != BaseProxyTest):
|
|
|
+ attr != BaseProxyTest):
|
|
|
tests.append(attr())
|
|
|
- logger.debug(f"Test classes: {tests}")
|
|
|
+ logger.debug(f"Loaded tests from {test_file}: {[t.test_id for t in tests]}")
|
|
|
except Exception as e:
|
|
|
logger.error(f"Failed to load test file {test_file}: {e}")
|
|
|
+ raise
|
|
|
|
|
|
return tests
|
|
|
|
|
|
@@ -388,59 +443,95 @@ def setup_logging(level):
|
|
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
|
)
|
|
|
|
|
|
-def main():
|
|
|
- parser = argparse.ArgumentParser(description="Reverse roxy test tool")
|
|
|
- parser.add_argument(
|
|
|
- '--log-level', '-l',
|
|
|
- type=str.upper,
|
|
|
- choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
|
|
- default='INFO', help='Set logging level (Default: INFO)')
|
|
|
- parser.add_argument(
|
|
|
- 'paths',
|
|
|
- nargs='+',
|
|
|
- help='Directories, file patterns, or individual Python files (.py)')
|
|
|
- args = parser.parse_args()
|
|
|
-
|
|
|
- log_level = getattr(logging, args.log_level.upper())
|
|
|
- setup_logging(log_level)
|
|
|
- logger = logging.getLogger("httphound")
|
|
|
-
|
|
|
- logger.info("Starting httphound")
|
|
|
- # a little Ascii art only in debug mode
|
|
|
|
|
|
- hound = '''
|
|
|
+async def async_main(args):
|
|
|
+ """Async main function - runs all tests"""
|
|
|
+ logger.info("Starting httphound (async version)")
|
|
|
+
|
|
|
+ # ASCII art only in debug mode
|
|
|
+ if args.log_level == 'DEBUG':
|
|
|
+ hound = '''
|
|
|
__
|
|
|
\ ______/ V`-,
|
|
|
} /~~
|
|
|
/_)^ --,r'
|
|
|
|b |b
|
|
|
'''
|
|
|
- logger.debug(hound)
|
|
|
+ logger.debug(hound)
|
|
|
|
|
|
try:
|
|
|
+ # Discover tests (sync operation)
|
|
|
tests = discover_tests(args.paths)
|
|
|
logger.info(f"Discovered {len(tests)} tests in {args.paths}")
|
|
|
|
|
|
if not tests:
|
|
|
logger.info("No tests found")
|
|
|
- sys.exit(0)
|
|
|
+ return 0
|
|
|
|
|
|
# Run tests
|
|
|
runner = TestRunner()
|
|
|
- runner.run_all_tests(tests)
|
|
|
+
|
|
|
+ if args.parallel:
|
|
|
+ logger.info(f"Running tests in parallel (max {args.parallel} concurrent)")
|
|
|
+ await runner.run_all_tests_parallel(tests, max_concurrent=args.parallel)
|
|
|
+ else:
|
|
|
+ logger.info("Running tests sequentially")
|
|
|
+ await runner.run_all_tests(tests)
|
|
|
|
|
|
# Print summary
|
|
|
runner.print_summary()
|
|
|
|
|
|
# Exit with error code if any tests failed
|
|
|
- failed_count = sum(1 for r in runner.results if not r.status == TestStatus.PASS)
|
|
|
+ failed_count = sum(1 for r in runner.results if r.status != TestStatus.PASS)
|
|
|
if failed_count > 0:
|
|
|
- raise RuntimeError(f"{failed_count} over {len(runner.results)} test failed")
|
|
|
+ logger.error(f"{failed_count} out of {len(runner.results)} tests failed")
|
|
|
+ return 1
|
|
|
+
|
|
|
+ logger.info("All tests passed!")
|
|
|
+ return 0
|
|
|
|
|
|
except Exception as e:
|
|
|
logger.error(f"Test execution failed: {e}")
|
|
|
- sys.exit(1)
|
|
|
+ logger.exception("Full traceback:")
|
|
|
+ return 1
|
|
|
+
|
|
|
+
|
|
|
+def main():
|
|
|
+ """Synchronous entry point that launches async main"""
|
|
|
+ parser = argparse.ArgumentParser(description="Reverse proxy test tool (async)")
|
|
|
+ parser.add_argument(
|
|
|
+ '--log-level', '-l',
|
|
|
+ type=str.upper,
|
|
|
+ choices=['DEBUG', 'INFO', 'WARNING', 'ERROR'],
|
|
|
+ default='INFO',
|
|
|
+ help='Set logging level (Default: INFO)'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ '--parallel', '-p',
|
|
|
+ type=int,
|
|
|
+ default=None,
|
|
|
+ metavar='N',
|
|
|
+ help='Run tests in parallel with max N concurrent tests'
|
|
|
+ )
|
|
|
+ parser.add_argument(
|
|
|
+ 'paths',
|
|
|
+ nargs='+',
|
|
|
+ help='Directories, file patterns, or individual Python files (.py)'
|
|
|
+ )
|
|
|
+ args = parser.parse_args()
|
|
|
+
|
|
|
+ log_level = getattr(logging, args.log_level.upper())
|
|
|
+ setup_logging(log_level)
|
|
|
+
|
|
|
+ # Run async main
|
|
|
+ try:
|
|
|
+ exit_code = asyncio.run(async_main(args))
|
|
|
+ sys.exit(exit_code)
|
|
|
+ except KeyboardInterrupt:
|
|
|
+ logger.info("\nInterrupted by user")
|
|
|
+ sys.exit(130)
|
|
|
|
|
|
|
|
|
def start():
|
|
|
+ """Console script entry point"""
|
|
|
main()
|