import numpy as np
import xarray as xr
import pyrsktools
import glob
import seawater
import datetime
import subprocess, os
import pandas as pd

from getInletX import getInletX
import scipy.io as sio
from datetime import datetime, timedelta


#cruise_name = '20240918'
#dir = cruise_name
dir = './'
minp = 2   # casts must be 20 m deep to count...

# First, let's see what columns are actually in the file
df_temp = pd.read_excel('311-2025_CTD_meta.xlsx', skiprows=0)
print("Available columns:")
print(df_temp.columns.tolist())
print("\nFirst few rows:")
print(df_temp.head())

# Now read with correct column names for datetime parsing
# You'll need to update the column names below based on the output above
df = pd.read_excel('311-2025_CTD_meta.xlsx', parse_dates=[['Date', 'Time_UTC']])
print(df)
lat = df['Lat_dec-deg'].values
lon = -df['Long_dec-deg'].values
if True:
    #ruskinname = f'{dir}/ruskin/206663_20240918_0946.rsk'
    for dd in df.iterrows():
        info = dd[1]
        print('Looking for')
        print("INFO:", info['Cast number'])
        # rsk to xarray
        fname = f'ruskin/{info.fname}.rsk'
        print(fname)
        with pyrsktools.open(fname) as rsk:
            print(rsk)
            channel_names = rsk.channels.keys()
            for nn, cast in enumerate(rsk.casts(pyrsktools.Region.CAST_DOWN)):
                if nn+1 == info['Cast number']:
                    times = cast.npsamples()['timestamp']
                    times = [t.replace(tzinfo=None) for t in times]
                    time = np.array(times).astype('datetime64[ns]')
                    ds = xr.Dataset(coords={'time':time})
                    ds.attrs = {'start_time': str(time[0]),
                                'end_time': str(time[-1]),
                                'serial_number': rsk.instrument.serial,
                                'model': rsk.instrument.model,
                                'firmware_version': rsk.instrument.firmware_version,
                                'firmware_type': str(rsk.instrument.firmware_type),
                                }
                    for channel in channel_names:
                        try:
                            ds[channel] = ('time', cast.npsamples()[channel])
                            ds[channel].attrs = {'units': rsk.channels[channel].units,
                                                'derived': str(rsk.channels[channel].derived),
                                                'long_name': rsk.channels[channel].name}
                        except:
                            pass
                    if ds['seapressure_00'].max() > minp:

                        ds['potential_density'] = ('time', seawater.eos80.pden(ds.salinity_00, ds.temperature_00, ds.pressure_00-10.0, 0))
                        ds['station_name'] = info['Station']
                        ds['water_depth'] = info['Bottom depth (m)']
                        ds['lat'] = info['Lat_dec-deg']
                        ds['lon'] = -info['Long_dec-deg']
                        ds['datetime'] = info['Date_Time_UTC']
                        ds['Event'] = info['Event_no']
                        ds['serial'] = rsk.instrument.serial
                        outname = f'./casts/{rsk.instrument.serial}'
                        outname += '_' + time[0].astype('datetime64[ms]').astype(datetime).strftime('%Y%m%d%H%M%S') + '.nc'
                        ds.to_netcdf(outname)
                    else:

                        dsadsadas
                    # we found it, so...
                    break


    # make a grid in Ruskin format:
if True:
    dat = glob.glob('./casts/*.nc')
    dat.sort()
    Ncasts = len(dat)
    depthbins = np.arange(0, 325, 1)
    depths = depthbins[:-1] + 0.5
    grid = xr.Dataset(coords={'depths':depths, 'cast':np.arange(Ncasts)})
    sources = {'pressure': 'seapressure_00', 'temperature': 'temperature_00',
            'conductivity': 'conductivity_00', 'chlorophyll':'chlorophyll_00',
            'oxygensaturation':'oxygensaturation_00', 'salinity': 'salinity_00',
            'par':'par_00', 'potential_density':'potential_density'}

    grid['time'] = ('cast', np.zeros(Ncasts))
    for s in sources:
        grid[s] = (('depths', 'cast'), np.zeros((len(depths), Ncasts)))

    grid['station_name'] = (('cast'), ['long_empty_string'] * Ncasts)
    grid['water_depth'] = (('cast', np.zeros(Ncasts)))
    grid['serial'] = (('cast'), ['long_empty_string'] * Ncasts)
    grid['Event'] = (('cast'), ['long_empty_string'] * Ncasts)
    grid['longitude'] = (('cast', np.zeros(Ncasts)))
    grid['latitude'] = (('cast', np.zeros(Ncasts)))
    for nn, d in enumerate(dat):
        with xr.open_dataset(d) as ds:
            grid['time'][nn] = ds.time[0]
            for s in sources:
                grid[s][:, nn] = (np.histogram(ds.depth_00, depthbins, weights=ds[sources[s]])[0] /
                                np.histogram(ds.depth_00, depthbins)[0])
                if nn == len(dat)-1:
                    grid[s].attrs = ds[sources[s]].attrs

        grid['station_name'][nn] = ds.station_name
        grid['water_depth'][nn] = ds.water_depth
        grid['longitude'][nn] = ds.lon
        grid['latitude'][nn] = ds.lat
        grid['Event'][nn] = ds.Event
        grid['serial'][nn] = ds.serial

    grid.time.attrs = {"units": 'nanoseconds since 1970-01-01T00:00:00'}


    grid.to_netcdf(f'./CTDGridRuskin.nc')



