#!/usr/bin/env python3
"""
FLHIP Hotel Geolocation Enricher
Uses Google Geocoding API to fill missing county, zip, lat, lng data for hotel leads
"""

import os
import sys
import time
import argparse
import requests
import mysql.connector
from mysql.connector import Error
from dotenv import load_dotenv
from pathlib import Path

# Load environment variables
load_dotenv()

# Configuration
GOOGLE_API_KEY = os.getenv("GOOGLE_API_KEY")
GEOCODE_BASE_URL = "https://maps.googleapis.com/maps/api/geocode/json"

# Database configuration
DB_CONFIG = {
    'host': os.getenv('DB_HOST', 'localhost'),
    'user': os.getenv('DB_USER'),
    'password': os.getenv('DB_PASSWORD'),
    'database': os.getenv('DB_NAME')
}

# Table name for hotels
TABLE_NAME = "hotel_leads_usa"

# Tracking file for progress
TRACKING_DIR = Path(__file__).resolve().parent / "tracking"
PROGRESS_FILE = TRACKING_DIR / "hotel_geolocation_enricher_progress.txt"


def get_db_connection():
    """Create database connection"""
    try:
        conn = mysql.connector.connect(**DB_CONFIG)
        return conn
    except Error as e:
        print(f"[ERROR] Database connection error: {e}")
        sys.exit(1)


def get_records_needing_geolocation(conn, limit=100):
    """
    Get records that have address/city/state but are missing county or lat/lng
    """
    cursor = conn.cursor(dictionary=True)
    
    query = f"""
        SELECT id, business, address, city, state, zip, county, lat, lng
        FROM {TABLE_NAME}
        WHERE 
            city IS NOT NULL AND city != '' 
            AND state IS NOT NULL AND state != ''
            AND (lat IS NULL OR lng IS NULL)
        ORDER BY id ASC
        LIMIT %s
    """
    
    cursor.execute(query, (limit,))
    records = cursor.fetchall()
    cursor.close()
    
    return records


def get_total_needing_geolocation(conn):
    """Get total count of records needing geolocation"""
    cursor = conn.cursor()
    
    query = f"""
        SELECT COUNT(*) 
        FROM {TABLE_NAME}
        WHERE 
            city IS NOT NULL AND city != '' 
            AND state IS NOT NULL AND state != ''
            AND (lat IS NULL OR lng IS NULL)
    """
    
    cursor.execute(query)
    count = cursor.fetchone()[0]
    cursor.close()
    
    return count


def geocode_address(address, city, state, zip_code=None):
    """
    Use Google Geocoding API to get location data
    Returns: dict with county, zip, lat, lng or None if failed
    """
    if not GOOGLE_API_KEY:
        print("[ERROR] GOOGLE_API_KEY not set in environment")
        return None
    
    # Build address string
    address_parts = []
    if address:
        address_parts.append(address)
    if city:
        address_parts.append(city)
    if state:
        address_parts.append(state)
    if zip_code:
        address_parts.append(zip_code)
    
    full_address = ", ".join(address_parts)
    
    if not full_address:
        return None
    
    params = {
        'address': full_address,
        'key': GOOGLE_API_KEY
    }
    
    try:
        response = requests.get(GEOCODE_BASE_URL, params=params, timeout=10)
        data = response.json()
        
        if data['status'] != 'OK':
            print(f"  [WARN] Geocoding failed: {data['status']}")
            return None
        
        result = data['results'][0]
        
        # Extract location data
        location = result['geometry']['location']
        lat = location['lat']
        lng = location['lng']
        
        # Extract county and zip from address components
        county = None
        zip_code_result = None
        
        for component in result['address_components']:
            types = component['types']
            
            # County is usually "administrative_area_level_2"
            if 'administrative_area_level_2' in types:
                county = component['long_name']
                # Remove "County" suffix if present for cleaner data
                if county and county.endswith(' County'):
                    county = county[:-7]
            
            # ZIP code
            if 'postal_code' in types:
                zip_code_result = component['short_name']
        
        return {
            'lat': lat,
            'lng': lng,
            'county': county,
            'zip': zip_code_result
        }
        
    except requests.exceptions.RequestException as e:
        print(f"  [ERROR] API request error: {e}")
        return None
    except (KeyError, IndexError) as e:
        print(f"  [ERROR] Parse error: {e}")
        return None


def update_record_geolocation(conn, record_id, geo_data):
    """Update a record with geolocation data"""
    cursor = conn.cursor()
    
    # Build dynamic update query based on what data we have
    updates = []
    values = []
    
    if geo_data.get('county'):
        updates.append("county = %s")
        values.append(geo_data['county'])
    
    if geo_data.get('zip'):
        updates.append("zip = %s")
        values.append(geo_data['zip'])
    
    if geo_data.get('lat'):
        updates.append("lat = %s")
        values.append(geo_data['lat'])
    
    if geo_data.get('lng'):
        updates.append("lng = %s")
        values.append(geo_data['lng'])
    
    if not updates:
        return False
    
    query = f"""
        UPDATE {TABLE_NAME} 
        SET {', '.join(updates)}
        WHERE id = %s
    """
    values.append(record_id)
    
    try:
        cursor.execute(query, values)
        conn.commit()
        cursor.close()
        return True
    except Error as e:
        print(f"  [ERROR] Update error: {e}")
        conn.rollback()
        cursor.close()
        return False


