#!/usr/bin/env python3
import sys
import argparse
import urllib.request
import urllib.parse
import json
import math
from datetime import datetime, timedelta, timezone

def haversine_distance(lat1, lon1, lat2, lon2):
    # Radius of the Earth in km
    R = 6371.0
    
    dlat = math.radians(lat2 - lat1)
    dlon = math.radians(lon2 - lon1)
    
    a = math.sin(dlat / 2)**2 + math.cos(math.radians(lat1)) * math.cos(math.radians(lat2)) * math.sin(dlon / 2)**2
    c = 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))
    
    return R * c

def geocode_place(name):
    url = "https://api.entur.io/geocoder/v1/autocomplete"
    params = {
        "text": name,
        "size": 10,
        "layers": "venue"
    }
    url_parts = list(urllib.parse.urlparse(url))
    url_parts[4] = urllib.parse.urlencode(params)
    full_url = urllib.parse.urlunparse(url_parts)
    
    headers = {"ET-Client-Name": "antigravity-oslorouting"}
    req = urllib.request.Request(full_url, headers=headers, method="GET")
    try:
        with urllib.request.urlopen(req) as response:
            res = json.loads(response.read().decode("utf-8"))
            features = res.get("features", [])
            for feat in features:
                props = feat.get("properties", {})
                gid = props.get("id")
                # Ensure it's a stop place ID
                if gid and gid.startswith("NSR:StopPlace:"):
                    return gid, props.get("name")
            if features:
                props = features[0].get("properties", {})
                return props.get("id"), props.get("name")
    except Exception as e:
        print(f"Error geocoding '{name}': {e}", file=sys.stderr)
    return None, None

def get_nearby_stops(lat, lon, radius_km):
    url = "https://api.entur.io/geocoder/v1/reverse"
    params = {
        "point.lat": lat,
        "point.lon": lon,
        "boundary.circle.radius": radius_km,
        "layers": "venue",
        "size": 50
    }
    url_parts = list(urllib.parse.urlparse(url))
    url_parts[4] = urllib.parse.urlencode(params)
    full_url = urllib.parse.urlunparse(url_parts)
    
    headers = {"ET-Client-Name": "antigravity-oslorouting"}
    req = urllib.request.Request(full_url, headers=headers, method="GET")
    try:
        with urllib.request.urlopen(req) as response:
            return json.loads(response.read().decode("utf-8"))
    except Exception as e:
        print(f"Error reverse geocoding: {e}", file=sys.stderr)
        return None

def query_entur_api(query, variables=None):
    url = "https://api.entur.io/journey-planner/v3/graphql"
    headers = {
        "Content-Type": "application/json",
        "ET-Client-Name": "antigravity-oslorouting"
    }
    data = {"query": query}
    if variables:
        data["variables"] = variables
    
    req = urllib.request.Request(url, data=json.dumps(data).encode("utf-8"), headers=headers, method="POST")
    try:
        with urllib.request.urlopen(req) as response:
            return json.loads(response.read().decode("utf-8"))
    except Exception as e:
        print(f"Error querying Entur API: {e}", file=sys.stderr)
        return None

trip_query = """
query GetTripsFromStop($fromStop: String!, $toStop: String!, $dateTime: DateTime!) {
  trip(
    from: { place: $fromStop }
    to: { place: $toStop }
    dateTime: $dateTime
    arriveBy: true
    numTripPatterns: 30
  ) {
    tripPatterns {
      startTime
      endTime
      duration
      legs {
        mode
        distance
        duration
        expectedStartTime
        expectedEndTime
        fromPlace {
          name
          latitude
          longitude
        }
        toPlace {
          name
          latitude
          longitude
        }
        line {
          publicCode
          transportMode
        }
      }
    }
  }
}
"""

def parse_iso_datetime(dt_str):
    if dt_str[-3] == ':':
        dt_str = dt_str[:-3] + dt_str[-2:]
    return datetime.strptime(dt_str, "%Y-%m-%dT%H:%M:%S%z")