if True:
    # get the geographic grid:
    if False:
        with pyrsktools.open(ruskinname) as rsk:
            Ngeo = len(list(rsk.geodata()))
            geods = xr.Dataset(coords={'sample': np.arange(0, Ngeo)})
            geods['time'] = ('sample', np.zeros(Ngeo))
            geods['latitude'] = ('sample', np.zeros(Ngeo))
            geods['longitude'] = ('sample', np.zeros(Ngeo))

            for nn, geo in enumerate(rsk.geodata()):
                geods['time'][nn] = np.datetime64(geo.timestamp.replace(tzinfo=None))
                geods['latitude'][nn] = geo.latitude
                geods['longitude'][nn] = geo.longitude
        geods.time.attrs = {"units": 'nanoseconds since 1970-01-01T00:00:00'}

        geods.to_netcdf(f'{dir}/Geo.nc')

        with xr.open_dataset(f'{dir}/Geo.nc', engine='netcdf4') as geods, xr.open_dataset(f'{dir}/CTDGridRuskin.nc', engine='netcdf4') as cgrid:
            cgrid['longitude'] = ('cast', np.interp(cgrid.time, geods.time, geods.longitude))
            cgrid['latitude'] = ('cast', np.interp(cgrid.time, geods.time, geods.latitude))


    #lat = 48 + np.array([43.205, 44.28, 45.76, 45.49, 44.55, 43.69, 43.00, 42.30, 41.82, 40.50, 39.55, 38.28]) / 60.
    #lon = -123 - np.array([14.395, 17.29, 19.32, 22.10, 23.56, 25.05, 27.00, 29.10, 30.00, 30.00, 30.16, 30.10]) / 60.
    with xr.open_dataset(f'{dir}/CTDGridRuskin.nc', engine='netcdf4') as cgrid:

        #cgrid['longitude'] = ('cast', lon)
        # cgrid['latitude'] = ('cast', lat)

            cgrid = cgrid.swap_dims({'cast':'time'})

            x, y = getInletX(cgrid.longitude, cgrid.latitude)
            cgrid['alongx'] = ('time', x, {'units':'dist from S4 [km]'})
            cgrid['acrossx'] = ('time', y, {'units':'dist Thalweg [km]'})
            # cgrid = cgrid.sortby('alongx')

            cgrid.to_netcdf(f'{dir}/CTDGridGeoRuskin.nc')

            # convert CTDgridGeo to Old-names CTD grid
            rename = {'conductivity': 'cond',
                    'temperature': 'temp',
                    'pressure': 'pres',
                    'oxygensaturation': 'O2sat',
                    'salinity': 'sal',
                    'potential_density': 'pden',
                    'chlorophyll': 'Flu',
                    'longitude': 'lon',
                    'latitude': 'lat',
                    'station_name': 'id',
                    'water_depth': 'water_depth',
                    'par': 'Par'}

            ds = cgrid.rename(rename)
            ds.to_netcdf(f'./CtdGridNew.nc')

    # save as matlab

    cout = {}

    def datetime642matlab(dt):
        dt = dt.astype('datetime64[ms]').astype(datetime)
        return datetime2matlab(dt)

    def datetime2matlab(dt):
        mdn = dt + timedelta(days = 366)
        frac = (dt-datetime(dt.year,dt.month,dt.day,0,0,0)).seconds

        return mdn.toordinal() + frac / (24.0 * 60.0 * 60.0)

    with xr.open_dataset(f'{dir}/CtdGridNew.nc') as ds:
        cout['time'] = np.array([datetime642matlab(t) for t in ds.time.values])
        cout['depths'] = ds.depths.values
        for k in ds.keys():
            cout[k] = ds[k].values

        outname = f'{dir}/CtdGridNew.mat'
        sio.savemat(outname, cout, format='5')
        subprocess.call(['octave', 'saveasstruct.m', outname, 'cgrid'])

    # Save individual casts:
    try:
        os.mkdir(f'./ctdgrid_csv/')
    except:
        pass

    with xr.open_dataset(f'{dir}/CtdGridNew.nc') as cgrid:
        for i in range(cgrid.dims['time']):
            ds = cgrid.isel(time=i)
            id = str(ds.id.values).strip('.')
            name = f'{ds.time}_{id}'.replace('.', '')
            name = ds.time.values.astype('datetime64[ms]').astype(datetime).strftime('%Y%m%d%H%M')
            name = f'{name}_{id}'
            ds.to_netcdf(f'{dir}/{name}.nc')

            # csv:

            df = cgrid.isel(time=i).to_pandas()
            #csv_name = f'{dir}/ctdgrid_csv/ctd' + str(cgrid.id[i].values) + '.csv'
            #df.to_csv(csv_name)

            # matlab:
            cout = {}
            cout['time'] = datetime642matlab(ds.time.values)
            cout['depths'] = ds.depths.values
            for k in ds.keys():
                cout[k] = ds[k].values
            # name = ds.time[0].astype('datetime64[ms]').astype(datetime).strftime('%Y%m%d%H%M%S')
            outname = f'{dir}/{name}New.mat'
            sio.savemat(outname, cout)
            subprocess.call(['octave', 'saveasstruct.m', outname, 'ctd'])