def save_progress(last_processed_id):
    """Save progress to tracking file"""
    TRACKING_DIR.mkdir(parents=True, exist_ok=True)
    with open(PROGRESS_FILE, 'w') as f:
        f.write(str(last_processed_id))


def load_progress():
    """Load last processed ID from tracking file"""
    if PROGRESS_FILE.exists():
        with open(PROGRESS_FILE, 'r') as f:
            try:
                return int(f.read().strip())
            except ValueError:
                return None
    return None


def main():
    parser = argparse.ArgumentParser(description='FLHIP Hotel Geolocation Enricher')
    parser.add_argument('--limit', type=int, default=100, help='Number of records to process')
    parser.add_argument('--batch-size', type=int, default=50, help='Records per batch')
    parser.add_argument('--dry-run', action='store_true', help='Show what would be done without making changes')
    parser.add_argument('--delay', type=float, default=0.1, help='Delay between API calls (seconds)')
    parser.add_argument('--auto', action='store_true', help='Run without prompts (for automated/cron jobs)')
    args = parser.parse_args()
    
    print("=" * 60)
    print("FLHIP HOTEL GEOLOCATION ENRICHER")
    print("=" * 60)
    
    # Check API key
    if not GOOGLE_API_KEY:
        print("[ERROR] GOOGLE_API_KEY not found in environment")
        print("   Set it in your .env file or export it")
        sys.exit(1)
    
    print(f"[OK] Google API Key: ...{GOOGLE_API_KEY[-8:]}")
    
    # Connect to database
    conn = get_db_connection()
    print(f"[OK] Connected to database: {DB_CONFIG['database']}")
    print(f"[OK] Using table: {TABLE_NAME}")
    
    # Get total count
    total_needed = get_total_needing_geolocation(conn)
    print(f"[INFO] Hotel records needing geolocation: {total_needed}")
    
    if total_needed == 0:
        print("[OK] All hotel records already have geolocation data!")
        conn.close()
        return
    
    if args.dry_run:
        print("\n[DRY RUN] No changes will be made\n")
    
    # Process records
    processed = 0
    updated = 0
    failed = 0
    
    while processed < args.limit:
        # Get batch of records
        batch_size = min(args.batch_size, args.limit - processed)
        records = get_records_needing_geolocation(conn, batch_size)
        
        if not records:
            print("\n[OK] No more records to process")
            break
        
        print(f"\n--- Processing batch of {len(records)} hotel records ---")
        
        for record in records:
            record_id = record['id']
            business = record['business'] or 'Unknown'
            city = record['city']
            state = record['state']
            address = record['address']
            
            print(f"\n[{processed + 1}/{args.limit}] ID {record_id}: {business[:40]}")
            print(f"   Location: {address or 'No address'}, {city}, {state}")
            print(f"   Current: county={record['county']}, zip={record['zip']}, lat={record['lat']}, lng={record['lng']}")
            
            if args.dry_run:
                print("   [SKIP] Dry run - skipping API call")
                processed += 1
                continue
            
            # Call Google Geocoding API
            geo_data = geocode_address(address, city, state, record['zip'])
            
            if geo_data:
                # Only update fields that are currently missing
                update_data = {}
                
                if (not record['county'] or record['county'] == '') and geo_data.get('county'):
                    update_data['county'] = geo_data['county']
                
                if (not record['zip'] or record['zip'] == '') and geo_data.get('zip'):
                    update_data['zip'] = geo_data['zip']
                
                if (not record['lat'] or record['lat'] == 0 or record['lat'] == '') and geo_data.get('lat'):
                    update_data['lat'] = geo_data['lat']
                
                if (not record['lng'] or record['lng'] == 0 or record['lng'] == '') and geo_data.get('lng'):
                    update_data['lng'] = geo_data['lng']
                
                if update_data:
                    success = update_record_geolocation(conn, record_id, update_data)
                    if success:
                        print(f"   [UPDATED] {update_data}")
                        updated += 1
                    else:
                        print(f"   [FAILED] Could not update database")
                        failed += 1
                else:
                    print(f"   [SKIP] No new data to update")
            else:
                print(f"   [FAILED] Geocoding failed")
                failed += 1
            
            processed += 1
            save_progress(record_id)
            
            # Rate limiting
            time.sleep(args.delay)
        
        # Prompt to continue if batch complete (skip if --auto)
        if not args.auto and processed < args.limit and processed % args.batch_size == 0:
            remaining = args.limit - processed
            response = input(f"\n[INFO] Processed {processed}. Continue with next {min(args.batch_size, remaining)}? (y/n): ")
            if response.lower() != 'y':
                break
    
    # Summary
    print("\n" + "=" * 60)
    print("SUMMARY")
    print("=" * 60)
    print(f"   Total processed: {processed}")
    print(f"   Updated: {updated}")
    print(f"   Failed: {failed}")
    print(f"   Skipped: {processed - updated - failed}")
    
    conn.close()
    print("\n[DONE]")


if __name__ == "__main__":
    main()
