# Ultralytics 🚀 AGPL-3.0 License - https://ultralytics.com/license

import os
import platform
import re
import socket
import sys
from concurrent.futures import ThreadPoolExecutor
from math import isfinite
from pathlib import Path
from time import sleep, time

from ultralytics.utils import ENVIRONMENT, GIT, LOGGER, PYTHON_VERSION, RANK, SETTINGS, TESTS_RUNNING, Retry, colorstr

PREFIX = colorstr("Platform: ")

# Configurable platform URL for debugging (e.g. ULTRALYTICS_PLATFORM_URL=http://localhost:3000)
PLATFORM_URL = os.getenv("ULTRALYTICS_PLATFORM_URL", "https://platform.ultralytics.com").rstrip("/")
PLATFORM_API_URL = f"{PLATFORM_URL}/api/webhooks"


def slugify(text):
    """Convert text to URL-safe slug (e.g., 'My Project 1' -> 'my-project-1')."""
    if not text:
        return text
    return re.sub(r"-+", "-", re.sub(r"[^a-z0-9\s-]", "", str(text).lower()).replace(" ", "-")).strip("-")[:128]


try:
    assert not TESTS_RUNNING  # do not log pytest
    assert SETTINGS.get("platform", False) is True or os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
    _api_key = os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
    assert _api_key  # verify API key is present

    import requests

    from ultralytics.utils.logger import ConsoleLogger, SystemLogger
    from ultralytics.utils.torch_utils import model_info_for_loggers

    _executor = ThreadPoolExecutor(max_workers=10)  # Bounded thread pool for async operations

except (AssertionError, ImportError):
    _api_key = None


def resolve_platform_uri(uri, hard=True):
    """Resolve ul:// URIs to signed URLs by authenticating with Ultralytics Platform.

    Formats:
        ul://username/datasets/slug  -> Returns signed URL to NDJSON file
        ul://username/project/model  -> Returns signed URL to .pt file

    Args:
        uri (str): Platform URI starting with "ul://".
        hard (bool): Whether to raise an error if resolution fails.

    Returns:
        (str | None): Signed URL on success, None if not found and hard=False.

    Raises:
        ValueError: If API key is missing/invalid or URI format is wrong.
        PermissionError: If access is denied.
        RuntimeError: If resource is not ready (e.g., dataset still processing).
        FileNotFoundError: If resource not found and hard=True.
        ConnectionError: If network request fails and hard=True.
    """
    import requests

    path = uri[5:]  # Remove "ul://"
    parts = path.split("/")

    api_key = os.getenv("ULTRALYTICS_API_KEY") or SETTINGS.get("api_key")
    if not api_key:
        raise ValueError(f"ULTRALYTICS_API_KEY required for '{uri}'. Get key at {PLATFORM_URL}/settings")

    base = PLATFORM_API_URL
    headers = {"Authorization": f"Bearer {api_key}"}

    # ul://username/datasets/slug
    if len(parts) == 3 and parts[1] == "datasets":
        username, _, slug = parts
        url = f"{base}/datasets/{username}/{slug}/export"

    # ul://username/project/model
    elif len(parts) == 3:
        username, project, model = parts
        url = f"{base}/models/{username}/{project}/{model}/download"

    else:
        raise ValueError(f"Invalid platform URI: {uri}. Use ul://user/datasets/name or ul://user/project/model")

    # (connect_timeout, read_timeout) — short connect so retries are fast, long read for server-side generation
    timeout = (10, 3600) if "/datasets/" in url else (10, 90)

    try:
        for attempt in range(3):
            try:
                r = requests.head(url, headers=headers, allow_redirects=False, timeout=timeout)
                break
            except requests.exceptions.ConnectionError as e:
                LOGGER.warning(f"Retry {attempt + 1}/3 failed for {uri}: {e}")
                if attempt >= 2:
                    raise
                sleep(2 * (2**attempt))  # 2s, 4s backoff
    except Exception as e:
        if hard:
            raise ConnectionError(f"Failed to resolve {uri}: {e}") from e
        LOGGER.warning(f"Failed to resolve {uri}: {e}")
        return None

    # Handle redirect responses (301, 302, 303, 307, 308)
    if 300 <= r.status_code < 400 and "location" in r.headers:
        return r.headers["location"]  # Return signed URL

    # Handle error responses
    if r.status_code == 401:
        raise ValueError(f"Invalid ULTRALYTICS_API_KEY for '{uri}'")
    if r.status_code == 403:
        raise PermissionError(f"Access denied for '{uri}'. Check dataset/model visibility settings.")
    if r.status_code == 404:
        if hard:
            raise FileNotFoundError(f"Not found on platform: {uri}")
        LOGGER.warning(f"Not found on platform: {uri}")
        return None
    if r.status_code == 409:
        raise RuntimeError(f"Resource not ready: {uri}. Dataset may still be processing.")

    # Unexpected response
    r.raise_for_status()
    raise RuntimeError(f"Unexpected response from platform for '{uri}': {r.status_code}")


