Browse Source

add ability to use production configuration

And replace relevant bits with regex instead of using templates
Fabrizio Furnari 1 month ago
parent
commit
fab6c20162
4 changed files with 540 additions and 7 deletions
  1. 68 0
      example_config/haproxy.cfg
  2. 80 0
      example_tests/10-production_config.py
  3. 6 0
      httphound/main.py
  4. 386 7
      httphound/proxy.py

+ 68 - 0
example_config/haproxy.cfg

@@ -0,0 +1,68 @@
+# Production HAProxy Configuration Example
+# This can be used with HttpHound's production mode
+
+global
+    daemon
+    maxconn 4096
+    log /dev/log local0
+    log /dev/log local1 notice
+    # chroot /var/lib/haproxy
+    # stats socket /run/haproxy/admin.sock mode 660 level admin
+    stats timeout 30s
+
+defaults
+    log     global
+    mode    http
+    option  httplog
+    option  dontlognull
+    timeout connect 5000
+    timeout client  50000
+    timeout server  50000
+    
+# Frontend - receives client requests
+frontend web_frontend
+    bind *:80
+    #bind *:443 ssl crt /etc/ssl/certs/haproxy.pem
+    
+    # ACLs
+    acl is_api path_beg /api
+    acl is_static path_beg /static /images /css /js
+    
+    # Routing
+    use_backend api_backend if is_api
+    use_backend static_backend if is_static
+    default_backend app_backend
+
+# Backend - application servers
+backend app_backend
+    balance roundrobin
+    option httpchk GET /health
+    http-check expect status 200
+    
+    # These will be patched by HttpHound in production mode
+    server app1 10.0.1.10:8080
+    server app2 10.0.1.11:8080
+    server app3 10.0.1.12:8080
+
+# API backend
+backend api_backend
+    balance leastconn
+    option httpchk GET /api/health
+    
+    server api1 10.0.2.10:9000 check
+    server api2 10.0.2.11:9000 check
+
+# Static content backend
+backend static_backend
+    balance roundrobin
+    
+    server static1 10.0.3.10:80 check
+    server static2 10.0.3.11:80 check
+
+# Stats page
+listen stats
+    bind *:8404
+    stats enable
+    stats uri /stats
+    stats refresh 30s
+    stats admin if TRUE

+ 80 - 0
example_tests/10-production_config.py

@@ -0,0 +1,80 @@
+"""
+Example test using production (static) HAProxy configuration
+"""
+
+from pathlib import Path
+from httphound.main import BaseProxyTest, BackendConfig, ProxyConfig
+
+class ProductionConfigBasicTest(BaseProxyTest):
+    """Test using a production haproxy config with automatic
+    backend patching
+    """
+    def __init__(self):
+        super().__init__()
+        self.description = "Production config with backend patching"
+
+        # backend will run on port 9999
+        self.backend_config = BackendConfig(
+            host='127.0.0.1',
+            port=9999,
+            response_status=200,
+            response_body='Production test response',
+            response_headers={"X-Test-Mode":"production"}
+        )
+
+        # configure haproxy to use production config
+        self.proxy_config = ProxyConfig(
+            binary_path=Path.home() / "bin/haproxy",
+            config_mode="production",
+            production_config_file_path="example_config/haproxy.cfg",
+            backend_name_to_patch="app_backend",
+            bind_address_override="*:4242",
+            skip_backend_injection=False,
+        )
+
+        self.url = "http://127.0.0.1:4242"
+        self.expected_status = 200
+        self.expected_headers = {
+            "x-test-mode": "production",
+        }
+
+    async def run_test(self):
+        """Run the test"""
+        await self.make_request()
+
+        assert self.backend.request_count == 1
+        assert "Host" in self.backend.received_headers
+
+        return True
+
+class ProductionConfigWithExtraArgsTest(BaseProxyTest):
+    """Test with custom HAProxy command-line arguments"""
+    
+    def __init__(self):
+        super().__init__()
+        self.description = "Production config with extra HAProxy args"
+        
+        self.backend_config = BackendConfig(
+            host="127.0.0.1",
+            port=9999,
+        )
+        
+        self.proxy_config = ProxyConfig(
+            binary_path=Path.home() / "bin/haproxy",
+            config_mode="production",
+            production_config_file_path="example_config/haproxy.cfg",
+            backend_name_to_patch="app_backend",
+            bind_address_override="*:4242",
+            
+            # Add custom HAProxy arguments
+            extra_args=[
+                "-dM",
+            ]
+        )
+        
+        self.url = "http://127.0.0.1:4242/"
+        self.expected_status = 200
+    
+    async def run_test(self):
+        await self.make_request()
+        return True

