# Copyright (c) 2025, benilerouge.ddns.net
# Licensed under the MIT License.

import matplotlib.pyplot as plt
import matplotlib as mpl
mpl.rcParams['toolbar'] = 'None'
from matplotlib.widgets import Button
import matplotlib.patches as patches
from datetime import datetime, timedelta, time
import sys
import numpy as np
import os
import math
import requests
from io import BytesIO
from PIL import Image
import tkinter as tk
from tkinter import filedialog, simpledialog, ttk
import tkinter.messagebox
import argparse
import xml.etree.ElementTree as ET
import threading
import concurrent.futures
import logging
from functools import lru_cache
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry

# --- Configuration ---
CONFIG = {
    'cache_dir': os.path.expanduser("~/.tile_cache_gsi"),
    'max_tile_cache': 100,
    'default_zoom': 16,
    'speed_threshold': 0.5,  # m/s pour détection arrêts
    'interpolation_interval': 1,  # secondes
    'network_timeout': 10,
    'max_retries': 3,
    'max_workers': 4,  # threads pour téléchargement tuiles
}

# --- Logging ---
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler('gpx_editor.log'),
        logging.StreamHandler()
    ]
)
logger = logging.getLogger(__name__)

# --- Arguments CLI ---
parser = argparse.ArgumentParser(description="Éditeur GPX Optimisé")
parser.add_argument('--server', choices=['seamlessphoto', 'std'], help='Serveur tuiles GSI')
parser.add_argument('--zoom', type=int, help='Zoom niveau 0-20')
parser.add_argument('--file', help='Fichier GPX à ouvrir directement')
args = parser.parse_args()

# --- Classes utilitaires ---

class TileCache:
    """Cache en mémoire pour les tuiles avec gestion LRU"""
    def __init__(self, max_size=100):
        self._cache = {}
        self._access_order = []
        self._lock = threading.Lock()
        self.max_size = max_size
    
    def get(self, key):
        with self._lock:
            if key in self._cache:
                # Mettre à jour l'ordre d'accès
                self._access_order.remove(key)
                self._access_order.append(key)
                return self._cache[key]
            return None
    
    def put(self, key, value):
        with self._lock:
            if key in self._cache:
                self._access_order.remove(key)
            elif len(self._cache) >= self.max_size:
                # Supprimer le moins récemment utilisé
                oldest_key = self._access_order.pop(0)
                del self._cache[oldest_key]
            
            self._cache[key] = value
            self._access_order.append(key)

class NetworkManager:
    """Gestionnaire réseau avec retry et timeout"""
    def __init__(self):
        self.session = self._create_session()
    
    def _create_session(self):
        session = requests.Session()
        retry_strategy = Retry(
            total=CONFIG['max_retries'],
            backoff_factor=1,
            status_forcelist=[429, 500, 502, 503, 504],
        )
        adapter = HTTPAdapter(max_retries=retry_strategy)
        session.mount("http://", adapter)
        session.mount("https://", adapter)
        return session
    
    def download_with_retry(self, url, timeout=None):
        timeout = timeout or CONFIG['network_timeout']
        try:
            response = self.session.get(url, timeout=timeout)
            response.raise_for_status()
            return response
        except requests.exceptions.RequestException as e:
            logger.error(f"Erreur téléchargement {url}: {e}")
            raise

class ProgressDialog:
    """Dialog de progression pour les opérations longues - Version console uniquement"""
    def __init__(self, title, max_value=100):
        self.title = title
        self.max_value = max_value
        self.last_percent = -1
        print(f"\n--- {title} ---")
    
    def update(self, value, text=""):
        # Affichage console uniquement pour éviter les fenêtres parasites
        percent = int((value / self.max_value) * 100)
        if percent != self.last_percent and percent % 10 == 0:  # Afficher tous les 10%
            print(f"{self.title}: {percent}% - {text}")
            self.last_percent = percent
    
    def close(self):
        print(f"--- {self.title} terminé ---\n")

# --- Fonctions optimisées ---

def deg2num(lat_deg, lon_deg, zoom):
    """Conversion coordonnées géographiques vers tuiles"""
    lat_rad = math.radians(lat_deg)
    n = 2 ** zoom
    x = int((lon_deg + 180.0) / 360.0 * n)
    y = int((1.0 - math.log(math.tan(lat_rad) + 1 / math.cos(lat_rad)) / math.pi) / 2 * n)
    return x, y

def num2deg(x, y, zoom):
    """Conversion tuiles vers coordonnées géographiques"""
    n = 2 ** zoom
    lon = x / n * 360.0 - 180
    lat_rad = math.atan(math.sinh(math.pi * (1 - 2 * y / n)))
    lat = math.degrees(lat_rad)
    return lat, lon

def validate_gpx_points(points):
    """Valide la cohérence des points GPX"""
    errors = []
    
    if not points:
        errors.append("Aucun point GPX trouvé")
        return errors
    
    for i, point in enumerate(points):
        # Vérifier les coordonnées
        if not (-90 <= point['lat'] <= 90):
            errors.append(f"Point {i}: latitude invalide {point['lat']}")
        if not (-180 <= point['lon'] <= 180):
            errors.append(f"Point {i}: longitude invalide {point['lon']}")
            
        # Vérifier la progression temporelle
        if i > 0 and point['time'] < points[i-1]['time']:
            errors.append(f"Point {i}: temps non chronologique")
    
    return errors