def _interp_plot(plot, n=101):
    """Interpolate plot curve data to n points to reduce storage size."""
    import numpy as np

    if not plot.get("x") or not plot.get("y"):
        return plot  # No interpolation needed (e.g., confusion_matrix)

    x, y = np.array(plot["x"]), np.array(plot["y"])
    if len(x) <= n:
        return plot  # Already small enough

    # New x values (101 points gives clean 0.01 increments: 0, 0.01, 0.02, ..., 1.0)
    x_new = np.linspace(x[0], x[-1], n)

    # Interpolate y values (handle both 1D and 2D arrays)
    if y.ndim == 1:
        y_new = np.interp(x_new, x, y)
    else:
        y_new = np.array([np.interp(x_new, x, yi) for yi in y])

    # Also interpolate ap if present (for PR curves)
    result = {**plot, "x": x_new.tolist(), "y": y_new.tolist()}
    if "ap" in plot:
        result["ap"] = plot["ap"]  # Keep AP values as-is (per-class scalars)

    return result


def _sanitize_json_value(value):
    """Replace non-finite floats in payloads with None so requests JSON encoding succeeds."""
    if isinstance(value, dict):
        return {k: _sanitize_json_value(v) for k, v in value.items()}
    if isinstance(value, (list, tuple)):
        return [_sanitize_json_value(v) for v in value]
    if isinstance(value, float):
        return value if isfinite(value) else None  # avoid "Out of range float values are not JSON compliant" warnings
    return value


def _send(event, data, project, name, model_id=None, retry=2):
    """Send event to Platform endpoint with retry logic."""
    payload = {"event": event, "project": project, "name": name, "data": _sanitize_json_value(data)}
    if model_id:
        payload["modelId"] = model_id

    @Retry(times=retry, delay=1)
    def post():
        r = requests.post(
            f"{PLATFORM_API_URL}/training/metrics",
            json=payload,
            headers={"Authorization": f"Bearer {_api_key}"},
            timeout=30,
        )
        if 400 <= r.status_code < 500 and r.status_code not in {408, 429}:
            try:
                msg = r.json().get("error", r.reason)
            except Exception:
                msg = r.reason
            LOGGER.warning(f"{PREFIX}{msg}")
            return None  # Don't retry client errors (except 408 timeout, 429 rate limit)
        r.raise_for_status()
        return r.json()

    try:
        return post()
    except Exception as e:
        LOGGER.debug(f"{PREFIX}Failed to send {event}: {e}")
        return None


def _send_async(event, data, project, name, model_id=None):
    """Send event asynchronously using bounded thread pool."""
    _executor.submit(_send, event, data, project, name, model_id)