def main():
    parser = argparse.ArgumentParser(description="Find best transit connection from coordinate radius to a destination.")
    parser.add_argument("coords", help="Starting coordinates in 'lat,lon' format (e.g. '59.9224,10.7582')")
    parser.add_argument("destination", help="Name of the destination stop/station (e.g. 'Oslo S')")
    parser.add_argument("--target-time", default="08:22", help="Departure time of connection at destination, in HH:MM format (default: 08:22)")
    parser.add_argument("--transfer-min", default="5-6", help="Desired transfer window range in minutes (default: 5-6)")
    parser.add_argument("--radius", type=float, default=1.5, help="Search radius in kilometers for starting stops (default: 1.5)")
    parser.add_argument("--modes", default="bus,metro,rail", help="Allowed transport modes, comma-separated (default: bus,metro,rail)")
    parser.add_argument("--date", help="Date in YYYY-MM-DD format. Defaults to today or tomorrow depending on target time.")
    parser.add_argument("--jog-speed", type=float, default=8.0, help="Jogging speed from coordinates to the stop in km/h (default: 8.0)")
    parser.add_argument("--buffer-min", type=float, default=3.0, help="Buffer time in minutes to arrive at the stop before departure (default: 3.0)")
    
    args = parser.parse_args()
    
    # Parse coordinates
    try:
        lat_str, lon_str = args.coords.split(",")
        center_lat = float(lat_str.strip())
        center_lon = float(lon_str.strip())
    except ValueError:
        print("Error: Coords must be in 'lat,lon' format (e.g. '59.933093,10.794331')", file=sys.stderr)
        sys.exit(1)
        
    # Geocode destination
    dest_id, dest_name = geocode_place(args.destination)
    if not dest_id:
        print(f"Error: Could not find destination '{args.destination}' via Geocoder.", file=sys.stderr)
        sys.exit(1)
    print(f"Destination matched to: {dest_name} (ID: {dest_id})")
    
    # Parse transfer window
    try:
        if "-" in args.transfer_min:
            t_min_str, t_max_str = args.transfer_min.split("-")
            t_min = float(t_min_str.strip())
            t_max = float(t_max_str.strip())
        else:
            t_min = float(args.transfer_min)
            t_max = t_min
    except ValueError:
        print("Error: --transfer-min must be a number or a range like '5-6'", file=sys.stderr)
        sys.exit(1)
        
    allowed_modes = [m.strip().lower() for m in args.modes.split(",")]
    
    # Calculate target datetime
    # We will assume local timezone of the machine running this query.
    # To be precise, we get local timezone.
    local_tz = datetime.now().astimezone().tzinfo
    
    target_hh, target_mm = map(int, args.target_time.split(":"))
    now = datetime.now(local_tz)
    
    if args.date:
        try:
            target_date = datetime.strptime(args.date, "%Y-%m-%d").date()
        except ValueError:
            print("Error: --date must be in YYYY-MM-DD format", file=sys.stderr)
            sys.exit(1)
    else:
        # Determine if target time has already passed today
        today_target = now.replace(hour=target_hh, minute=target_mm, second=0, microsecond=0)
        if now >= today_target:
            # Already passed today, search for tomorrow
            target_date = (now + timedelta(days=1)).date()
        else:
            target_date = now.date()
            
    # Combine target date and time to construct target dt
    target_dt = datetime.combine(target_date, datetime.min.time()).replace(hour=target_hh, minute=target_mm, tzinfo=local_tz)
    print(f"Target connection departure: {target_dt.strftime('%Y-%m-%d %H:%M:%S %Z')}")
    print(f"Searching for arrivals between {(target_dt - timedelta(minutes=t_max)).strftime('%H:%M:%S')} and {(target_dt - timedelta(minutes=t_min)).strftime('%H:%M:%S')} (transfer window: {args.transfer_min} min)")
    
    # Find nearby stops
    stops_res = get_nearby_stops(center_lat, center_lon, args.radius)
    if not stops_res or "features" not in stops_res or not stops_res["features"]:
        print(f"Error: Could not find any stops within {args.radius} km of coordinates.", file=sys.stderr)
        sys.exit(1)
        
    features = stops_res["features"]
    print(f"Found {len(features)} stop places within {args.radius} km radius.")
    
    # We will query with arriveBy: true at target time
    # Just in case, also query slightly later to cover boundary conditions
    dateTime_str = target_dt.isoformat()
    dateTime_later_str = (target_dt + timedelta(minutes=3)).isoformat()
    target_datetimes = [dateTime_str, dateTime_later_str]
    
    all_patterns = []
    seen_pattern_keys = set()
    
    for idx, feat in enumerate(features):
        props = feat.get("properties", {})
        stop_id = props.get("id")
        stop_name = props.get("name")
        stop_coords = feat.get("geometry", {}).get("coordinates", [])
        
        if not stop_id or not stop_name or len(stop_coords) < 2:
            continue
            
        dist = haversine_distance(center_lat, center_lon, stop_coords[1], stop_coords[0])
        if dist > args.radius:
            continue
            
        print(f"Querying: {stop_name} ({dist:.2f} km)...", end="\r")
        
        for dt_val in target_datetimes:
            variables = {
                "fromStop": stop_id,
                "toStop": dest_id,
                "dateTime": dt_val
            }
            res = query_entur_api(trip_query, variables)
            if not res or "data" not in res or not res["data"] or "trip" not in res["data"] or not res["data"]["trip"]:
                continue
            patterns = res["data"]["trip"]["tripPatterns"]
            for p in patterns:
                key = (p["startTime"], p["endTime"], p["duration"])
                if key not in seen_pattern_keys:
                    seen_pattern_keys.add(key)
                    p["_source_stop_name"] = stop_name
                    p["_source_stop_dist"] = dist
                    all_patterns.append(p)
                    
    print(f"\nTotal unique trip patterns retrieved: {len(all_patterns)}")
    
    filtered_trips = []
    for p in all_patterns:
        legs = p["legs"]
        invalid_mode = False
        transit_legs = []
        
        for leg in legs:
            if leg["mode"] == "foot":
                continue
            if leg["mode"] not in allowed_modes:
                invalid_mode = True
                break
            transit_legs.append(leg)
            
        if invalid_mode or not transit_legs:
            continue
            
        end_time_dt = parse_iso_datetime(p["endTime"])
        time_diff_sec = (target_dt - end_time_dt).total_seconds()
        transfer_minutes = time_diff_sec / 60.0
        
        # We allow a slightly wider window for display/close matches
        # but the primary filter is: arrival must be before target time
        # Let's say we keep arrivals up to 10 minutes before and 2 minutes after
        if -2.0 <= transfer_minutes <= 12.0:
            ride_duration_sec = sum(leg["duration"] for leg in transit_legs)
            
            # Use the actual first transit boarding stop for jogging calculations
            first_transit_leg = transit_legs[0]
            actual_stop_name = first_transit_leg["fromPlace"]["name"]
            actual_stop_lat = first_transit_leg["fromPlace"]["latitude"]
            actual_stop_lon = first_transit_leg["fromPlace"]["longitude"]
            
            actual_stop_dist = haversine_distance(center_lat, center_lon, actual_stop_lat, actual_stop_lon)
            
            # Ensure the actual boarding stop is within the radius
            if actual_stop_dist > args.radius:
                continue
                
            transit_dep_time_dt = parse_iso_datetime(first_transit_leg["expectedStartTime"])
            
            # Calculate jogging duration and latest jog off time with buffer to the actual boarding stop
            jog_duration_sec = (actual_stop_dist / args.jog_speed) * 3600.0
            jog_off_dt = transit_dep_time_dt - timedelta(seconds=jog_duration_sec) - timedelta(minutes=args.buffer_min)
            
            # Total duration is now from the jog off time to the final arrival at Oslo S
            total_duration_sec = (end_time_dt - jog_off_dt).total_seconds()
            
            filtered_trips.append({
                "pattern": p,
                "transfer_minutes": transfer_minutes,
                "ride_duration_sec": ride_duration_sec,
                "total_duration_sec": total_duration_sec,
                "jog_duration_sec": jog_duration_sec,
                "jog_off_dt": jog_off_dt,
                "source_stop_name": actual_stop_name,
                "source_stop_dist": actual_stop_dist,
                "transit_legs": transit_legs
            })
            
    # Deduplicate filtered_trips based on actual transit boarding parameters
    unique_filtered = []
    seen_trips = set()
    for t in filtered_trips:
        first_leg = t["transit_legs"][0]
        route_sig = "-".join([f"{l['mode']}_{l['line']['publicCode'] if l.get('line') else ''}" for l in t["transit_legs"]])
        trip_key = (
            t["source_stop_name"],
            first_leg["expectedStartTime"],
            t["pattern"]["endTime"],
            route_sig
        )
        if trip_key not in seen_trips:
            seen_trips.add(trip_key)
            unique_filtered.append(t)
            
    filtered_trips = unique_filtered
            
    # Perfect matches are strictly inside the transfer window range
    perfect_matches = [t for t in filtered_trips if t_min <= t["transfer_minutes"] <= t_max]
    close_matches = [t for t in filtered_trips if (t_min - 2.0) <= t["transfer_minutes"] <= (t_max + 2.0)]
    
    # Sort by latest jog off time (latest first)
    perfect_matches.sort(key=lambda x: x["jog_off_dt"], reverse=True)
    close_matches.sort(key=lambda x: x["jog_off_dt"], reverse=True)
    
    print("\n" + "="*80)
    print(f"RESULTS FOR ROUTE TO {dest_name.upper()} AT {args.target_time}")
    print("="*80)
    
    print(f"\n--- PERFECT MATCHES ({args.transfer_min} min transfer window, sorted by latest jog off time) ---")
    if not perfect_matches:
        print("No perfect matches found.")
    for idx, t in enumerate(perfect_matches):
        p = t["pattern"]
        ride_min = t["ride_duration_sec"] / 60.0
        jog_min = t["jog_duration_sec"] / 60.0
        total_elapsed_min = t["total_duration_sec"] / 60.0
        first_leg = t["transit_legs"][0]
        board_time_str = parse_iso_datetime(first_leg["expectedStartTime"]).strftime('%H:%M:%S')
        legs_str = " -> ".join([f"{leg['mode'].upper()} {leg['line']['publicCode'] if leg.get('line') else ''} ({leg['duration']/60.0:.1f} min)" for leg in t["transit_legs"]])
        print(f"\n#{idx+1}: From: {t['source_stop_name']} ({t['source_stop_dist']:.2f} km)")
        print(f"  JOG OFF: {t['jog_off_dt'].strftime('%H:%M:%S')} (Jog: {jog_min:.1f} min + {args.buffer_min} min buffer at {args.jog_speed} km/h)")
        print(f"  Board Vehicle: {board_time_str} | Arr at Destination: {parse_iso_datetime(p['endTime']).strftime('%H:%M:%S')}")
        print(f"  Transfer Buffer: {t['transfer_minutes']:.2f} min | Ride duration: {ride_min:.1f} min (Total elapsed: {total_elapsed_min:.1f} min)")
        print(f"  Route: {legs_str}")
        
    print(f"\n--- CLOSE MATCHES (within ±2 min of window, sorted by latest jog off time) ---")
    # Show close matches that are not in perfect matches to avoid redundancy
    shown_close = 0
    perfect_keys = {(t["pattern"]["startTime"], t["pattern"]["endTime"]) for t in perfect_matches}
    for t in close_matches:
        key = (t["pattern"]["startTime"], t["pattern"]["endTime"])
        if key in perfect_keys:
            continue
        p = t["pattern"]
        ride_min = t["ride_duration_sec"] / 60.0
        jog_min = t["jog_duration_sec"] / 60.0
        total_elapsed_min = t["total_duration_sec"] / 60.0
        first_leg = t["transit_legs"][0]
        board_time_str = parse_iso_datetime(first_leg["expectedStartTime"]).strftime('%H:%M:%S')
        legs_str = " -> ".join([f"{leg['mode'].upper()} {leg['line']['publicCode'] if leg.get('line') else ''} ({leg['duration']/60.0:.1f} min)" for leg in t["transit_legs"]])
        print(f"\n#{shown_close+1}: From: {t['source_stop_name']} ({t['source_stop_dist']:.2f} km)")
        print(f"  JOG OFF: {t['jog_off_dt'].strftime('%H:%M:%S')} (Jog: {jog_min:.1f} min + {args.buffer_min} min buffer at {args.jog_speed} km/h)")
        print(f"  Board Vehicle: {board_time_str} | Arr at Destination: {parse_iso_datetime(p['endTime']).strftime('%H:%M:%S')}")
        print(f"  Transfer Buffer: {t['transfer_minutes']:.2f} min | Ride duration: {ride_min:.1f} min (Total elapsed: {total_elapsed_min:.1f} min)")
        print(f"  Route: {legs_str}")
        shown_close += 1
    if shown_close == 0:
        print("No other close matches found.")

if __name__ == "__main__":
    main()
