"""
Security middleware to handle malformed requests and protocol attacks.
"""
import logging
from django.http import HttpResponseBadRequest
from django.core.cache import cache

logger = logging.getLogger(__name__)


class SecurityFilterMiddleware:
    """
    Middleware to detect and block malformed requests, protocol confusion attacks,
    and other low-level security threats before they reach the application.
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        
        # Suspicious request patterns that should be blocked immediately
        self.malicious_patterns = [
            b'\x00',  # Null bytes - definitely malicious
            b'\r\n\r\n',  # HTTP response injection - definitely malicious
            b'%00',  # Null byte URL encoding - definitely malicious
        ]
        
        # Less aggressive patterns - only flag if combined with other suspicious indicators
        # Note: Removed %0a, %0d, \\x as they have legitimate uses in web applications
        # Example: API parameters, form submissions, and normal URL encoding
        self.suspicious_patterns = []
    
    def __call__(self, request):
        # Check for malformed requests that might not be handled properly by Django
        if self.is_malformed_request(request):
            ip_address = self.get_client_ip(request)
            logger.warning(f"MALFORMED REQUEST from {ip_address}: {request.method} {request.path}")
            
            # Immediately block this IP for a short period
            self.temporary_block_ip(ip_address)
            
            return HttpResponseBadRequest("Malformed request", content_type="text/plain")
        
        # Check for protocol confusion attacks
        if self.is_protocol_confusion(request):
            ip_address = self.get_client_ip(request)
            logger.warning(f"PROTOCOL CONFUSION from {ip_address}: {request.method} {request.path}")
            
            # Block this IP for protocol confusion
            self.temporary_block_ip(ip_address)
            
            return HttpResponseBadRequest("Protocol error", content_type="text/plain")
        
        response = self.get_response(request)
        return response
    
    def is_malformed_request(self, request):
        """Detect malformed requests."""
        try:
            # Check raw path for definitely malicious patterns
            raw_path = request.get_full_path().encode('utf-8')
            for pattern in self.malicious_patterns:
                if pattern in raw_path:
                    return True
                    
            # Count suspicious indicators
            suspicious_count = 0
            
            # Check for suspicious patterns
            for pattern in self.suspicious_patterns:
                if pattern in raw_path:
                    suspicious_count += 1
            
            # Check for excessively long URLs (only flag if very long)
            if len(request.get_full_path()) > 4000:  # Increased threshold
                suspicious_count += 1
            
            # Check for suspiciously long user agent
            user_agent = request.META.get('HTTP_USER_AGENT', '')
            if len(user_agent) > 2000:  # Increased threshold
                suspicious_count += 1
            
            # Check for binary data in headers (but be less strict)
            for key, value in request.META.items():
                if key.startswith('HTTP_'):
                    if isinstance(value, str):
                        # Only flag if we have multiple control characters
                        control_chars = sum(1 for c in value if ord(c) < 32 and c not in '\t\n\r')
                        if control_chars > 3:  # Allow some control characters
                            suspicious_count += 1
                            break
            
            # Only flag as malformed if we have multiple suspicious indicators
            return suspicious_count >= 2
                        
        except Exception as e:
            # Log the exception for debugging
            logger.error(f"Error analyzing request: {e}")
            # Don't automatically flag as malformed - let it through
            return False
    
    def is_protocol_confusion(self, request):
        """Detect protocol confusion attacks."""
        try:
            # Check for SSL/TLS data in HTTP request (only check for actual binary patterns in path)
            raw_path = request.get_full_path()
            
            # Only flag if we have actual SSL/TLS binary handshake patterns
            # These patterns would only appear in truly malformed requests
            if raw_path.find('\\x16\\x03') != -1:  # SSL handshake
                return True
                
            # Check for suspicious binary content in POST data
            if request.method in ['POST', 'PUT', 'PATCH']:
                content_type = request.META.get('CONTENT_TYPE', '').lower()
                if 'application/octet-stream' in content_type and request.path not in ['/upload/', '/api/upload/']:
                    return True
                    
            return False
            
        except Exception as e:
            # Log the exception but don't automatically flag as malicious
            logger.debug(f"Exception in protocol confusion check: {e}")
            return False
    
    def get_client_ip(self, request):
        """Get the real client IP address."""
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            ip = x_forwarded_for.split(',')[0].strip()
        else:
            ip = (request.META.get('HTTP_X_REAL_IP') or 
                  request.META.get('HTTP_CF_CONNECTING_IP') or
                  request.META.get('REMOTE_ADDR', ''))
        return ip
    
    def temporary_block_ip(self, ip_address, duration=300):  # 5 minute block
        """Temporarily block an IP for malformed requests."""
        block_key = f"security_block_{ip_address}"
        cache.set(block_key, True, duration)
        logger.info(f"Temporarily blocked IP {ip_address} for {duration} seconds due to malformed request")