def _handle_control_response(trainer, ctx, response):
    """Apply centralized stop signals returned by Platform webhook responses.

    Notes:
        ``ctx["cancelled"]`` is the durable cancellation signal. During startup, trainer setup later resets
        ``trainer.stop``, so early stop requests still rely on ``on_pretrain_routine_end()`` to reapply the flag after
        setup completes.
    """
    if response and response.get("cancelled"):
        ctx["cancelled"] = True
        trainer.stop = True
        LOGGER.info(f"{PREFIX}Training cancelled from Platform ⚠️")


def _upload_model(model_path, project, name, progress=False, retry=1, model_id=None):
    """Upload model checkpoint to Platform via signed URL."""
    from ultralytics.utils.uploads import safe_upload

    model_path = Path(model_path)
    if not model_path.exists():
        LOGGER.warning(f"{PREFIX}Model file not found: {model_path}")
        return None

    # Get signed upload URL from Platform (server sanitizes filename for storage safety)
    @Retry(times=3, delay=2)
    def get_signed_url():
        payload = {"project": project, "name": name, "filename": model_path.name}
        if model_id:
            payload["modelId"] = model_id  # Direct lookup avoids slug mismatch from auto-increment
        r = requests.post(
            f"{PLATFORM_API_URL}/models/upload",
            json=payload,
            headers={"Authorization": f"Bearer {_api_key}"},
            timeout=30,
        )
        r.raise_for_status()
        return r.json()

    try:
        data = get_signed_url()
    except Exception as e:
        LOGGER.warning(f"{PREFIX}Failed to get upload URL: {e}")
        return None

    # Upload to GCS using safe_upload with retry logic and optional progress bar
    if safe_upload(file=model_path, url=data["uploadUrl"], retry=retry, progress=progress):
        return data.get("gcsPath")
    return None


def _upload_model_async(model_path, project, name, model_id=None):
    """Upload model asynchronously using bounded thread pool."""
    _executor.submit(_upload_model, model_path, project, name, model_id=model_id)


def _get_environment_info():
    """Collect comprehensive environment info using existing ultralytics utilities."""
    import shutil

    import psutil
    import torch

    from ultralytics import __version__
    from ultralytics.utils.torch_utils import get_cpu_info, get_gpu_info

    # Get RAM and disk totals
    memory = psutil.virtual_memory()
    disk_usage = shutil.disk_usage("/")

    env = {
        "ultralyticsVersion": __version__,
        "hostname": socket.gethostname(),
        "os": platform.platform(),
        "environment": ENVIRONMENT,
        "pythonVersion": PYTHON_VERSION,
        "pythonExecutable": sys.executable,
        "cpuCount": os.cpu_count() or 0,
        "cpu": get_cpu_info(),
        "command": " ".join(sys.argv),
        "totalRamGb": round(memory.total / (1 << 30), 1),  # Total RAM in GB
        "totalDiskGb": round(disk_usage.total / (1 << 30), 1),  # Total disk in GB
    }

    # Git info using cached GIT singleton (no subprocess calls)
    try:
        if GIT.is_repo:
            if GIT.origin:
                env["gitRepository"] = GIT.origin
            if GIT.branch:
                env["gitBranch"] = GIT.branch
            if GIT.commit:
                env["gitCommit"] = GIT.commit[:12]  # Short hash
    except Exception:
        pass

    # GPU info
    try:
        if torch.cuda.is_available():
            env["gpuCount"] = torch.cuda.device_count()
            env["gpuType"] = get_gpu_info(0) if torch.cuda.device_count() > 0 else None
    except Exception:
        pass

    return env


def _get_project_name(trainer):
    """Get slugified project and name from trainer args."""
    raw = str(trainer.args.project)
    parts = raw.split("/", 1)
    project = f"{parts[0]}/{slugify(parts[1])}" if len(parts) == 2 else slugify(raw)
    return project, slugify(str(trainer.args.name or "train"))


