"""
Enhanced rate limiting middleware with bot detection and malicious request filtering.
"""
import time
import re
from collections import defaultdict
from typing import Dict, List, Tuple
from django.http import HttpResponse
from django.core.cache import cache
from django.conf import settings
from django.template.loader import render_to_string
from django.utils.deprecation import MiddlewareMixin
import logging

logger = logging.getLogger(__name__)


class EnhancedRateLimitMiddleware:
    """
    Enhanced middleware to rate limit requests and detect malicious bot behavior.
    
    Features:
    - IP-based rate limiting with configurable thresholds
    - Bot detection based on user agents and request patterns
    - Malicious request pattern detection
    - Different rate limits for suspicious vs normal traffic
    - Enhanced logging for security monitoring
    """
    
    def __init__(self, get_response):
        self.get_response = get_response
        
        # Standard rate limiting configuration
        self.requests_limit = getattr(settings, 'RATE_LIMIT_REQUESTS', 100)
        self.time_window = getattr(settings, 'RATE_LIMIT_WINDOW', 3600)  # 1 hour
        self.block_duration = getattr(settings, 'RATE_LIMIT_BLOCK_DURATION', 1800)  # 30 minutes
        self.enabled = getattr(settings, 'RATE_LIMIT_ENABLED', True)
        
        # Enhanced bot detection configuration
        self.bot_requests_limit = getattr(settings, 'RATE_LIMIT_BOT_REQUESTS', 20)  # Stricter for bots
        self.bot_block_duration = getattr(settings, 'RATE_LIMIT_BOT_BLOCK_DURATION', 3600)  # 1 hour for bots
        self.suspicious_requests_limit = getattr(settings, 'RATE_LIMIT_SUSPICIOUS_REQUESTS', 10)  # Very strict for suspicious
        
        # Exempt paths
        self.exempt_paths = getattr(settings, 'RATE_LIMIT_EXEMPT_PATHS', [
            '/static/',
            '/media/',
            '/favicon.ico',
            '/robots.txt',
            '/sitemap.xml',
            '/admin/jsi18n/',
            '/inbox/notifications/api/unread_list/',
            '/inbox/notifications/api/unread_list',
            '/jsi18n/',
            '/reload-messages',
        ])
        
        # Known malicious bot user agents (only block clearly malicious ones)
        self.bot_user_agents = [
            r'masscan', r'nmap', r'zmap', r'shodan', r'censys',
            r'sqlmap', r'nikto', r'acunetix', r'burp',
            r'headless.*chrome.*(?!normal)', r'phantomjs',  # Headless browsers (excluding normal headless)
            r'scrapybot', r'scrapy', r'nutch',
        ]
        
        # Whitelist for legitimate API clients that shouldn't be rate limited
        self.legitimate_user_agents_whitelist = [
            r'curl', r'wget', r'python-requests', r'python-httpx',
            r'java', r'go-http', r'ruby', r'node',
            r'axios', r'fetch', r'postman', r'insomnia',
            r'mobile.*safari', r'iphone', r'android',
            r'chrome', r'firefox', r'safari', r'edge',
        ]
        
        # Suspicious request patterns
        self.suspicious_patterns = [
            r'\.php$', r'\.asp$', r'\.jsp$', r'\.cgi$',  # Common web vulnerabilities
            r'/admin', r'/wp-admin', r'/administrator',  # Admin interfaces
            r'/login', r'/signin', r'/auth',  # Auth endpoints when accessed rapidly
            r'\.env$', r'\.config$', r'\.backup$',  # Config files
            r'/api/v1', r'/api/v2', r'/rest/',  # API endpoints
            r'\.\./\.\./\.\./\.\./\.\./\.\./\.\./\.\./\.\./\.\.',  # Path traversal
            r'<script', r'javascript:', r'vbscript:',  # XSS attempts
            r'union.*select', r'drop.*table', r'insert.*into',  # SQL injection
        ]
        
        self.bot_user_agent_regex = re.compile('|'.join(self.bot_user_agents), re.IGNORECASE)
        self.suspicious_pattern_regex = re.compile('|'.join(self.suspicious_patterns), re.IGNORECASE)
        
    def __call__(self, request):
        if not self.enabled:
            return self.get_response(request)
            
        # Skip rate limiting for exempt paths
        current_path = request.path.rstrip('/')  # Remove trailing slash for comparison
        if any(current_path.startswith(path.rstrip('/')) for path in self.exempt_paths):
            return self.get_response(request)
            
        # Additional specific checks for API endpoints
        if '/inbox/notifications/api/' in request.path:
            return self.get_response(request)
        
        # Skip rate limiting for authenticated users with valid sessions
        if hasattr(request, 'user') and request.user.is_authenticated:
            return self.get_response(request)
            
        ip_address = self.get_client_ip(request)
        
        # Detect request type and apply appropriate limits
        request_type = self.classify_request(request)
        
        # Check if IP is currently blocked
        if self.is_blocked(ip_address, request_type):
            self.log_blocked_request(ip_address, request, request_type)
            return self.create_rate_limit_response(request_type)
        
        # Check rate limit based on request type
        if self.is_rate_limited(ip_address, request_type):
            # Block the IP for the specified duration
            self.block_ip(ip_address, request_type)
            self.log_rate_limit_exceeded(ip_address, request, request_type)
            return self.create_rate_limit_response(request_type)
        
        # Record the request
        self.record_request(ip_address, request_type)
        
        # Log suspicious requests for monitoring
        if request_type in ['bot', 'suspicious']:
            self.log_suspicious_request(ip_address, request, request_type)
        
        response = self.get_response(request)
        return response
    
    def classify_request(self, request):
        """Classify request as normal, bot, or suspicious."""
        user_agent = request.META.get('HTTP_USER_AGENT', '').lower()
        path = request.path.lower()
        ip_address = self.get_client_ip(request)
        
        # First check if user agent is in whitelist (legitimate API/browser clients)
        if user_agent:
            legitimate_regex = re.compile('|'.join(self.legitimate_user_agents_whitelist), re.IGNORECASE)
            if legitimate_regex.search(user_agent):
                # It's a legitimate client, return normal
                return 'normal'
        
        # Check for suspicious patterns - ONLY for POST requests
        if request.method == 'POST':
            if self.suspicious_pattern_regex.search(path) or self.suspicious_pattern_regex.search(user_agent):
                return 'suspicious'
        
        # Check if IP has more than 3 requests per second
        if self.is_high_frequency_request(ip_address):
            return 'suspicious'
        
        # Check for known malicious bots (only after whitelist check)
        if not user_agent or self.bot_user_agent_regex.search(user_agent):
            return 'bot'
        
        # Check for missing common headers (likely bot) - but allow if has user agent
        if not request.META.get('HTTP_ACCEPT') and not user_agent:
            return 'bot'
        
        # Check for rapid 404s (scanning behavior)
        if self.is_scanning_behavior(request):
            return 'suspicious'
        
        return 'normal'
    
    def is_scanning_behavior(self, request):
        """Detect if this looks like scanning behavior."""
        # This is a simplified check - in production you might want more sophisticated detection
        suspicious_paths = ['/admin/', '/wp-admin/', '/phpmyadmin/', '/.env', '/config.php']
        return any(request.path.startswith(path) for path in suspicious_paths)
    
    def is_high_frequency_request(self, ip_address):
        """Check if IP is making more than 3 requests per second."""
        frequency_key = f"rate_limit_frequency_{ip_address}"
        current_time = time.time()
        
        # Get the recent request timestamps for this IP
        request_times = cache.get(frequency_key, [])
        
        # Filter out timestamps older than 1 second
        recent_times = [t for t in request_times if current_time - t <= 1.0]
        
        # Add current request time
        recent_times.append(current_time)
        
        # Store updated timestamps (keep only last 10 to prevent memory bloat)
        cache.set(frequency_key, recent_times[-10:], 2)  # Cache for 2 seconds
        
        # Check if more than 3 requests in the last second
        return len(recent_times) > 3
    
    def get_limits_for_type(self, request_type):
        """Get rate limits based on request type."""
        if request_type == 'suspicious':
            return self.suspicious_requests_limit, self.bot_block_duration
        elif request_type == 'bot':
            return self.bot_requests_limit, self.bot_block_duration
        else:
            return self.requests_limit, self.block_duration
    
    def get_client_ip(self, request):
        """Get the real client IP address."""
        # Check for X-Forwarded-For header first (load balancers/proxies)
        x_forwarded_for = request.META.get('HTTP_X_FORWARDED_FOR')
        if x_forwarded_for:
            # Take the first IP in the chain (original client)
            ip = x_forwarded_for.split(',')[0].strip()
        else:
            # Check other common headers
            ip = (request.META.get('HTTP_X_REAL_IP') or 
                  request.META.get('HTTP_CF_CONNECTING_IP') or  # Cloudflare
                  request.META.get('REMOTE_ADDR', ''))
        return ip
    
    def get_cache_key(self, ip_address, key_type, request_type='normal'):
        """Generate cache key for IP address and request type."""
        return f"rate_limit_{key_type}_{request_type}_{ip_address}"
    
    def is_blocked(self, ip_address, request_type):
        """Check if IP address is currently blocked for this request type."""
        block_key = self.get_cache_key(ip_address, "block", request_type)
        # Also check if blocked for any more severe type
        if request_type == 'normal':
            bot_block_key = self.get_cache_key(ip_address, "block", "bot")
            suspicious_block_key = self.get_cache_key(ip_address, "block", "suspicious")
            return (cache.get(block_key, False) or 
                    cache.get(bot_block_key, False) or 
                    cache.get(suspicious_block_key, False))
        elif request_type == 'bot':
            suspicious_block_key = self.get_cache_key(ip_address, "block", "suspicious")
            return cache.get(block_key, False) or cache.get(suspicious_block_key, False)
        else:
            return cache.get(block_key, False)
    
    def block_ip(self, ip_address, request_type):
        """Block an IP address for the specified duration based on request type."""
        _, block_duration = self.get_limits_for_type(request_type)
        block_key = self.get_cache_key(ip_address, "block", request_type)
        cache.set(block_key, True, block_duration)
        
        # Clear the request count when blocking
        count_key = self.get_cache_key(ip_address, "count", request_type)
        cache.delete(count_key)
    
    def is_rate_limited(self, ip_address, request_type):
        """Check if IP address has exceeded the rate limit for this request type."""
        requests_limit, _ = self.get_limits_for_type(request_type)
        count_key = self.get_cache_key(ip_address, "count", request_type)
        request_data = cache.get(count_key, {'count': 0, 'window_start': time.time()})
        
        current_time = time.time()
        window_start = request_data['window_start']
        count = request_data['count']
        
        # If we're outside the time window, reset the counter
        if current_time - window_start > self.time_window:
            return False
        
        # Check if limit is exceeded
        return count >= requests_limit
    
    def record_request(self, ip_address, request_type):
        """Record a request from an IP address for the specific request type."""
        count_key = self.get_cache_key(ip_address, "count", request_type)
        request_data = cache.get(count_key, {'count': 0, 'window_start': time.time()})
        
        current_time = time.time()
        window_start = request_data['window_start']
        count = request_data['count']
        
        # If we're outside the time window, reset the counter
        if current_time - window_start > self.time_window:
            request_data = {'count': 1, 'window_start': current_time}
        else:
            request_data['count'] = count + 1
        
        cache.set(count_key, request_data, self.time_window + 300)  # Extra 5 minutes buffer
    
    def create_rate_limit_response(self, request_type):
        """Create appropriate response for rate limited requests."""
        if request_type == 'suspicious':
            # For suspicious requests, return a plain 400 Bad Request
            return HttpResponse("Bad Request", status=400, content_type="text/plain")
        else:
            # For normal and bot requests, return 429 with custom page
            try:
                content = render_to_string('429.html')
                return HttpResponse(content, status=429, content_type="text/html")
            except:
                return HttpResponse(
                    "Rate limit exceeded. Please try again later.",
                    status=429,
                    content_type="text/plain"
                )
    
    def log_blocked_request(self, ip_address, request, request_type):
        """Log blocked requests for monitoring."""
        user_agent = request.META.get('HTTP_USER_AGENT', 'Unknown')
        logger.warning(
            f"BLOCKED {request_type.upper()} request from {ip_address}: "
            f"method={request.method} path={request.path} ua='{user_agent[:100]}'"
        )
    
    def log_rate_limit_exceeded(self, ip_address, request, request_type):
        """Log when rate limit is exceeded."""
        user_agent = request.META.get('HTTP_USER_AGENT', 'Unknown')
        limits, duration = self.get_limits_for_type(request_type)
        logger.warning(
            f"RATE LIMIT EXCEEDED for {request_type.upper()} from {ip_address}: "
            f"method={request.method} path={request.path} "
            f"limit={limits}/{self.time_window}s blocked_for={duration}s "
            f"ua='{user_agent[:100]}'"
        )
    
    def log_suspicious_request(self, ip_address, request, request_type):
        """Log suspicious requests for monitoring."""
        user_agent = request.META.get('HTTP_USER_AGENT', 'Unknown')
        logger.info(
            f"SUSPICIOUS {request_type.upper()} request from {ip_address}: "
            f"method={request.method} path={request.path} ua='{user_agent[:100]}'"
        )


# Keep the original middleware as well for compatibility
RateLimitMiddleware = EnhancedRateLimitMiddleware