def parse_gpx_optimized(file_content):
    """Parser GPX optimisé utilisant ElementTree"""
    try:
        root = ET.fromstring(file_content)
        points = []
        
        # Gestion des namespaces GPX
        namespaces = {
            'gpx': 'http://www.topografix.com/GPX/1/1',
            'gpx10': 'http://www.topografix.com/GPX/1/0'
        }
        
        # Essayer les deux versions de namespace
        trkpts = root.findall('.//gpx:trkpt', namespaces) or root.findall('.//gpx10:trkpt', namespaces)
        
        # Si pas de namespace, essayer sans
        if not trkpts:
            trkpts = root.findall('.//trkpt')
        
        for trkpt in trkpts:
            try:
                lat = float(trkpt.get('lat'))
                lon = float(trkpt.get('lon'))
                
                # Chercher elevation et time
                ele_elem = trkpt.find('.//ele') or trkpt.find('.//{*}ele')
                time_elem = trkpt.find('.//time') or trkpt.find('.//{*}time')
                
                if ele_elem is not None and time_elem is not None:
                    ele = float(ele_elem.text)
                    time_obj = datetime.strptime(
                        time_elem.text.replace('Z', '+00:00').replace('+00:00', 'Z'), 
                        "%Y-%m-%dT%H:%M:%SZ"
                    )
                    points.append({
                        'lat': lat, 'lon': lon, 'ele': ele, 'time': time_obj
                    })
            except (ValueError, AttributeError) as e:
                logger.warning(f"Erreur parsing point: {e}")
                continue
        
        return points
        
    except ET.ParseError as e:
        logger.error(f"Erreur parsing XML: {e}")
        # Fallback vers l'ancienne méthode
        return parse_gpx_fallback(file_content)

def parse_gpx_fallback(file_content):
    """Fallback vers l'ancienne méthode de parsing"""
    lines = file_content.strip().splitlines()
    points = []
    for i, line in enumerate(lines):
        if line.strip().startswith('<trkpt'):
            try:
                lat = float(line.split('lat="')[1].split('"')[0])
                lon = float(line.split('lon="')[1].split('"')[0])
                if i + 2 < len(lines):
                    ele = float(lines[i + 1].strip().replace('<ele>', '').replace('</ele>', ''))
                    time_str = lines[i + 2].strip().replace('<time>', '').replace('</time>', '')
                    t = datetime.strptime(time_str, "%Y-%m-%dT%H:%M:%SZ")
                    points.append({'lat': lat, 'lon': lon, 'ele': ele, 'time': t})
            except (ValueError, IndexError) as e:
                logger.warning(f"Erreur parsing ligne {i}: {e}")
                continue
    return points

@lru_cache(maxsize=128)
def haversine_cached(lat1, lon1, lat2, lon2):
    """Version cachée du calcul de distance Haversine"""
    R = 6371000  # Rayon Terre en mètres
    phi1, phi2 = math.radians(lat1), math.radians(lat2)
    dphi = math.radians(lat2 - lat1)
    dlambda = math.radians(lon2 - lon1)
    a = math.sin(dphi/2)**2 + math.cos(phi1)*math.cos(phi2)*math.sin(dlambda/2)**2
    return R * 2 * math.atan2(math.sqrt(a), math.sqrt(1 - a))

def calculate_speeds_vectorized(points):
    """Calcule les vitesses de manière vectorisée"""
    if len(points) < 2:
        return []
    
    # Extraction vectorisée des données
    lats = np.array([p['lat'] for p in points])
    lons = np.array([p['lon'] for p in points])
    times = np.array([(p['time'] - points[0]['time']).total_seconds() for p in points])
    
    # Calcul vectorisé des distances avec Haversine
    lat_rad = np.radians(lats)
    lon_rad = np.radians(lons)
    
    dlat = np.diff(lat_rad)
    dlon = np.diff(lon_rad)
    
    a = (np.sin(dlat/2)**2 + 
         np.cos(lat_rad[:-1]) * np.cos(lat_rad[1:]) * 
         np.sin(dlon/2)**2)
    
    distances = 6371000 * 2 * np.arctan2(np.sqrt(a), np.sqrt(1-a))
    dt = np.diff(times)
    
    # Éviter division par zéro
    speeds = np.divide(distances, dt, out=np.zeros_like(distances), where=dt!=0)
    return speeds.tolist()

def remove_duplicates_optimized(points):
    """Version optimisée de suppression des doublons"""
    if not points:
        return []
    
    unique_points = [points[0]]
    duplicate_count = 0
    
    # Pré-calculer les coordonnées arrondies
    coords = [(round(p['lat'], 6), round(p['lon'], 6)) for p in points]
    
    for i in range(1, len(points)):
        if coords[i] != coords[i-1]:
            unique_points.append(points[i])
        else:
            duplicate_count += 1
    
    logger.info(f"Doublons détectés et supprimés : {duplicate_count}")
    return unique_points

def interpolate_points_optimized(points, progress_callback=None):
    """Version optimisée de l'interpolation avec callback de progression"""
    if not points:
        return []
    
    interpolated = [points[0].copy()]
    total_segments = len(points) - 1
    
    for i in range(len(points) - 1):
        if progress_callback and i % 100 == 0:
            progress_callback(i / total_segments * 100, f"Interpolation: {i}/{total_segments}")
        
        pt_start = points[i]
        pt_end = points[i + 1]
        delta_t = int((pt_end['time'] - pt_start['time']).total_seconds())
        
        if delta_t <= 1:
            interpolated.append(pt_end.copy())
        else:
            # Interpolation vectorisée pour de meilleures performances
            fractions = np.linspace(0, 1, delta_t + 1)[1:-1]  # Exclure 0 et 1
            
            for frac in fractions:
                lat = pt_start['lat'] + frac * (pt_end['lat'] - pt_start['lat'])
                lon = pt_start['lon'] + frac * (pt_end['lon'] - pt_start['lon'])
                ele = pt_start['ele'] + frac * (pt_end['ele'] - pt_start['ele'])
                time_interp = pt_start['time'] + timedelta(seconds=frac * delta_t)
                interpolated.append({'lat': lat, 'lon': lon, 'ele': ele, 'time': time_interp})
            
            interpolated.append(pt_end.copy())
    
    return interpolated

# --- Classes principales ---