def on_pretrain_routine_start(trainer):
    """Initialize Platform logging at training start."""
    if RANK not in {-1, 0} or not trainer.args.project:
        return

    project, name = _get_project_name(trainer)
    LOGGER.info(f"{PREFIX}Streaming training metrics to Platform")

    # Single dict for all platform callback state (like trainer.hub_session for HUB callbacks)
    ctx = {"model_id": None, "last_upload": time(), "cancelled": False, "console_logger": None, "system_logger": None}
    trainer.platform = ctx

    # Create callback to send console output to Platform
    def send_console_output(content, line_count, chunk_id):
        """Send batched console output to Platform webhook."""
        _send_async(
            "console_output",
            {"chunkId": chunk_id, "content": content, "lineCount": line_count},
            project,
            name,
            ctx["model_id"],
        )

    # Start console capture with batching (5 lines or 5 seconds)
    ctx["console_logger"] = ConsoleLogger(batch_size=5, flush_interval=5.0, on_flush=send_console_output)
    ctx["console_logger"].start_capture()

    # Collect environment info (W&B-style metadata)
    environment = _get_environment_info()

    # Build trainArgs - callback runs before get_dataset() so args.data is still original (e.g., ul:// URIs)
    # Note: model_info is sent later in on_fit_epoch_end (epoch 0) when the model is actually loaded
    train_args = {k: str(v) for k, v in vars(trainer.args).items()}

    # Send synchronously to get modelId for subsequent webhooks (critical, more retries)
    response = _send(
        "training_started",
        {
            "trainArgs": train_args,
            "epochs": trainer.epochs,
            "device": str(trainer.device),
            "environment": environment,
        },
        project,
        name,
        retry=4,
    )
    if response and response.get("modelId"):
        ctx["model_id"] = response["modelId"]
        # Server returns actual slug (may differ from requested name due to auto-increment, e.g. "train" → "train-2")
        if response.get("modelSlug"):
            ctx["model_slug"] = response["modelSlug"]
            url = f"{PLATFORM_URL}/{project}/{ctx['model_slug']}"
            LOGGER.info(f"{PREFIX}View model at {url}")
        # Note: trainer.stop is set in on_pretrain_routine_end (after _setup_train resets it)
        _handle_control_response(trainer, ctx, response)
    else:
        LOGGER.warning(f"{PREFIX}Training will not be tracked on Platform")
        trainer.platform = None  # Disable further callbacks


def on_pretrain_routine_end(trainer):
    """Apply pre-start cancellation after _setup_train resets trainer.stop."""
    ctx = getattr(trainer, "platform", None)
    if ctx and ctx["cancelled"]:
        LOGGER.info(f"{PREFIX}Training cancelled from Platform before starting ✅")
        trainer.stop = True


def on_fit_epoch_end(trainer):
    """Log training and system metrics at epoch end."""
    ctx = getattr(trainer, "platform", None)
    if not ctx or RANK not in {-1, 0} or not trainer.args.project:
        return

    project, name = _get_project_name(trainer)
    metrics = {**trainer.label_loss_items(trainer.tloss, prefix="train"), **trainer.metrics}

    if trainer.optimizer and trainer.optimizer.param_groups:
        metrics["lr"] = trainer.optimizer.param_groups[0]["lr"]

    # Extract model info at epoch 0 (sent as separate field, not in metrics)
    model_info = None
    if trainer.epoch == 0:
        try:
            info = model_info_for_loggers(trainer)
            model_info = {
                "parameters": info.get("model/parameters", 0),
                "gflops": info.get("model/GFLOPs", 0),
                "speedMs": info.get("model/speed_PyTorch(ms)", 0),
            }
        except Exception:
            pass

    # Get system metrics (cache SystemLogger in platform context for efficiency)
    system = {}
    try:
        if not ctx["system_logger"]:
            ctx["system_logger"] = SystemLogger()
        system = ctx["system_logger"].get_metrics(rates=True)
    except Exception:
        pass

    payload = {
        "epoch": trainer.epoch,
        "metrics": metrics,
        "system": system,
        "fitness": trainer.fitness,
        "best_fitness": trainer.best_fitness,
    }
    if model_info:
        payload["modelInfo"] = model_info

    def _send_and_check_cancel():
        """Send epoch_end and check response for cancellation (runs in background thread)."""
        response = _send("epoch_end", payload, project, name, ctx["model_id"], retry=1)
        _handle_control_response(trainer, ctx, response)

    _executor.submit(_send_and_check_cancel)