#!mkdir 20240925/casts_binned_csv/
os.system('rm eos311202409_binned.csv')


columnrename = {
    'depths': 'depth [m]',
    'id': 'Station',
    'lon': 'longitude [degE]',
    'lat': 'latitude [degN]',
    'water_depth': 'Bot.Depth [m]',
    'year': 'year',
    'month': 'month',
    'day': 'day',
    'hour': 'hour',
    'minute': 'minute',
    'temp': 'temperature [oC]',
    'sal': 'salinity [PSS-78]',
    'pres': 'Pressure [dbar]',
    'pden': 'potential density [kg m-3]',
    'Flu': 'chlorophyll fluorescence [ug L-1]',
    'Par': 'PAR [umol-photons m-2 s-1]',
    'O2conc': 'O2 conc [umol kg-1]',
    'O2sat': 'O2 saturation [%]',
    'cond': 'conductivity [S/dm]',
    'alongx': 'along-track [km]',
    'acrossx': 'across-track [km]',
    'O2sat0': 'O2 conc at saturation [umol kg-1]'
}

with xr.open_dataset('./CtdGridNew.nc') as ds:
    ds['O2sat0'] = (('depths', 'time'), seawater.satO2(ds.sal, ds.temp))
    ds['O2conc'] = ds.O2sat0 * ds.O2sat
    dates = ds.time.values
    years = dates.astype('datetime64[Y]').astype(int) + 1970
    months = dates.astype('datetime64[M]').astype(int) % 12 + 1
    days = dates.astype('datetime64[D]') - dates.astype('datetime64[M]') + 1
    hours = dates.astype('datetime64[h]').astype(int) % 24
    minutes = dates.astype('datetime64[m]').astype(int) % 60
    print(hours, minutes, dates)
    ds['year'] = ('time', years)
    ds['month'] = ('time', months)
    ds['day'] = ('time', days)
    ds['hour'] = ('time', hours)
    ds['minute'] = ('time', minutes)
    for nn in range(len(ds.time)):
        pro = ds.isel(time=nn)
        pro = pro.where(np.isfinite(pro.temp), drop=True)
        df = pro.to_pandas()
        df['day'] = days[nn].astype('int32')
        df = df.astype({'year':'int32', 'day':'int32', 'month':'int32', 'hour':'int32', 'minute':'int32'})
        df.index.name = 'depth [m]'

        df = df.loc[:, list(columnrename.keys())[1:]]
        df = df.drop(columns=['O2sat0'])
        df = df.rename(columns=columnrename)
        name = pro.time.values.astype('datetime64[ms]').astype(datetime).strftime('%Y%m%d%H%M')
        name = f'{name}_{pro.id.values}'

        #df.to_csv(f'./name.csv')
        #print(df['year'].dtype)
        df.to_csv('eos311202409_binned.csv', mode='a', header=(nn==0))