class TileManager:
    """Gestionnaire de tuiles optimisé avec cache et téléchargement asynchrone"""
    
    # Dictionnaire des serveurs de tuiles disponibles
    TILE_SERVERS = {
        # Serveur mondial pour cyclisme
        'cyclosm': {
            'url': 'https://dev.a.tile.openstreetmap.fr/cyclosm/{z}/{x}/{y}.png',
            'ext': 'png',
            'attribution': '© CyclOSM, © OpenStreetMap contributors'
        },
        # Serveur japonais
        'gsi_std': {
            'url': 'https://cyberjapandata.gsi.go.jp/xyz/std/{z}/{x}/{y}.png',
            'ext': 'png',
            'attribution': 'GSI Japan'
        }
    }
    
    def __init__(self, cache_dir=None):
        self.cache_dir = cache_dir or CONFIG['cache_dir']
        os.makedirs(self.cache_dir, exist_ok=True)
        self.memory_cache = TileCache(CONFIG['max_tile_cache'])
        self.network_manager = NetworkManager()
    
    def _get_tile_path(self, server, zoom, x, y):
        if server in self.TILE_SERVERS:
            ext = self.TILE_SERVERS[server]['ext']
        else:
            # Fallback pour compatibilité avec l'ancien code
            ext = 'jpg' if server == 'seamlessphoto' else 'png'
        return os.path.join(self.cache_dir, server, str(zoom), str(x), f"{y}.{ext}")
    
    def _download_single_tile(self, server, zoom, x, y):
        """Télécharge une tuile unique"""
        cache_key = (server, zoom, x, y)
        
        # Vérifier le cache mémoire
        cached_tile = self.memory_cache.get(cache_key)
        if cached_tile is not None:
            return cached_tile
        
        # Vérifier le cache disque
        tile_path = self._get_tile_path(server, zoom, x, y)
        if os.path.exists(tile_path):
            try:
                tile = Image.open(tile_path)
                self.memory_cache.put(cache_key, tile)
                return tile
            except Exception as e:
                logger.warning(f"Erreur lecture cache disque {tile_path}: {e}")
        
        # Télécharger depuis le serveur
        if server in self.TILE_SERVERS:
            url = self.TILE_SERVERS[server]['url'].format(z=zoom, x=x, y=y)
        else:
            # Fallback pour compatibilité avec l'ancien code
            ext = 'jpg' if server == 'seamlessphoto' else 'png'
            url = f"https://cyberjapandata.gsi.go.jp/xyz/{server}/{zoom}/{x}/{y}.{ext}"
        
        try:
            logger.debug(f"Téléchargement tuile : {url}")
            response = self.network_manager.download_with_retry(url)
            tile = Image.open(BytesIO(response.content))
            
            # Sauvegarder en cache disque
            os.makedirs(os.path.dirname(tile_path), exist_ok=True)
            tile.save(tile_path)
            
            # Mettre en cache mémoire
            self.memory_cache.put(cache_key, tile)
            
            return tile
            
        except Exception as e:
            logger.error(f"Erreur téléchargement tuile {url}: {e}")
            return None
    
    def download_tiles_async(self, server, zoom, tile_coords, progress_callback=None):
        """Télécharge les tuiles de manière asynchrone"""
        results = {}
        total_tiles = len(tile_coords)
        
        with concurrent.futures.ThreadPoolExecutor(max_workers=CONFIG['max_workers']) as executor:
            # Soumettre tous les téléchargements
            futures = {
                executor.submit(self._download_single_tile, server, zoom, x, y): (x, y)
                for x, y in tile_coords
            }
            
            # Récupérer les résultats au fur et à mesure
            for i, future in enumerate(concurrent.futures.as_completed(futures)):
                x, y = futures[future]
                try:
                    tile = future.result()
                    if tile is not None:
                        results[(x, y)] = tile
                    
                    if progress_callback:
                        progress_callback(
                            (i + 1) / total_tiles * 100, 
                            f"Tuiles téléchargées: {i + 1}/{total_tiles}"
                        )
                        
                except Exception as e:
                    logger.error(f"Erreur téléchargement tuile async {zoom}/{x}/{y}: {e}")
        
        return results