def on_model_save(trainer):
    """Upload model checkpoint (rate limited to every 15 min)."""
    ctx = getattr(trainer, "platform", None)
    if not ctx or RANK not in {-1, 0} or not trainer.args.project:
        return

    # Rate limit to every 15 minutes (900 seconds)
    if time() - ctx["last_upload"] < 900:
        return

    model_path = trainer.best if trainer.best and Path(trainer.best).exists() else trainer.last
    if not model_path:
        return

    project, name = _get_project_name(trainer)
    _upload_model_async(model_path, project, name, model_id=ctx["model_id"])
    ctx["last_upload"] = time()


def on_train_end(trainer):
    """Log final results, upload best model, and send validation plot data."""
    ctx = getattr(trainer, "platform", None)
    if not ctx or RANK not in {-1, 0} or not trainer.args.project:
        return

    project, name = _get_project_name(trainer)

    if ctx["cancelled"]:
        LOGGER.info(f"{PREFIX}Uploading partial results for cancelled training")

    # Stop console capture
    if ctx["console_logger"]:
        ctx["console_logger"].stop_capture()
        ctx["console_logger"] = None

    # Upload best model (blocking with progress bar to ensure it completes)
    gcs_path = None
    model_size = None
    if trainer.best and Path(trainer.best).exists():
        model_size = Path(trainer.best).stat().st_size
        gcs_path = _upload_model(trainer.best, project, name, progress=True, retry=3, model_id=ctx["model_id"])
        if not gcs_path:
            LOGGER.warning(f"{PREFIX}Model will not be available for download on Platform (upload failed)")

    # Collect plots from trainer and validator, deduplicating by type
    plots_by_type = {}
    for info in getattr(trainer, "plots", {}).values():
        if info.get("data") and info["data"].get("type"):
            plots_by_type[info["data"]["type"]] = info["data"]
    for info in getattr(getattr(trainer, "validator", None), "plots", {}).values():
        if info.get("data") and info["data"].get("type"):
            plots_by_type.setdefault(info["data"]["type"], info["data"])  # Don't overwrite trainer plots
    plots = [_interp_plot(p) for p in plots_by_type.values()]  # Interpolate curves to reduce size

    # Get class names
    names = getattr(getattr(trainer, "validator", None), "names", None) or (trainer.data or {}).get("names")
    class_names = list(names.values()) if isinstance(names, dict) else list(names) if names else None

    _send(
        "training_complete",
        {
            "results": {
                "metrics": {**trainer.metrics, "fitness": trainer.fitness},
                "bestEpoch": getattr(trainer, "best_epoch", trainer.epoch),
                "bestFitness": trainer.best_fitness,
                "modelPath": gcs_path,  # Only send GCS path, not local path
                "modelSize": model_size,
            },
            "classNames": class_names,
            "plots": plots,
        },
        project,
        name,
        ctx["model_id"],
        retry=4,  # Critical, more retries
    )
    url = f"{PLATFORM_URL}/{project}/{ctx.get('model_slug', name)}"
    LOGGER.info(f"{PREFIX}View results at {url}")


callbacks = (
    {
        "on_pretrain_routine_start": on_pretrain_routine_start,
        "on_pretrain_routine_end": on_pretrain_routine_end,
        "on_fit_epoch_end": on_fit_epoch_end,
        "on_model_save": on_model_save,
        "on_train_end": on_train_end,
    }
    if _api_key
    else {}
)
