"""
FLHIP Scraper v3 - Database Handler
Saves extracted leads to MySQL - restaurant_leads_usa2
Updated: 2026-01-21 - Added business+city+state duplicate checking
"""

import mysql.connector
from mysql.connector import Error
from datetime import datetime
import json
import re
from config import DB_CONFIG, RSS_FEEDS


TABLE_NAME = "restaurant_leads_usa2"
ARCHIVE_TABLE = "restaurant_leads_usa2_archive"

# Build feed URL lookup from config
FEED_URL_LOOKUP = {feed['name']: feed['url'] for feed in RSS_FEEDS}


def get_feed_url(feed_name: str) -> str:
    """Look up feed URL from feed name."""
    if not feed_name:
        return ''
    return FEED_URL_LOOKUP.get(feed_name, '')


def get_connection():
    """Get database connection."""
    try:
        conn = mysql.connector.connect(**DB_CONFIG)
        return conn
    except Error as e:
        print(f"Database connection error: {e}")
        return None


def normalize_business_name(name: str) -> str:
    """
    Normalize business name for comparison.
    Removes common suffixes, extra spaces, and lowercases.
    """
    if not name:
        return ""
    
    # Lowercase
    name = name.lower().strip()
    
    # Remove common suffixes
    suffixes = [
        ' restaurant', ' restaurants', ' bar', ' grill', ' cafe', ' coffee',
        ' pizza', ' pizzeria', ' kitchen', ' eatery', ' bistro', ' tavern',
        ' pub', ' brewing', ' brewery', ' bbq', ' barbecue', 
        ' inc', ' llc', ' corp', ' co', ' company'
    ]
    for suffix in suffixes:
        if name.endswith(suffix):
            name = name[:-len(suffix)]
    
    # Remove punctuation and extra spaces
    name = re.sub(r'[^\w\s]', '', name)
    name = re.sub(r'\s+', ' ', name).strip()
    
    return name


def check_duplicate_url(conn, url: str, table: str) -> bool:
    """Check if article URL already exists in specified table."""
    if not url:
        return False
    cursor = conn.cursor()
    try:
        cursor.execute(
            f"SELECT id FROM {table} WHERE info_source = %s LIMIT 1",
            (url,)
        )
        result = cursor.fetchone()
        return result is not None
    finally:
        cursor.close()


def check_duplicate_business(conn, business: str, city: str, state: str, table: str) -> tuple[bool, int | None]:
    """
    Check if business+city+state already exists in specified table.
    Returns (is_duplicate, existing_id).
    Uses normalized business name for fuzzy matching.
    """
    if not business or not city or not state:
        return False, None
    
    cursor = conn.cursor()
    try:
        # First try exact match
        cursor.execute(
            f"SELECT id, business FROM {table} WHERE city = %s AND state = %s",
            (city, state)
        )
        results = cursor.fetchall()
        
        if not results:
            return False, None
        
        # Normalize the incoming business name
        normalized_incoming = normalize_business_name(business)
        
        for row in results:
            existing_id, existing_business = row
            normalized_existing = normalize_business_name(existing_business)
            
            # Check if normalized names match
            if normalized_incoming == normalized_existing:
                return True, existing_id
            
            # Also check if one contains the other (for cases like "7 Brew" vs "7 Brew Coffee")
            if normalized_incoming in normalized_existing or normalized_existing in normalized_incoming:
                # Only match if the shorter name is at least 3 chars (avoid false positives)
                shorter = min(normalized_incoming, normalized_existing, key=len)
                if len(shorter) >= 3:
                    return True, existing_id
        
        return False, None
        
    finally:
        cursor.close()


def check_all_duplicates(conn, url: str, business: str, city: str, state: str) -> tuple[bool, str, int | None]:
    """
    Check if lead exists in current table OR archive by URL or business+city+state.
    Returns (is_duplicate, location, existing_id) where location is 'current', 'archive', or None.
    """
    # Check URL duplicates first (fastest)
    if check_duplicate_url(conn, url, TABLE_NAME):
        return True, 'current_url', None
    if check_duplicate_url(conn, url, ARCHIVE_TABLE):
        return True, 'archive_url', None
    
    # Check business+city+state duplicates
    is_dup, existing_id = check_duplicate_business(conn, business, city, state, TABLE_NAME)
    if is_dup:
        return True, 'current_business', existing_id
    
    is_dup, existing_id = check_duplicate_business(conn, business, city, state, ARCHIVE_TABLE)
    if is_dup:
        return True, 'archive_business', existing_id
    
    return False, None, None


