import numpy as np
import xarray as xr
import pyrsktools
import glob
import seawater
import datetime
import subprocess, os
from pathlib import Path
from getInletX import getInletX
import scipy.io as sio
from datetime import datetime, timedelta

top = Path(__file__).resolve().parent

cruise_name = '20250917'
maindir = Path(f'{top}/ctd/{cruise_name}/')
maindir.mkdir(parents=True, exist_ok=True)
ruskin_dir = Path(f'{top}/ctd/{cruise_name}/ruskin')
ruskin_dir.mkdir(parents=True, exist_ok=True)
print(ruskin_dir)
ctdgrid_csv_dir = Path(f'{top}/ctd/{cruise_name}/ctdgrid_csv')
ctdgrid_csv_dir.mkdir(parents=True, exist_ok=True)
casts_dir = Path(f'{top}/ctd/{cruise_name}/casts')
casts_dir.mkdir(parents=True, exist_ok=True)

casts_binned_csv_dir = Path(f'{top}/ctd/{cruise_name}/casts_binned_csv')
casts_binned_csv_dir.mkdir(parents=True, exist_ok=True)

dir = cruise_name
minp = 20   # casts must be 20 m deep to count...
# these are determined by hand.  Ruskin software should do this,
# but didn't seem to work 2022..
Casts = {'A1': [1809, 5421],
         'A2': [9858, 12095],
        'A3': [15449, 16777],
        'A5': [19067, 20571],
        'S8': [23106, 23863],
         'S5.5': [25909, 26884],
         'S5': [30157, 31927],
         'S12.5': [ 36805, 37914],
         'S12': [40040, 41094],
        }


names = ['H1', 'A2', 'A3', 'A5', 'S8', 'S5.5', 'S5', 'S12.5', 'S12', 'S4.5', 'S4.25', 'S4']
depths = [337, 104, 98, 107, 79, 88, 114, 81, 83, 156, 183, 201 ]

ruskinname = f'{dir}/ruskin/206664_20250917_1122.rsk'
ruskinname = ruskin_dir / Path('206664_20250917_1646.rsk')
print(ruskinname)

