import xml.etree.ElementTree as ET
import numpy as np
import xarray as xr

gpx_path = "ctd/20250917/ruskin/55ECEC04-9D45-421F-B2A9-CFAFF6EA5888_20250917_0834.gpx"

# Parse and detect the GPX default namespace automatically
root = ET.parse(gpx_path).getroot()
ns_uri = root.tag.split("}")[0].strip("{") if root.tag.startswith("{") else ""
ns = {"gpx": ns_uri} if ns_uri else {}

def _find_time(trkpt):
    t = trkpt.find("gpx:time", ns) if ns_uri else trkpt.find("time")
    return t.text.strip() if (t is not None and t.text) else None

def _find_acc(trkpt):
    # Look for any <...:acc> under <extensions> regardless of prefix
    for e in trkpt.findall(".//*"):
        if e.tag.endswith("acc") and e.text:
            try:
                return float(e.text)
            except ValueError:
                pass
    return np.nan

lats, lons, tstrings, accs = [], [], [], []

# Iterate all track points
trkpt_iter = root.findall(".//gpx:trkpt", ns) if ns_uri else root.findall(".//trkpt")
for trkpt in trkpt_iter:
    ts = _find_time(trkpt)
    if not ts:
        continue  # skip points without time (keeps arrays aligned)

    # Numpy datetime64 is timezone-naive; treat 'Z' (UTC) as UTC and drop the tz marker
    ts = ts.replace("Z", "")
    tstrings.append(ts)

    lats.append(float(trkpt.attrib["lat"]))
    lons.append(float(trkpt.attrib["lon"]))
    accs.append(_find_acc(trkpt))

# Convert to datetime64[ns] for xarray
times = np.array(tstrings, dtype="datetime64[ns]")

# Build Dataset
ds = xr.Dataset(
    data_vars=dict(
        lat=("time", np.asarray(lats, dtype=float)),
        lon=("time", np.asarray(lons, dtype=float)),
        acc=("time", np.asarray(accs, dtype=float)),  # remove if not needed
    ),
    coords=dict(time=times),
)

ds.to_netcdf("ctd/20250917/output_gpx.nc")
print(ds)