class GPXEditor:
    """Classe principale de l'éditeur GPX"""
    
    def __init__(self):
        self.points = []
        self.original_points = []
        self.intervals = []
        self.selected_points = []
        self.cursor_idx = 0
        self.colors = []
        self.speeds = []
        self.cut_by_buttons = False
        
        # Managers
        self.tile_manager = TileManager()
        
        # Interface
        self.fig = None
        self.ax = None
        self.cursor_circle = None
        self.point_markers = []
        self.time_annotation = None
        
        # CORRECTION CRITIQUE: Références pour les boutons
        self.btn_debut = None
        self.btn_fin = None
        
        # Configuration
        self.server = 'seamlessphoto'
        self.zoom = CONFIG['default_zoom']
    
    def load_gpx_file(self, filepath=None):
        """Charge un fichier GPX avec interface de progression"""
        if not filepath:
            filepath = self._select_file_dialog()
        
        if not filepath:
            logger.info("Aucun fichier sélectionné")
            return False
        
        try:
            with open(filepath, 'r', encoding='utf-8') as f:
                content = f.read()
            
            # Dialog de progression
            progress = ProgressDialog("Chargement GPX", 100)
            
            try:
                progress.update(10, "Parsing GPX...")
                points = parse_gpx_optimized(content)
                
                if not points:
                    raise ValueError("Aucun point trouvé dans le fichier GPX")
                
                # Validation
                progress.update(20, "Validation des données...")
                errors = validate_gpx_points(points)
                if errors:
                    logger.warning(f"Erreurs de validation: {errors}")
                
                # Suppression des doublons
                progress.update(40, "Suppression des doublons...")
                points = remove_duplicates_optimized(points)
                
                # Interpolation
                def interp_callback(value, text):
                    progress.update(40 + value * 0.4, text)
                
                progress.update(40, "Interpolation des points...")
                points = interpolate_points_optimized(points, interp_callback)
                
                # Calculs
                progress.update(85, "Calcul des vitesses...")
                self._calculate_derived_data(points)
                
                # Finalisation
                progress.update(95, "Finalisation...")
                self.points = points
                self.original_points = points.copy()
                self._shift_times_to_zero()
                
                progress.update(100, "Terminé!")
                progress.close()
                
                logger.info(f"GPX chargé avec succès: {len(self.points)} points")
                return True
                
            except Exception as e:
                progress.close()
                raise e
                
        except Exception as e:
            error_msg = f"Erreur lors du chargement du fichier GPX: {e}"
            logger.error(error_msg)
            tkinter.messagebox.showerror("Erreur", error_msg)
            return False
    
    def _select_file_dialog(self):
        """Interface de sélection de fichier"""
        root = tk.Tk()
        root.withdraw()  # Cacher la fenêtre principale
        root.attributes('-topmost', True)  # Toujours au premier plan
        
        try:
            filepath = filedialog.askopenfilename(
                title="Sélectionnez un fichier GPX",
                filetypes=[("Fichiers GPX", "*.gpx"), ("Tous fichiers", "*.*")],
                parent=root
            )
            return filepath
        finally:
            root.quit()
            root.destroy()
    
    def _calculate_derived_data(self, points):
        """Calcule les données dérivées (vitesses, points d'arrêt)"""
        # Calcul des vitesses
        self.speeds = calculate_speeds_vectorized(points)
        
        # Détection des points d'arrêt
        points_arret = set()
        for i in range(1, len(self.speeds) - 1):
            if (self.speeds[i-1] < CONFIG['speed_threshold'] and 
                self.speeds[i] < CONFIG['speed_threshold']):
                points_arret.add(i)
        
        # Attribution des couleurs
        self.colors = []
        for i in range(len(points)):
            if i in points_arret:
                self.colors.append('red')
            else:
                self.colors.append('blue')
    
    def _shift_times_to_zero(self):
        """Décale les temps pour commencer à zéro mais garde le format datetime"""
        if not self.points:
            return
        
        t0 = self.points[0]['time']
        base_date = t0.date()  # Extraire la date
        
        for i, p in enumerate(self.points):
            # Calculer le décalage en secondes depuis le début
            delta_seconds = (p['time'] - t0).total_seconds()
            # Créer un nouveau datetime avec la même date mais temps décalé
            new_time = datetime.combine(base_date, time(0, 0, 0)) + timedelta(seconds=delta_seconds)
            p['time'] = new_time
    
    def setup_interface(self):
        """Configure l'interface matplotlib"""
        self.fig, self.ax = plt.subplots(figsize=(12, 9))
        plt.subplots_adjust(top=0.87, bottom=0.15)
        
        # Configuration de la fenêtre
        self.fig.canvas.toolbar_visible = False
        self.fig.canvas.manager.set_window_title("Éditeur GPX Optimisé")
        
        # Instructions
        instructions = (
            "←/→ : Navigation | ↑/↓ : Zoom | Cliquer avec la souris pour les longs déplacements | Entrée : Marquer début/fin de découpe | z : Annuler | a : Tout annuler | o : Terminer | Boutons [Début] et [Fin] : Couper avant/après le point sélectionné"
        )
        
        self.fig.text(0.5, 0.95, instructions, ha='center', va='top', 
                     fontsize=9, wrap=True, bbox=dict(boxstyle="round,pad=0.5", 
                     facecolor="lightblue", alpha=0.8))
        
        # Curseur
        lats = [pt['lat'] for pt in self.points]
        lons = [pt['lon'] for pt in self.points]
        
        self.cursor_circle = patches.Circle(
            (lons[0], lats[0]), radius=0, color='black', alpha=0.7, zorder=50
        )
        self.ax.add_patch(self.cursor_circle)
        
        # CORRECTION DES BOUTONS - Maintenir les références
        ax_btn_debut = plt.axes([0.35, 0.02, 0.15, 0.06])
        self.btn_debut = Button(ax_btn_debut, "Début")
        self.btn_debut.on_clicked(self._cut_beginning)
        
        ax_btn_fin = plt.axes([0.52, 0.02, 0.15, 0.06])
        self.btn_fin = Button(ax_btn_fin, "Fin")
        self.btn_fin.on_clicked(self._cut_end)
        
        # Événements
        self.fig.canvas.mpl_connect('key_press_event', self._on_key)
        self.fig.canvas.mpl_connect('button_press_event', self._on_click)
    
    def draw_map(self):
        """Dessine la carte avec tuiles et trace"""
        if not self.points:
            return
        
        lats = [pt['lat'] for pt in self.points]
        lons = [pt['lon'] for pt in self.points]
        
        # Dialog de progression pour le téléchargement des tuiles
        progress = ProgressDialog("Chargement de la carte", 100)
        
        try:
            self.ax.clear()
            
            # Calcul des limites avec marge
            marge_lon = max((max(lons) - min(lons)) * 0.05, 0.005)
            marge_lat = max((max(lats) - min(lats)) * 0.05, 0.005)
            
            self.ax.set_xlim(min(lons) - marge_lon, max(lons) + marge_lon)
            self.ax.set_ylim(min(lats) - marge_lat, max(lats) + marge_lat)
            self.ax.set_aspect('equal', adjustable='datalim')
            self.ax.set_xticklabels([])
            self.ax.set_yticklabels([])
            
            # Calcul des tuiles nécessaires
            progress.update(10, "Calcul des tuiles...")
            lat_min, lat_max = min(lats), max(lats)
            lon_min, lon_max = min(lons), max(lons)
            
            x_min, y_max = deg2num(lat_min, lon_min, self.zoom)
            x_max, y_min = deg2num(lat_max, lon_max, self.zoom)
            
            tile_coords = [(x, y) for x in range(x_min, x_max + 1) 
                          for y in range(y_min, y_max + 1)]
            
            # Téléchargement asynchrone des tuiles
            def tile_progress(value, text):
                progress.update(10 + value * 0.6, text)
            
            tiles = self.tile_manager.download_tiles_async(
                self.server, self.zoom, tile_coords, tile_progress
            )
            
            # Affichage des tuiles
            progress.update(70, "Affichage des tuiles...")
            for (x, y), tile in tiles.items():
                if tile is not None:
                    lat_top, lon_left = num2deg(x, y, self.zoom)
                    lat_bottom, lon_right = num2deg(x + 1, y + 1, self.zoom)
                    self.ax.imshow(tile, extent=[lon_left, lon_right, lat_bottom, lat_top], 
                                 aspect='auto', zorder=0)
            
            # Tracé de la route
            progress.update(80, "Tracé de la route...")
            self.ax.plot(lons, lats, marker='o', linestyle='-', color='blue', 
                        alpha=0.6, linewidth=1, markersize=2, zorder=10)
            
            # Points colorés selon vitesse
            sizes = [150 if c == 'red' else 20 for c in self.colors]
            self.ax.scatter(lons, lats, c=self.colors, s=sizes, alpha=0.7, zorder=15)
            
            # Marqueurs début/fin
            self.ax.scatter(lons[0], lats[0], c='yellow', s=200, marker='*', 
                          edgecolors='black', linewidths=2, zorder=30, label="Départ")
            self.ax.scatter(lons[-1], lats[-1], c='red', s=200, marker='X', 
                          edgecolors='black', linewidths=2, zorder=30, label="Arrivée")
            
            # Points de repère temporel (toutes les 5 minutes)
            if len(self.points) > 1:
                progress.update(90, "Marqueurs temporels...")
                self._add_time_markers()
            
            # Curseur
            self.ax.add_patch(self.cursor_circle)
            self._update_cursor()
            
            # Légende
            self.ax.legend(loc='upper right')
            plt.grid(True, alpha=0.3)
            
            # Attribution des tuiles (en bas à gauche)
            if hasattr(self.tile_manager, 'TILE_SERVERS') and self.server in self.tile_manager.TILE_SERVERS:
                attribution = self.tile_manager.TILE_SERVERS[self.server]['attribution']
                self.fig.text(0.02, 0.02, attribution, fontsize=8, alpha=0.7, 
                             bbox=dict(boxstyle="round,pad=0.3", facecolor="white", alpha=0.8))
            
            progress.update(100, "Terminé!")
            self.fig.canvas.draw_idle()
            
        finally:
            progress.close()
            
            
    def _add_time_markers(self):
        """Ajoute des marqueurs temporels toutes les 5 minutes"""
        lats = [pt['lat'] for pt in self.points]
        lons = [pt['lon'] for pt in self.points]
        
        total_sec = int((self.points[-1]['time'] - self.points[0]['time']).total_seconds())
        interval_sec = 300  # 5 minutes
        marks = list(range(interval_sec, total_sec + interval_sec, interval_sec))
        points_sec = [(p['time'] - self.points[0]['time']).total_seconds() for p in self.points]

        indices_verts = []
        for m in marks:
            diffs = [abs(ps - m) for ps in points_sec]
            idx_min = diffs.index(min(diffs))
            indices_verts.append(idx_min)

        if indices_verts:
            self.ax.scatter([lons[i] for i in indices_verts],
                           [lats[i] for i in indices_verts],
                           c='green', s=80, marker='o', zorder=20,
                           label="Repères 5min", alpha=0.8)
    
    def _update_cursor(self, center=False):
        """Met à jour la position du curseur"""
        if not self.points or self.cursor_idx >= len(self.points):
            return
            
        lats = [pt['lat'] for pt in self.points]
        lons = [pt['lon'] for pt in self.points]
        
        self.cursor_circle.center = (lons[self.cursor_idx], lats[self.cursor_idx])
        
        if center:
            w = self.ax.get_xlim()[1] - self.ax.get_xlim()[0]
            h = self.ax.get_ylim()[1] - self.ax.get_ylim()[0]
            self.ax.set_xlim(lons[self.cursor_idx] - w / 2, lons[self.cursor_idx] + w / 2)
            self.ax.set_ylim(lats[self.cursor_idx] - h / 2, lats[self.cursor_idx] + h / 2)

        # Ajuster la taille du curseur selon le zoom
        pix_to_data_x = (self.ax.get_xlim()[1] - self.ax.get_xlim()[0]) / self.ax.bbox.width
        pix_to_data_y = (self.ax.get_ylim()[1] - self.ax.get_ylim()[0]) / self.ax.bbox.height
        r = ((pix_to_data_x + pix_to_data_y) / 2) * 10
        self.cursor_circle.radius = r
        
        self.fig.canvas.draw_idle()
    
    def _zoom(self, factor):
        """Zoom centré sur le curseur"""
        if not self.points:
            return
            
        lons = [pt['lon'] for pt in self.points]
        lats = [pt['lat'] for pt in self.points]
        
        w = self.ax.get_xlim()[1] - self.ax.get_xlim()[0]
        h = self.ax.get_ylim()[1] - self.ax.get_ylim()[0]
        w_new = w * factor
        h_new = h * factor
        
        self.ax.set_xlim(lons[self.cursor_idx] - w_new / 2, lons[self.cursor_idx] + w_new / 2)
        self.ax.set_ylim(lats[self.cursor_idx] - h_new / 2, lats[self.cursor_idx] + h_new / 2)
        self.ax.set_aspect('equal', adjustable='datalim')
        
        self._update_cursor()
    
    def _find_closest_point(self, x, y):
        """Trouve le point le plus proche des coordonnées données"""
        lons = [pt['lon'] for pt in self.points]
        lats = [pt['lat'] for pt in self.points]
        
        dists = [(lons[i] - x) ** 2 + (lats[i] - y) ** 2 for i in range(len(self.points))]
        return int(np.argmin(dists))
    
    def _show_time_annotation(self, idx):
        """Affiche l'annotation temporelle"""
        if self.time_annotation:
            self.time_annotation.remove()
        
        lons = [pt['lon'] for pt in self.points]
        lats = [pt['lat'] for pt in self.points]
        
        t = self.points[idx]['time']
        
        if isinstance(t, datetime):
            # Calculer les secondes depuis minuit
            total_seconds = int((t - datetime.combine(t.date(), time(0, 0, 0))).total_seconds())
        else:  # timedelta
            total_seconds = int(t.total_seconds())
        
        mm, ss = divmod(total_seconds, 60)
        hh, mm = divmod(mm, 60)
        
        if hh > 0:
            time_str = f"{hh:02d}:{mm:02d}:{ss:02d}"
        else:
            time_str = f"{mm:02d}:{ss:02d}"
        
        self.time_annotation = self.ax.annotate(
            time_str, 
            xy=(lons[idx], lats[idx]), 
            xytext=(15, 15),
            textcoords='offset points', 
            bbox=dict(boxstyle="round,pad=0.3", fc="yellow", alpha=0.8),
            arrowprops=dict(arrowstyle="->", color="black"),
            fontsize=10, fontweight='bold'
        )

    
    def _clear_point_markers(self):
        """Efface tous les marqueurs de points"""
        for marker in self.point_markers:
            marker.remove()
        self.point_markers.clear()
        self.fig.canvas.draw_idle()
    
    # --- Gestionnaires d'événements ---
    
    def _on_click(self, event):
        """Gestionnaire de clic souris"""
        if event.inaxes != self.ax or event.xdata is None or event.ydata is None:
            return
        
        idx = self._find_closest_point(event.xdata, event.ydata)
        self.cursor_idx = idx
        self._update_cursor(center=True)
        self._show_time_annotation(idx)
        
        logger.debug(f"Curseur déplacé à l'indice {idx}")
    
    def _on_key(self, event):
        """Gestionnaire d'événements clavier"""
        if not self.points:
            return
            
        if event.key == 'right':
            if self.cursor_idx < len(self.points) - 1:
                self.cursor_idx += 1
                self._update_cursor(center=True)
                self._show_time_annotation(self.cursor_idx)
                
        elif event.key == 'left':
            if self.cursor_idx > 0:
                self.cursor_idx -= 1
                self._update_cursor(center=True)
                self._show_time_annotation(self.cursor_idx)
                
        elif event.key == 'up':
            self._zoom(0.5)  # Zoom in
            
        elif event.key == 'down':
            self._zoom(2.0)  # Zoom out
            
        elif event.key == 'enter':
            self._handle_selection()
            
        elif event.key == 'z':
            self._undo_selection()
            
        elif event.key == 'a':
            self._clear_all_selections()
            
        elif event.key == 'o':
            self._finish_editing()
            
        elif event.key == ' ':
            self._show_time_annotation(self.cursor_idx)
    
    def _handle_selection(self):
        """Gère la sélection de zones à supprimer"""
        lons = [pt['lon'] for pt in self.points]
        lats = [pt['lat'] for pt in self.points]
        
        if len(self.selected_points) % 2 == 0:
            # Début de sélection
            self.selected_points.append(self.cursor_idx)
            marker = self.ax.plot(lons[self.cursor_idx], lats[self.cursor_idx], 
                                'gs', markersize=15, zorder=25)
            self.point_markers.extend(marker)
            logger.info(f"Début de section sélectionnée : index {self.cursor_idx}")
        else:
            # Fin de sélection
            start = self.selected_points[-1]
            if self.cursor_idx <= start:
                logger.error(f"Fin doit être > début ({start})")
                return
            
            self.selected_points.append(self.cursor_idx)
            self.intervals.append((start, self.cursor_idx))
            
            marker = self.ax.plot(lons[self.cursor_idx], lats[self.cursor_idx], 
                                'gs', markersize=15, zorder=25)
            self.point_markers.extend(marker)
            
            logger.info(f"Fin de section sélectionnée : index {self.cursor_idx}")
        
        self.fig.canvas.draw_idle()
    
    def _undo_selection(self):
        """Annule la dernière sélection"""
        if self.selected_points:
            if len(self.selected_points) % 2 == 0 and self.intervals:
                # Annuler la dernière paire complète
                self.intervals.pop()
                logger.info("Dernière section annulée.")
                # Supprimer 2 marqueurs
                for _ in range(2):
                    if self.point_markers:
                        self.point_markers.pop().remove()
            else:
                # Annuler le dernier point de début
                logger.info("Dernier point de sélection annulé.")
                if self.point_markers:
                    self.point_markers.pop().remove()
            
            self.selected_points.pop()
            self.fig.canvas.draw_idle()
        else:
            logger.info("Aucune sélection à annuler.")
    
    def _clear_all_selections(self):
        """Annule toutes les sélections"""
        self.intervals.clear()
        self.selected_points.clear()
        self._clear_point_markers()
        logger.info("Toutes les sélections annulées.")
    
    def _finish_editing(self):
        """Termine l'édition"""
        if len(self.selected_points) % 2 != 0:
            logger.warning("Veuillez terminer la dernière zone de sélection avant de conclure.")
            return
        
        logger.info("Édition terminée, génération du fichier...")
        plt.close()
    
    def _cut_beginning(self, event):
        """Coupe le début de la trace"""
        print(f"DEBUG: _cut_beginning appelé avec cursor_idx = {self.cursor_idx}")
        
        if self.cursor_idx >= len(self.points):
            logger.warning("Indice hors bornes pour coupe début.")
            print("ERREUR: Indice hors bornes pour coupe début.")
            return
        
        if self.cursor_idx == 0:
            logger.warning("Impossible de couper au début, déjà au premier point.")
            print("ATTENTION: Impossible de couper au début, déjà au premier point.")
            return
        
        try:
            points_supprimés = self.cursor_idx
            self.points = self.points[self.cursor_idx:]
            self._shift_times_to_zero()
            self._calculate_derived_data(self.points)
            
            logger.info(f"Trace tronquée au début. {points_supprimés} points supprimés. Nouveau total: {len(self.points)}")
            print(f"SUCCÈS: {points_supprimés} points supprimés au début. Nouveau total: {len(self.points)}")
            
            self.cursor_idx = 0
            self.intervals.clear()
            self.selected_points.clear()
            self.cut_by_buttons = True
            
            self._clear_point_markers()
            
            # CORRECTION: Redessiner la carte ET recentrer la vue
            self.draw_map()
            # Forcer le recentrage de la vue sur le nouveau point de départ
            self._update_cursor(center=True)
            
        except Exception as e:
            logger.error(f"Erreur lors de la coupe début: {e}")
            print(f"ERREUR: Erreur lors de la coupe début: {e}")


    def _cut_end(self, event):
        """Coupe la fin de la trace"""
        print(f"DEBUG: _cut_end appelé avec cursor_idx = {self.cursor_idx}")
        
        if self.cursor_idx >= len(self.points):
            logger.warning("Indice hors bornes pour coupe fin.")
            print("ERREUR: Indice hors bornes pour coupe fin.")
            return
            
        if self.cursor_idx == len(self.points) - 1:
            logger.warning("Impossible de couper à la fin, déjà au dernier point.")
            print("ATTENTION: Impossible de couper à la fin, déjà au dernier point.")
            return
        
        try:
            points_supprimés = len(self.points) - self.cursor_idx - 1
            self.points = self.points[:self.cursor_idx + 1]
            self._shift_times_to_zero()
            self._calculate_derived_data(self.points)
            
            logger.info(f"Trace tronquée à la fin. {points_supprimés} points supprimés. Nouveau total: {len(self.points)}")
            print(f"SUCCÈS: {points_supprimés} points supprimés à la fin. Nouveau total: {len(self.points)}")
            
            # S'assurer que le curseur reste valide
            if self.cursor_idx >= len(self.points):
                self.cursor_idx = len(self.points) - 1
                
            self.intervals.clear()
            self.selected_points.clear()
            self.cut_by_buttons = True
            
            self._clear_point_markers()
            
            # CORRECTION: Redessiner la carte ET recentrer la vue
            self.draw_map()
            # Forcer le recentrage de la vue sur le point courant
            self._update_cursor(center=True)
            
        except Exception as e:
            logger.error(f"Erreur lors de la coupe fin: {e}")
            print(f"ERREUR: Erreur lors de la coupe fin: {e}")
    
    def generate_output(self):
        """Génère le fichier GPX de sortie"""
        if len(self.intervals) == 0:
            if self.cut_by_buttons:
                # Pas de sélection manuelle mais coupe par bouton
                kept_intervals = [(0, len(self.points) - 1)]
            else:
                logger.info("Aucune section sélectionnée, sortie sans modification.")
                return False
        else:
            # Ajuster les intervalles (enlever 1 point au début et à la fin)
            trimmed_intervals = []
            for start, end in self.intervals:
                new_start = start + 1 if start + 1 < end else end
                new_end = end - 1 if end - 1 > start else start
                if new_start <= new_end:
                    trimmed_intervals.append((new_start, new_end))
            
            # Inverser les intervalles pour obtenir les parties à garder
            kept_intervals = self._invert_intervals(trimmed_intervals, len(self.points))
        
        # Construire les points filtrés
        filtered_points = self._build_filtered_points(kept_intervals)
        
        if not filtered_points:
            logger.warning("Aucun point à sauvegarder après filtrage.")
            return False
        
        # Générer le GPX
        gpx_content = self._generate_gpx_content(filtered_points)
        
        # Sauvegarder
        return self._save_gpx_dialog(gpx_content)
    
    def _invert_intervals(self, intervals, total_length):
        """Inverse les intervalles pour obtenir les parties à conserver"""
        if not intervals:
            return [(0, total_length - 1)]
        
        intervals = sorted(intervals)
        result = []
        prev_end = -1
        
        for start, end in intervals:
            if start > prev_end + 1:
                result.append((prev_end + 1, start - 1))
            prev_end = end
        
        if prev_end < total_length - 1:
            result.append((prev_end + 1, total_length - 1))
        
        return result
    
    def _build_filtered_points(self, kept_intervals):
        """Construit la liste des points filtrés avec ajustement temporel"""
        filtered_points = []
        time_removed_seconds = 0  # Garder en secondes
        prev_end = -1

        for start, end in kept_intervals:
            if prev_end >= 0:
                gap_seconds = (self.points[start]['time'] - self.points[prev_end + 1]['time']).total_seconds()
                time_removed_seconds += gap_seconds
            
            for idx in range(start, end + 1):
                pt = self.points[idx].copy()
                # Ajuster en gardant le format datetime
                new_time = pt['time'] - timedelta(seconds=time_removed_seconds)
                pt['time'] = new_time
                filtered_points.append(pt)
            
            prev_end = end

        # Remettre à zéro le temps de départ en gardant la date
        if filtered_points:
            t0 = filtered_points[0]['time']
            base_date = t0.date()
            
            for pt in filtered_points:
                delta_seconds = (pt['time'] - t0).total_seconds()
                pt['time'] = datetime.combine(base_date, time(0, 0, 0)) + timedelta(seconds=delta_seconds)

        return filtered_points
    
    def _generate_gpx_content(self, points):
        """Génère le contenu GPX"""
        gpx_header = '''<?xml version="1.0" encoding="UTF-8"?>
<gpx version="1.1" creator="Éditeur GPX Optimisé"
  xmlns="http://www.topografix.com/GPX/1/1"
  xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
  xsi:schemaLocation="http://www.topografix.com/GPX/1/1
                      http://www.topografix.com/GPX/1/1/gpx.xsd">
<trk>
  <name>Trace corrigée</name>
  <trkseg>'''

        gpx_footer = '''  </trkseg>
</trk>
</gpx>'''

        output_lines = []
        for pt in points:
            time_str = self._format_time(pt['time'])
            output_lines.append(f'    <trkpt lat="{pt["lat"]:.6f}" lon="{pt["lon"]:.6f}">')
            output_lines.append(f'      <ele>{pt["ele"]:.2f}</ele>')
            output_lines.append(f'      <time>{time_str}</time>')
            output_lines.append(f'    </trkpt>')

        return "\n".join([gpx_header] + output_lines + [gpx_footer])
    
    def _format_time(self, t):
        """Formate le temps pour le GPX"""
        if isinstance(t, datetime):
            return t.strftime("%Y-%m-%dT%H:%M:%SZ")
        elif isinstance(t, timedelta):
            # Si on a encore des timedelta, les convertir avec la date du jour
            base_datetime = datetime.combine(datetime.now().date(), time(0, 0, 0))
            return (base_datetime + t).strftime("%Y-%m-%dT%H:%M:%SZ")
        else:
            raise TypeError(f"Type inattendu pour le temps : {type(t)}")
    
    def _save_gpx_dialog(self, gpx_content):
        """Dialog de sauvegarde du fichier GPX"""
        root = tk.Tk()
        root.withdraw()  # Cacher la fenêtre principale
        root.attributes('-topmost', True)  # Toujours au premier plan
        
        try:
            file_path = filedialog.asksaveasfilename(
                title="Enregistrer le fichier GPX corrigé",
                defaultextension=".gpx",
                filetypes=[("Fichiers GPX", "*.gpx"), ("Tous fichiers", "*.*")],
                parent=root
            )
            
            if file_path:
                try:
                    with open(file_path, "w", encoding="utf-8") as f:
                        f.write(gpx_content)
                    
                    success_msg = f"Le fichier GPX a été sauvegardé :\n{os.path.abspath(file_path)}"
                    tkinter.messagebox.showinfo("Sauvegarde réussie", success_msg, parent=root)
                    logger.info(f"Fichier sauvegardé: {file_path}")
                    return True
                    
                except Exception as e:
                    error_msg = f"Erreur lors de la sauvegarde:\n{e}"
                    tkinter.messagebox.showerror("Erreur de sauvegarde", error_msg, parent=root)
                    logger.error(f"Erreur sauvegarde: {e}")
                    return False
            else:
                tkinter.messagebox.showwarning("Sauvegarde annulée", "Aucun fichier sauvegardé.", parent=root)
                return False
        finally:
            root.quit()
            root.destroy()