+ 6 - 0
httphound/main.py

@@ -114,6 +114,12 @@ class BaseProxyTest(ABC):
         # Wait for proxy to be ready
         # Wait for proxy to be ready
         logger.debug("Waiting for proxy to be ready")
         logger.debug("Waiting for proxy to be ready")
         await self.proxy.wait_until_ready(self.url)
         await self.proxy.wait_until_ready(self.url)
+
+        # Reset backend counters after health check
+        self.backend.request_count = 0
+        self.backend.received_headers = {}
+        self.backend.received_body = ""
+        
         logger.debug("Setup complete")
         logger.debug("Setup complete")
 
 
     async def teardown(self):
     async def teardown(self):

+ 386 - 7
httphound/proxy.py

@@ -3,7 +3,8 @@ import os
 import subprocess
 import subprocess
 import asyncio
 import asyncio
 
 
-from typing import Dict, Any
+from pathlib import Path
+from typing import Dict, Any, Optional, List
 from dataclasses import dataclass, field
 from dataclasses import dataclass, field
 from jinja2 import Template
 from jinja2 import Template
 
 
@@ -14,14 +15,34 @@ logger = logging.getLogger(__name__)
 
 
 @dataclass
 @dataclass
 class ProxyConfig:
 class ProxyConfig:
-    """Reverse proxy configuration"""
+    """Reverse proxy configuration
+
+    Supports two modes:
+    - template: Use jinja2 template with variable substitution (default)
+    - production: Use existing HAProxy config file(s) as-is or with minimal
+                  patching (with regexes)
+    """
     binary_path: str = "/usr/sbin/haproxy"
     binary_path: str = "/usr/sbin/haproxy"
-    template_path: str = "haproxy.cfg.tpl"
     working_dir: str = "/tmp/httphound"
     working_dir: str = "/tmp/httphound"
+    
+    config_mode: str = "template" # "template" or "production"
+
+    # template mode
+    template_path: Optional[str] = "haproxy.cfg.tpl"
     listen_addr: str = "*"
     listen_addr: str = "*"
     listen_port: int = 4242
     listen_port: int = 4242
     template_vars: Dict[str, Any] = field(default_factory=dict)
     template_vars: Dict[str, Any] = field(default_factory=dict)
 
 
+    # production mode
+    production_config_file_path: Optional[str] = None # main config file path
+    production_config_base_dir: Optional[str] = None # directory containing config files
+    backend_name_to_patch: str = "default_backend" # which backend section to patch
+    bind_address_override: Optional[str] = None # Override bind address (ex. "*:4242")
+    skip_backend_injection: bool = False # if True, don't patch backend servers
+
+    # common settings
+    extra_args: List[str] = field(default_factory=list) # additional HAProxy CLI arguments
+
 
 
 class ProxyManager:
 class ProxyManager:
     """Manages reverse proxy"""
     """Manages reverse proxy"""
@@ -54,8 +75,18 @@ class ProxyManager:
         return rendered_template
         return rendered_template
 
 
     def start(self, backend_config: BackendConfig):
     def start(self, backend_config: BackendConfig):
-        """Start the reverse proxy"""
+        """Start the reverse proxy"""        
+        if self.config.config_mode == "template":
+            self._start_template_mode(backend_config)
+        elif self.config.config_mode == "production":
+            self._start_production_mode(backend_config)
+        else:
+            raise ValueError(
+                f"Unknown config_mode: {self.config.config_mode}. Must be 'template' or 'production'"
+            )
 
 
+    def _start_template_mode(self, backend_config):
+        logger.debug("Starting proxy in template mode")
         # Create working directory
         # Create working directory
         logger.debug(
         logger.debug(
             f"Creating reverse proxy working dir: {self.config.working_dir}")
             f"Creating reverse proxy working dir: {self.config.working_dir}")