def insert_lead(conn, lead: dict) -> int | None:
    """Insert a lead into the database. Returns inserted ID or None."""
    
    source_url = lead.get("source_url", "")
    business_name = lead.get("business_name", "")
    city = lead.get("city", "")
    state = lead.get("state", "")
    
    # Check both current table and archive for duplicates (URL and business+location)
    is_dup, dup_location, existing_id = check_all_duplicates(conn, source_url, business_name, city, state)
    if is_dup:
        dup_type = "URL" if "url" in dup_location else "business+location"
        table_loc = "archive" if "archive" in dup_location else "current"
        print(f"  Duplicate ({table_loc}, {dup_type}): {business_name} in {city}, {state}")
        return None
    
    cursor = conn.cursor()
    
    try:
        sql = f"""
        INSERT INTO {TABLE_NAME} (
            rss_title, feed_name, info_source, internet_info, article_date,
            business, address, city, state, zip,
            opening_date, standardized_opening_date,
            phone, contact_name, contact_email,
            confidence, future_signals, notes,
            lead_stage, extraction_method
        ) VALUES (
            %s, %s, %s, %s, %s, %s, %s, %s, %s, %s,
            %s, %s, %s, %s, %s, %s, %s, %s, %s, %s
        )
        """
        
        # Determine lead stage based on opening date
        opening_date = lead.get("opening_date")
        lead_stage = None
        if opening_date:
            try:
                open_dt = datetime.strptime(opening_date, "%Y-%m-%d")
                days_out = (open_dt - datetime.now()).days
                if days_out <= 30:
                    lead_stage = "LATE"
                elif days_out <= 90:
                    lead_stage = "MID"
                else:
                    lead_stage = "EARLY"
            except:
                lead_stage = "EARLY"
        
        # Convert future_signals list to JSON string
        future_signals = lead.get("future_signals", [])
        future_signals_json = json.dumps(future_signals) if future_signals else None
        
        # internet_info stores the human-readable feed name (e.g. "Google Alerts - Texas Restaurants")
        feed_name = lead.get("feed_name", "")
        
        values = (
            lead.get("article_title", "")[:500],
            feed_name[:100] if feed_name else None,
            source_url[:500] if source_url else None,
            feed_name[:500] if feed_name else None,  # internet_info = feed name, not RSS URL
            lead.get("article_date"),
            lead.get("business_name", "")[:255],
            lead.get("address", "")[:255] if lead.get("address") else None,
            lead.get("city", "")[:100],
            lead.get("state", "")[:10],
            lead.get("zip", "")[:10] if lead.get("zip") else None,
            opening_date,
            opening_date,
            lead.get("phone", "")[:20] if lead.get("phone") else None,
            lead.get("contact_name", "")[:100] if lead.get("contact_name") else None,
            lead.get("contact_email", "")[:255] if lead.get("contact_email") else None,
            lead.get("confidence", "medium")[:20],
            future_signals_json,
            lead.get("notes", "")[:1000] if lead.get("notes") else None,
            lead_stage,
            "claude_v3"
        )
        
        cursor.execute(sql, values)
        conn.commit()
        
        inserted_id = cursor.lastrowid
        print(f"  Inserted ID {inserted_id}: {lead.get('business_name')} in {city}, {state}")
        return inserted_id
        
    except Error as e:
        print(f"  Insert error: {e}")
        conn.rollback()
        return None
    finally:
        cursor.close()


def insert_leads_batch(leads: list[dict]) -> dict:
    """Insert multiple leads. Returns stats."""
    stats = {
        "total": len(leads),
        "inserted": 0,
        "duplicates_url": 0,
        "duplicates_business": 0,
        "archive_duplicates_url": 0,
        "archive_duplicates_business": 0,
        "errors": 0,
        "inserted_ids": []
    }
    
    if not leads:
        return stats
    
    # Sync auto-increment before inserting (ensures IDs continue from archive)
    sync_auto_increment()
    
    conn = get_connection()
    if not conn:
        stats["errors"] = len(leads)
        return stats
    
    try:
        for lead in leads:
            source_url = lead.get("source_url", "")
            business_name = lead.get("business_name", "")
            city = lead.get("city", "")
            state = lead.get("state", "")
            
            # Pre-check duplicates for accurate stats
            is_dup, dup_location, _ = check_all_duplicates(conn, source_url, business_name, city, state)
            
            if is_dup:
                if dup_location == 'archive_url':
                    stats["archive_duplicates_url"] += 1
                    print(f"  Duplicate (archive, URL): {business_name}")
                elif dup_location == 'archive_business':
                    stats["archive_duplicates_business"] += 1
                    print(f"  Duplicate (archive, business): {business_name} in {city}, {state}")
                elif dup_location == 'current_url':
                    stats["duplicates_url"] += 1
                    print(f"  Duplicate (current, URL): {business_name}")
                elif dup_location == 'current_business':
                    stats["duplicates_business"] += 1
                    print(f"  Duplicate (current, business): {business_name} in {city}, {state}")
                continue
            
            result = insert_lead(conn, lead)
            if result:
                stats["inserted"] += 1
                stats["inserted_ids"].append(result)
            else:
                stats["errors"] += 1
    finally:
        conn.close()
    
    # Summary
    total_dups = (stats["duplicates_url"] + stats["duplicates_business"] + 
                  stats["archive_duplicates_url"] + stats["archive_duplicates_business"])
    
    print(f"\n[DB] Insertion Summary:")
    print(f"     Total processed: {stats['total']}")
    print(f"     Inserted: {stats['inserted']}")
    print(f"     Duplicates (current table):")
    print(f"       - URL matches: {stats['duplicates_url']}")
    print(f"       - Business+location matches: {stats['duplicates_business']}")
    print(f"     Duplicates (archive table):")
    print(f"       - URL matches: {stats['archive_duplicates_url']}")
    print(f"       - Business+location matches: {stats['archive_duplicates_business']}")
    print(f"     Total duplicates skipped: {total_dups}")
    print(f"     Errors: {stats['errors']}")
    
    return stats