# --- Interface de configuration ---

class ConfigDialog:
    """Gestionnaire centralisé des dialogs de configuration"""
    
    @staticmethod
    def ask_server():
        """Interface de sélection du serveur de tuiles"""
        root = tk.Tk()
        root.withdraw()
        root.attributes('-topmost', True)
        try:
            choices_text = (
                "Serveurs de tuiles disponibles :\n\n"
                "1 - cyclosm (cartes vélo mondiales)\n"
                "2 - gsi_std (cartes standard Japon)\n\n"
                "Tapez le numéro (défaut=1) :"
            )
            
            server_map = {
                '1': 'cyclosm',
                '2': 'gsi_std',
                '': 'cyclosm'  # défaut
            }
            
            while True:
                choix = simpledialog.askstring(
                    "Serveur tuiles",
                    choices_text,
                    parent=root
                )
                
                if choix is None:
                    return None
                    
                if choix in server_map:
                    return server_map[choix]
                else:
                    tkinter.messagebox.showwarning(
                        "Entrée incorrecte", 
                        "Merci de taper 1 ou 2.",
                        parent=root
                    )
        finally:
            root.quit()
            root.destroy()
    
    @staticmethod
    def ask_zoom():
        """Interface de sélection du niveau de zoom"""
        root = tk.Tk()
        root.withdraw()  # Cacher la fenêtre principale
        root.attributes('-topmost', True)  # Toujours au premier plan
        
        try:
            zoom_str = simpledialog.askstring(
                "Niveau de zoom", 
                "Niveau de zoom (12-18 recommandé, défaut=16) :",
                parent=root
            )
            
            if zoom_str:
                try:
                    z = int(zoom_str)
                    if 0 <= z <= 20:
                        return z
                except (ValueError, TypeError):
                    pass
            
            return CONFIG['default_zoom']
        finally:
            root.quit()
            root.destroy()

    
    @staticmethod
    def ask_zoom():
        """Interface de sélection du niveau de zoom"""
        root = tk.Tk()
        root.withdraw()  # Cacher la fenêtre principale
        root.attributes('-topmost', True)  # Toujours au premier plan
        
        try:
            zoom_str = simpledialog.askstring(
                "Niveau de zoom", 
                "Niveau de zoom (12-18 recommandé, défaut=16) :",
                parent=root
            )
            
            if zoom_str:
                try:
                    z = int(zoom_str)
                    if 0 <= z <= 20:
                        return z
                except (ValueError, TypeError):
                    pass
            
            return CONFIG['default_zoom']
        finally:
            root.quit()
            root.destroy()