@@ -73,15 +104,363 @@ class ProxyManager:
         with open(self.config_file, 'w', encoding="utf-8") as f:
         with open(self.config_file, 'w', encoding="utf-8") as f:
             f.write(config_content)
             f.write(config_content)
 
 
-        # Start proxy process
+        # Validate configuration
+        self._validate_config(self.config_file)
+
+        # start haproxy
+        self._start_haproxy_process()
+
+    def _start_production_mode(self, backend_config: BackendConfig):
+        """Start proxy using production configuration
+        """
+
+        logger.debug("Starting proxy in configuration mode")
+
+        if not self.config.production_config_file_path:
+            raise ValueError(
+                "production_config_file_path must be set when using config_mode='production'"
+            )
+
+        # Create working directory
+        os.makedirs(self.config.working_dir, exist_ok=True)
+
+        # Determine what to copy
+        config_source = Path(self.config.production_config_file_path)
+
+        if not config_source.exists():
+            raise FileNotFoundError(
+                f"Production config not found: {self.config.production_config_file_path}"
+            )
+
+        # If base_dir specified, copy entire directory tree, otherwise
+        # just copy the single file
+        if self.config.production_config_base_dir:
+            self._copy_config_directory(backend_config)
+        else:
+            self._copy_single_config(backend_config)
+
+        # Validate configuration
+        self._validate_config(self.config_file)
+
+        # Start haproxy
+        self._start_haproxy_process()
+
+
+    def _copy_single_config(self, backend_config: BackendConfig):
+        """Copy and optionally patch a single production config file
+
+        Args:
+            backend_config: Backend configuration for patching
+        """
+
+        config_source = Path(self.config.production_config_file_path)
+
+        logger.debug(f"Copying production config from {config_source}")
+
+        # Read original config
+        with open(config_source, 'r', encoding='utf-8') as f:
+            config_content = f.read()
+
+        # apply patches if needed
+        if not self.config.skip_backend_injection:
+            logger.debug("Patching backend server addresses")
+            config_content = self._patch_backend_servers(config_content, backend_config)
+
+        if self.config.bind_address_override:
+            logger.debug(f"Overriding bind addresses to {self.config.bind_address_override}")
+            config_content = self._patch_bind_addresses(config_content)
+
+        # write to working directory
+        self.config_file = os.path.join(self.config.working_dir, "haproxy.cfg")
+        logger.debug(f"Writing patched config to {self.config_file}")
+
+        with open(self.config_file, 'w', encoding='utf-8') as f:
+            f.write(config_content)
+
+        logger.debug(f"Config content: \n\n{config_content}\n\n")
+
+    def _copy_config_directory(self, backend_config: BackendConfig):
+        """Copy entire production config directory, preserving structure
+
+        Args:
+            backend_config: backend configuration for patching
+        """
+        import shutil
+
+        base_dir = Path(self.config.production_config_base_dir)
+        config_file = Path(self.config.production_config_path)
+
+        if not base_dir.exists():
+            raise FileNotFounderror(
+                f"Production config base directory not found: {base_dir}"
+            )
+
+        logger.debug(f"Copying production config directory from {base_dir}")
+
+        # Create subdirectory in working dir to preserve structure
+        work_config_dir = Path(self.config.working_dir) / "haproxy_config"
+
+        # remove if exists then copy
+        if work_config_dir.exists():
+            shutil.rmtree(work_config_dir)
+
+        shutil.copytree(base_dir, work_config_dir, symlinks=False)
+        logger.debug(f"Copied config directory to {work_config_dir}")
+
+        # determine main config file in copied tree
+        relative_config = config_file.relative_to(base_dir)
+        self.config_file = str(work_config_dir / relative_config)
+
+        logger.debug(f"Main configuration file: {self.config_file}")
+
+        # apply patches to all .cfg files in the tree if needed
+        if not self.config.skip_backend_injection or self.config.bind_address_override:
+            self._patch_config_tree(work_config_dir, backend_config)
+
+    def _patch_config_tree(self, config_dir: Path, backend_config: BackendConfig):
+        """Recursively patch all .cfg files in directory tree
+
+        Args:
+            config_dir: root directory containing config files
+            backend_config: backend configuration for patching
+        """
+        import glob
+
+        # find all .cfg files recursively
+        cfg_files = glob.glob(str(config_dir / "**" / "*.cfg"), recursive=True)
+
+        logger.debug(f"found {len(cfg_files)} config files to patch")
+
+        for cfg_file in cfg_files:
+            logger.debug(f"patching {cfg_file}")
+
+            with open(cfg_file, 'r', encoding='utf-8') as f:
+                content = f.read()
+
+            if not self.config.skip_backend_injection:
+                content = self._patch_backend_servers(content, backend_config)
+
+            if self.config.bind_address_override:
+                content = self._patch_bind_addresses(content)
+
+            with open(cfg_file, 'w', encoding='utf-8') as f:
+                f.write(content)
+
+    def _patch_backend_servers(self, config_content: str, backend_config: BackendConfig):
+        """Patch backend server addresses using regex
+
+        Find 'server' lines in the specified backend section and replaces the address:port
+        with the test backend's address.
+
+        Args:
+            config_content: original haproxy config content
+            backend_config: test backend configuraiton
+
+        Returns:
+            patched config content
+        """
+        import re
+
+        backend_name = self.config.backend_name_to_patch
+        test_backend_addr = f"{backend_config.host}:{backend_config.port}"
+
+        logger.debug(f"patching backend '{backend_name}' to point to {test_backend_addr}")
+
+        # pattern to find backend section start
+        backend_pattern = rf'^(\s*backend\s+{re.escape(backend_name)}\s*)$'
+
+        # pattern to match server lines: server <name> <address>:<port> [options...]
+        # replace only <address>:<port> and keep everything else
+        server_pattern = r'^(\s*server\s+\S+\s+)(\S+:\d+)(\s+.*|)$'
+
+        lines = config_content.split('\n')
+        patched_lines = []
+        in_target_backend = False
+        patched_count = 0
+
+        for line in lines:
+            # check if entering target backedn section
+            if re.match(backend_pattern, line, re.IGNORECASE):
+                in_target_backend = True
+                logger.debug(f"found target backend section: {backend_name}")
+                patched_lines.append(line)
+                continue
+
+            # check if leaving backend section (new section starts)
+            if in_target_backend and re.match(r'^\s*(backend|frontend|listen|defaults|global)\s+', line, re.IGNORECASE):
+                in_target_backend = False
+                logger.debug(f"Left backend section, patched {patched_count} server(s)")
+
+            # if in target backend and line is a server directive
+            if in_target_backend:
+                match = re.match(server_pattern, line)
+                if match:
+                    # reconstruct the line with replaced address
+                    prefix = match.group(1)
+                    old_addr = match.group(2)
+                    suffix = match.group(3)
+                    patched_line = f"{prefix}{test_backend_addr}{suffix}"
+                    logger.debug(f"  patched: {line.strip()} -> {patched_line.strip()}")
+                    patched_lines.append(patched_line)
+                    patched_count += 1
+                else:
+                    # not a server line, keep it as-is
+                    patched_lines.append(line)
+            else:
+                # not in target backend
+                patched_lines.append(line)
+
+        if patched_count == 0:
+            logger.warning(
+                f"No server lines patched in backend '{backend_name}'. Check that backend_name_to_patch is correct"
+            )
+
+        return '\n'.join(patched_lines)
+
+    def _patch_bind_addresses(self, config_content: str) -> str:
+        """Override bind/listen addresses using regex
+    
+        Replaces address:port in 'bind' and 'listen' directives while
+        preserving all other options (ssl, crt, alpn, etc.)
+    
+        Args:
+            config_content: Original HAProxy config content
+        
+        Returns:
+            Patched config content
+        """
+        import re
+    
+        override_addr = self.config.bind_address_override
+        logger.debug(f"overriding all bind/listen addresses to {override_addr}")
+
+        # Pattern for bind directive: bind [<address>]:<port> [options...]
+        # replace [<address>]:<port> but keep options
+        #
+        # Examples:
+        #   bind :80
+        #   bind *:443 ssl crt /path/to/cert
+        #   bind 192.168.1.1:8080
+        #   bind [::]:80
+        #   bind /var/run/haproxy.sock
+        bind_pattern = r'^(\s*bind\s+)(?:\S+?)(?:\s+(.*))?$'
+    
+        # Pattern for 'listen' directive (has name before address)
+        listen_pattern = r'^(\s*listen\s+\S+\s+)(?:\S+?)(?:\s+(.*))?$'
+    
+        lines = config_content.split('\n')
+        patched_lines = []
+        bind_count = 0
+        listen_count = 0
+    
+        for line in lines:
+            # Check for bind directive
+            bind_match = re.match(bind_pattern, line, re.IGNORECASE)
+            if bind_match:
+                prefix = bind_match.group(1)
+                options = bind_match.group(2) or ''
+            
+                # Skip unix socket binds
+                if '/' in line and 'sock' in line.lower():
+                    patched_lines.append(line)
+                    continue
+            
+                patched_line = f"{prefix}{override_addr}"
+                if options:
+                    patched_line += f" {options}"
+            
+                logger.debug(f"  Patched bindings: {line.strip()} -> {patched_line.strip()}")
+                patched_lines.append(patched_line)
+                bind_count += 1
+                continue
+        
+            # Check for listen directive
+            listen_match = re.match(listen_pattern, line, re.IGNORECASE)
+            if listen_match:
+                prefix = listen_match.group(1)
+                options = listen_match.group(2) or ''
+            
+                patched_line = f"{prefix}{override_addr}"
+                if options:
+                    patched_line += f" {options}"
+            
+                logger.debug(f"  Patched listen: {line.strip()} -> {patched_line.strip()}")
+                patched_lines.append(patched_line)
+                listen_count += 1
+                continue
+        
+            # No match, keep original
+            patched_lines.append(line)
+    
+        logger.debug(f"Patched {bind_count} bind directive(s) and {listen_count} listen directive(s)")
+        return '\n'.join(patched_lines)
+
+    def _validate_config(self, config_path: str):
+        """Validate HAProxy configuration using haproxy -c
+
+        Args:
+            config_path: path to config file to validate
+
+        Raises:
+            RuntimeError: if configuration is invalid
+        """
+        logger.debug(f"validating haproxy configuration: {config_path}")
+
+        cmd = [
+            str(self.config.binary_path),
+            '-c',
+            '-f',
+            config_path
+        ]
+
+        logger.debug(f"running validation: {' '.join(cmd)}")
+
+        try:
+            result = subprocess.run(
+                cmd,
+                capture_output=True,
+                text=True,
+                timeout=5,
+            )
+
+            if result.returncode == 0:
+                logger.info("HAProxy configuration is valid")
+                if result.stdout:
+                    logger.debug(f"validation output: {result.stdout}")
+            else:
+                logger.error("HAProxy configuration validation failed")
+                logger.error(f"stdout: {result.stdout}")
+                logger.error(f"stderr: {result.stderr}")
+                raise RuntimeError(
+                    f"HAProxy configuration validaiton failed:\n{result.stderr}"
+                )
+        except subprocess.TimeoutExpired:
+            raise RuntimeError("HAProxy configuration validation timeout")
+        except FileNotFoundError:
+            raise RuntimeError(
+                f"HAProxy binary not found: {self.config.binary_path}"
+            )
+        
+    def _start_haproxy_process(self):
+        """Start HAProxy process with the prepared configuration
+
+        This is common to both template and production mode
+        """
+
+        # build commmand 
         cmd = [
         cmd = [
             str(self.config.binary_path),
             str(self.config.binary_path),
             '-V',
             '-V',
             '-db',
             '-db',
-            '-f', self.config_file,
+            '-f',
+            self.config_file,
         ]
         ]
-        logger.debug(f"Running proxy cmd: {cmd}")
 
 
+        if self.config.extra_args:
+            cmd.extend(self.config.extra_args)
+
+        logger.debug(f"Running proxy cmd: {' '.join(cmd)}")
+        
         try:
         try:
             self.process = subprocess.Popen(
             self.process = subprocess.Popen(
                 cmd,
                 cmd,