def get_recent_leads(days: int = 7) -> list[dict]:
    """Get leads inserted in the last N days."""
    conn = get_connection()
    if not conn:
        return []
    
    cursor = conn.cursor(dictionary=True)
    try:
        sql = f"""
        SELECT * FROM {TABLE_NAME}
        WHERE created_at >= DATE_SUB(NOW(), INTERVAL %s DAY)
        ORDER BY created_at DESC
        """
        cursor.execute(sql, (days,))
        return cursor.fetchall()
    finally:
        cursor.close()
        conn.close()


def get_lead_count() -> int:
    """Get total lead count."""
    conn = get_connection()
    if not conn:
        return 0
    cursor = conn.cursor()
    try:
        cursor.execute(f"SELECT COUNT(*) FROM {TABLE_NAME}")
        return cursor.fetchone()[0]
    finally:
        cursor.close()
        conn.close()


def get_archive_count() -> int:
    """Get total archive count."""
    conn = get_connection()
    if not conn:
        return 0
    cursor = conn.cursor()
    try:
        cursor.execute(f"SELECT COUNT(*) FROM {ARCHIVE_TABLE}")
        return cursor.fetchone()[0]
    finally:
        cursor.close()
        conn.close()


def sync_auto_increment() -> int:
    """
    Ensure auto-increment continues from the max ID across both tables.
    Call this before inserting new leads (especially after table reset).
    Returns the next ID that will be used.
    """
    conn = get_connection()
    if not conn:
        return 0
    
    cursor = conn.cursor()
    try:
        # Get max ID from current table
        cursor.execute(f"SELECT COALESCE(MAX(id), 0) FROM {TABLE_NAME}")
        current_max = cursor.fetchone()[0]
        
        # Get max ID from archive table
        cursor.execute(f"SELECT COALESCE(MAX(id), 0) FROM {ARCHIVE_TABLE}")
        archive_max = cursor.fetchone()[0]
        
        # Use the highest of the two
        max_id = max(current_max, archive_max)
        next_id = max_id + 1
        
        # Set auto-increment to continue from there
        cursor.execute(f"ALTER TABLE {TABLE_NAME} AUTO_INCREMENT = {next_id}")
        conn.commit()
        
        print(f"[SYNC] Current table max ID: {current_max}")
        print(f"[SYNC] Archive table max ID: {archive_max}")
        print(f"[SYNC] Auto-increment set to: {next_id}")
        
        return next_id
        
    except Error as e:
        print(f"[SYNC ERROR] {e}")
        return 0
    finally:
        cursor.close()
        conn.close()


def find_duplicates_in_table(table: str = TABLE_NAME) -> list[dict]:
    """
    Find duplicate business+city+state entries in a table.
    Useful for cleaning up existing duplicates.
    Returns list of duplicate groups.
    """
    conn = get_connection()
    if not conn:
        return []
    
    cursor = conn.cursor(dictionary=True)
    try:
        sql = f"""
        SELECT business, city, state, COUNT(*) as count, 
               GROUP_CONCAT(id ORDER BY id) as ids
        FROM {table}
        GROUP BY business, city, state
        HAVING COUNT(*) > 1
        ORDER BY count DESC
        """
        cursor.execute(sql)
        return cursor.fetchall()
    finally:
        cursor.close()
        conn.close()


if __name__ == "__main__":
    print(f"Testing connection to {TABLE_NAME}...")
    conn = get_connection()
    if conn:
        print("Connected successfully")
        cursor = conn.cursor()
        cursor.execute(f"DESCRIBE {TABLE_NAME}")
        print(f"Table has {len(cursor.fetchall())} columns")
        cursor.close()
        conn.close()
    
    print(f"\nCurrent leads: {get_lead_count()}")
    print(f"Archive leads: {get_archive_count()}")
    
    print("\nChecking auto-increment sync...")
    next_id = sync_auto_increment()
    
    print("\n" + "="*60)
    print("Checking for existing duplicates in current table...")
    print("="*60)
    duplicates = find_duplicates_in_table(TABLE_NAME)
    if duplicates:
        print(f"Found {len(duplicates)} duplicate groups:")
        for dup in duplicates[:10]:  # Show first 10
            print(f"  {dup['business']} in {dup['city']}, {dup['state']}: {dup['count']} entries (IDs: {dup['ids']})")
        if len(duplicates) > 10:
            print(f"  ... and {len(duplicates) - 10} more")
    else:
        print("No duplicates found!")