# --- Test simple des boutons ---

def test_buttons():
    """Test simple pour vérifier que les boutons fonctionnent"""
    import matplotlib.pyplot as plt
    from matplotlib.widgets import Button
    
    fig, ax = plt.subplots()
    plt.subplots_adjust(bottom=0.2)
    
    def test_callback(event):
        print("Bouton test cliqué!")
        ax.clear()
        ax.text(0.5, 0.5, "Bouton cliqué!", transform=ax.transAxes, 
                ha='center', va='center', fontsize=14, color='red')
        fig.canvas.draw()
    
    # Créer le bouton avec référence maintenue
    ax_btn = plt.axes([0.4, 0.05, 0.2, 0.075])
    btn = Button(ax_btn, 'Test')
    btn.on_clicked(test_callback)
    
    # CRITIQUE: Maintenir la référence
    fig.btn_test = btn
    
    print("Cliquez sur le bouton 'Test' pour vérifier le fonctionnement")
    plt.show()

# --- Fonction principale ---

def main():
    """Fonction principale de l'application"""
    try:
        # Créer l'éditeur
        editor = GPXEditor()
        
        # Configuration depuis les arguments CLI ou interface
        editor.server = args.server if args.server else ConfigDialog.ask_server()
        editor.zoom = args.zoom if (args.zoom is not None and 0 <= args.zoom <= 20) else ConfigDialog.ask_zoom()
        
        logger.info(f"Configuration: serveur={editor.server}, zoom={editor.zoom}")
        
        # Charger le fichier GPX
        if not editor.load_gpx_file(args.file):
            logger.error("Impossible de charger le fichier GPX")
            sys.exit(1)
        
        # Configurer l'interface
        editor.setup_interface()
        
        # Dessiner la carte
        editor.draw_map()
        
        # Afficher l'interface
        logger.info("Interface prête. Utilisez les raccourcis clavier pour naviguer.")
        plt.show()
        
        # Générer le fichier de sortie après fermeture
        if editor.generate_output():
            logger.info("Fichier GPX généré avec succès")
        else:
            logger.info("Aucun fichier généré")
            
    except KeyboardInterrupt:
        logger.info("Arrêt demandé par l'utilisateur")
        sys.exit(0)
    except Exception as e:
        logger.error(f"Erreur fatale: {e}")
        tkinter.messagebox.showerror("Erreur fatale", f"Une erreur inattendue s'est produite:\n{e}")
        sys.exit(1)

if __name__ == "__main__":
    # Pour tester uniquement les boutons, décommentez la ligne suivante :
    # test_buttons()
    main()