if True:
    # rsk to xarray
    with pyrsktools.open(ruskinname) as rsk:
        print(rsk)
        channel_names = rsk.channels.keys()
        print('CASTS',   rsk.casts(pyrsktools.Region.CAST_DOWN))
        for nn, cast in enumerate(rsk.casts(pyrsktools.Region.CAST_DOWN)):
            print(nn, '\n')
            times = cast.npsamples()['timestamp']
            times = [t.replace(tzinfo=None) for t in times]
            time = np.array(times).astype('datetime64[ns]')
            ds = xr.Dataset(coords={'time':time})
            ds.attrs = {'start_time': str(time[0]),
                        'end_time': str(time[-1]),
                        'serial_number': rsk.instrument.serial,
                        'model': rsk.instrument.model,
                        'firmware_version': rsk.instrument.firmware_version,
                        'firmware_type': str(rsk.instrument.firmware_type),
                        }
            print('Names', channel_names)
            for channel in channel_names:
                try:
                    ds[channel] = ('time', cast.npsamples()[channel])
                    ds[channel].attrs = {'units': rsk.channels[channel].units,
                                        'derived': str(rsk.channels[channel].derived),
                                        'long_name': rsk.channels[channel].name}
                except ValueError:
                    print(f'No {channel} found')
            if ds['seapressure_00'].max() > minp:

                ds['potential_density'] = ('time', seawater.eos80.pden(ds.salinity_00, ds.temperature_00, ds.pressure_00-10.0, 0))
                ds['station_name'] = names[nn]
                if names == '':
                    raise RuntimeError('name is empty, but cast has enough data')
                ds['water_depth'] = depths[nn]
                outname = casts_dir / Path(f'{rsk.instrument.serial}_{time[0].astype("datetime64[ms]").astype(datetime).strftime("%Y%m%d%H%M%S")}.nc')
                ds.to_netcdf(outname)


    # make a grid in Ruskin format:

    dat = Path(casts_dir).glob('*.nc')
    dat = sorted(dat)
    Ncasts = len(dat)
    depthbins = np.arange(0, 325, 1)
    depths = depthbins[:-1] + 0.5
    grid = xr.Dataset(coords={'depths':depths, 'cast':np.arange(Ncasts)})
    sources = {'pressure': 'seapressure_00', 'temperature': 'temperature_00',
            'conductivity': 'conductivity_00', 'chlorophyll':'chlorophyll_00',
            'oxygensaturation':'oxygensaturation_00', 'salinity': 'salinity_00',
            'par':'par_00', 'potential_density':'potential_density'}

    grid['time'] = ('cast', np.zeros(Ncasts))
    for s in sources:
        grid[s] = (('depths', 'cast'), np.zeros((len(depths), Ncasts)))

    grid['station_name'] = (('cast'), ['long_empty_string'] * Ncasts)
    grid['water_depth'] = (('cast', np.zeros(Ncasts)))
    for nn, d in enumerate(dat):
        with xr.open_dataset(d) as ds:
            grid['time'][nn] = ds.time[0]
            for s in sources:
                grid[s][:, nn] = (np.histogram(ds.depth_00, depthbins, weights=ds[sources[s]])[0] /
                                np.histogram(ds.depth_00, depthbins)[0])
                if nn == len(dat)-1:
                    grid[s].attrs = ds[sources[s]].attrs

        grid['station_name'][nn] = ds.station_name
        grid['water_depth'][nn] = ds.water_depth

    grid.time.attrs = {"units": 'nanoseconds since 1970-01-01T00:00:00'}


    grid.to_netcdf(maindir / 'CTDGridRuskin.nc')

    # get the geographic grid:
    if True:
        with xr.open_dataset(maindir / 'output_gpx.nc', engine='netcdf4') as geods:
            with xr.open_dataset(maindir / 'CTDGridRuskin.nc', engine='netcdf4') as cgrid:
                cgrid['longitude'] = ('cast', np.interp(cgrid.time, geods.time, geods.lon))
                cgrid['latitude'] = ('cast', np.interp(cgrid.time, geods.time, geods.lat))

                print(cgrid.longitude.values)
                #dsadsa

                cgrid = cgrid.swap_dims({'cast':'time'})

                x, y = getInletX(cgrid.longitude, cgrid.latitude)
                cgrid['alongx'] = ('time', x, {'units':'dist from S4 [km]'})
                cgrid['acrossx'] = ('time', y, {'units':'dist Thalweg [km]'})
                # cgrid = cgrid.sortby('alongx')

                cgrid.to_netcdf(maindir / 'CTDGridGeoRuskin.nc')

    # convert CTDgridGeo to Old-names CTD grid
    rename = {'conductivity': 'cond',
            'temperature': 'temp',
            'pressure': 'pres',
            'oxygensaturation': 'O2sat',
            'salinity': 'sal',
            'potential_density': 'pden',
            'chlorophyll': 'Flu',
            'longitude': 'lon',
            'latitude': 'lat',
            'station_name': 'id',
            'water_depth': 'water_depth',
            'par': 'Par'}

    ds = cgrid.rename(rename)
    ds.to_netcdf(maindir / 'CtdGridNew.nc')

# save as matlab

cout = {}

def datetime642matlab(dt):
    dt = dt.astype('datetime64[ms]').astype(datetime)
    return datetime2matlab(dt)

def datetime2matlab(dt):
    mdn = dt + timedelta(days = 366)
    frac = (dt-datetime(dt.year,dt.month,dt.day,0,0,0)).seconds

    return mdn.toordinal() + frac / (24.0 * 60.0 * 60.0)

with xr.open_dataset(maindir / 'CtdGridNew.nc') as ds:
    cout['time'] = np.array([datetime642matlab(t) for t in ds.time.values])
    cout['depths'] = ds.depths.values
    for k in ds.keys():
        cout[k] = ds[k].values

    sio.savemat(maindir / 'CtdGridNew.mat', cout, format='5')
    subprocess.call(['octave', 'saveasstruct.m', outname, 'cgrid'])

# Save individual casts:
try:
    os.mkdir(maindir / 'ctdgrid_csv/')
except:
    pass

with xr.open_dataset(maindir / 'CtdGridNew.nc') as cgrid:
    for i in range(cgrid.dims['time']):
        ds = cgrid.isel(time=i)
        id = str(ds.id.values).strip('.')
        name = f'{cruise_name}_{id}'.replace('.', '')
        ds.to_netcdf(f'{maindir}/{name}.nc')

        # csv:

        df = cgrid.isel(time=i).to_pandas()
        csv_name = f'{maindir}/ctdgrid_csv/ctd' + str(cgrid.id[i].values) + '.csv'
        df.to_csv(csv_name)

        # matlab:
        cout = {}
        cout['time'] = datetime642matlab(ds.time.values)
        cout['depths'] = ds.depths.values
        for k in ds.keys():
            cout[k] = ds[k].values
        outname = f'{maindir}/{name}New.mat'
        sio.savemat(outname, cout)
        subprocess.call(['octave', 'saveasstruct.m', outname, 'ctd'])

