import tqdm

from helper_code import *
import numpy as np, scipy as sp, scipy.stats, os, sys, joblib

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from sklearn.metrics import roc_auc_score
from evaluation_2022.evaluate_model import compute_cost

def FCL(fin,fout):
    return nn.Sequential(nn.Linear(fin,fout),nn.BatchNorm1d(fout),nn.LeakyReLU(),nn.Dropout(0.8))


class NN(nn.Module):
    def __init__(self):
        super(NN,self).__init__()

        self.L1 = FCL(82,128)
        self.L2 = FCL(128,128)
        self.fc = nn.Linear(128,2)

    def forward(self,x):
        x = self.L1(x)
        x = self.L2(x)
        x = self.fc(x)
        return x


import copy
class Dataset:
    def __init__(self,features,murmurs,outcomes):
        self.features = copy.deepcopy(features)
        self.features = np.nan_to_num(self.features)
        self.murmurs = np.argmax(copy.deepcopy(murmurs),axis=1)
        self.outcomes = np.argmax(copy.deepcopy(outcomes),axis=1)

    def __len__(self):
        return self.features.shape[0]

    def __getitem__(self, item):
        return self.features[item,:],self.murmurs[item],self.outcomes[item]

    def split(self):
        TRAIN = copy.deepcopy(self)
        VALID = copy.deepcopy(self)
        R = 0.75
        TRAIN.features = TRAIN.features[:int(R*len(self))]
        TRAIN.murmurs = TRAIN.murmurs[:int(R*len(self))]
        TRAIN.outcomes = TRAIN.outcomes[:int(R*len(self))]

        VALID.features = VALID.features[int(R*len(self)):]
        VALID.murmurs = VALID.murmurs[int(R*len(self)):]
        VALID.outcomes = VALID.outcomes[int(R*len(self)):]
        return TRAIN,VALID


def prepare_features(data_folder, model_folder, verbose):
    # Find data files.
    if verbose >= 1:
        print('Finding data files...')

    # Find the patient data files.
    patient_files = find_patient_files(data_folder)
    num_patient_files = len(patient_files)

    if num_patient_files == 0:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    os.makedirs(model_folder, exist_ok=True)

    # Extract the features and labels.
    if verbose >= 1:
        print('Extracting features and labels from the Challenge data...')

    murmur_classes = ['Present', 'Unknown', 'Absent']
    num_murmur_classes = len(murmur_classes)
    outcome_classes = ['Abnormal', 'Normal']
    num_outcome_classes = len(outcome_classes)

    features = list()
    murmurs = list()
    outcomes = list()

    for i in tqdm.tqdm(range(num_patient_files)):
        if verbose >= 2:
            print('    {}/{}...'.format(i + 1, num_patient_files))

        # Load the current patient data and recordings.
        current_patient_data = load_patient_data(patient_files[i])
        current_recordings = load_recordings(data_folder, current_patient_data)

        # Extract features.
        current_features = get_features(current_patient_data, current_recordings)
        features.append(current_features)

        # Extract labels and use one-hot encoding.
        current_murmur = np.zeros(num_murmur_classes, dtype=int)
        murmur = get_murmur(current_patient_data)
        if murmur in murmur_classes:
            j = murmur_classes.index(murmur)
            current_murmur[j] = 1
        murmurs.append(current_murmur)

        current_outcome = np.zeros(num_outcome_classes, dtype=int)
        outcome = get_outcome(current_patient_data)
        if outcome in outcome_classes:
            j = outcome_classes.index(outcome)
            current_outcome[j] = 1
        outcomes.append(current_outcome)

    features = np.vstack(features)
    murmurs = np.vstack(murmurs)
    outcomes = np.vstack(outcomes)
    return features,murmurs,outcomes



def train_outcome_model_features(features,murmurs,outcomes, model_folder,ensemble_id):


    T, V = Dataset(features, murmurs, outcomes).split()

    TRAIN = DataLoader(T, batch_size=64, shuffle=True)
    VALID = DataLoader(V, batch_size=64, shuffle=False)
    mdl = NN()
    opt = optim.Adam(mdl.parameters())
    loss = nn.CrossEntropyLoss()

    BEST_SCORE = 30000

    for epoch in range(30):
        mdl.train()
        for i, (x, m, o) in enumerate(TRAIN):
            opt.zero_grad()

            x = x.float()
            o = o.long()
            m = m.long()

            y = mdl(x)
            p = torch.softmax(y,dim=-1)
            J = loss(y, o) + torch.mean(-4*p*(p-1))
            J.backward()
            opt.step()

        mdl.eval()
        outputs = []
        targets = []
        for i, (x, m, o) in enumerate(VALID):
            x = x.float()
            o = o.long()
            m = m.long()

            y = mdl(x)
            y = torch.softmax(y, dim=-1)
            targets.append(o.data.cpu().numpy())
            outputs.append(y.data.cpu().numpy())

        outputs = np.concatenate(outputs, 0)
        targets = np.concatenate(targets, 0)
        AUROC = roc_auc_score(y_true=targets, y_score=outputs[:, 1])

        score,threshld = optimize_threshold(targets,outputs[:,1])

        outcome_classes = ['Abnormal', 'Normal']
        targetsHOT = np.stack([1 - targets, targets], axis=1)

        outputsHOT = outputs[:,1] > (threshld/100)
        #outputsHOT = np.argmax(outputs, axis=1)
        outputsHOT = np.stack([1 - outputsHOT, outputsHOT], axis=1)

        outcome_cost = compute_cost(targetsHOT, outputsHOT, outcome_classes, outcome_classes)

        if outcome_cost < BEST_SCORE:
            print(epoch, AUROC, outcome_cost)
            print('SAVE')
            BEST_SCORE = outcome_cost
            torch.save({'model': mdl.state_dict(),
                        'epoch': epoch,
                        'score': outcome_cost,
                        'threshold':threshld},
                       f'{model_folder}/model_outcome_{ensemble_id}',
                       )

        stop = 1

    stop = 1

import matplotlib.pyplot as plt
def optimize_threshold(targets,outputs,plot=True):
    result = []
    for threshold in range(0,100):
        outcome_classes = ['Abnormal', 'Normal']
        targetsHOT = np.stack([1 - targets, targets], axis=1)

        outputsHOT = outputs > threshold/100
        outputsHOT = np.stack([1 - outputsHOT, outputsHOT], axis=1)

        outcome_cost = compute_cost(targetsHOT, outputsHOT, outcome_classes, outcome_classes)
        result.append(outcome_cost)
    if plot:
        plt.plot(result)
        plt.grid()
        plt.show()
    return np.min(result),np.argmin(result)

from scipy.signal import butter,filtfilt,hilbert

# 15–90 Hz; 55–150 Hz; 100–250 Hz; 200–450 Hz; 400–800 Hz

B0,A0 = butter(3,[15/2000,90/2000],'bandpass')
B1,A1 = butter(3,[55/2000,150/2000],'bandpass')
B2,A2 = butter(3,[100/2000,250/2000],'bandpass')
B3,A3 = butter(3,[200/2000,450/2000],'bandpass')
B4,A4 = butter(3,[400/2000,800/2000],'bandpass')

def band_envelope(x,B,A):
    return np.abs(hilbert(filtfilt(B,A,x)))


def power(x):
    return np.sqrt(np.mean(x**2))



# Extract features from the data.
def get_features(data, recordings):
    # Extract the age group and replace with the (approximate) number of months for the middle of the age group.
    age_group = get_age(data)

    if compare_strings(age_group, 'Neonate'):
        age = 0.5
    elif compare_strings(age_group, 'Infant'):
        age = 6
    elif compare_strings(age_group, 'Child'):
        age = 6 * 12
    elif compare_strings(age_group, 'Adolescent'):
        age = 15 * 12
    elif compare_strings(age_group, 'Young Adult'):
        age = 20 * 12
    else:
        age = float('nan')

    # Extract sex. Use one-hot encoding.
    sex = get_sex(data)

    sex_features = np.zeros(2, dtype=int)
    if compare_strings(sex, 'Female'):
        sex_features[0] = 1
    elif compare_strings(sex, 'Male'):
        sex_features[1] = 1

    # Extract height and weight.
    height = get_height(data)
    weight = get_weight(data)

    # Extract pregnancy status.
    is_pregnant = get_pregnancy_status(data)

    # Extract recording locations and data. Identify when a location is present, and compute the mean, variance, and skewness of
    # each recording. If there are multiple recordings for one location, then extract features from the last recording.
    locations = get_locations(data)

    recording_locations = ['AV', 'MV', 'PV', 'TV', 'PhC']
    num_recording_locations = len(recording_locations)
    recording_features = np.zeros((num_recording_locations, 15), dtype=float)
    num_locations = len(locations)
    num_recordings = len(recordings)
    if num_locations==num_recordings:
        for i in range(num_locations):
            for j in range(num_recording_locations):
                if compare_strings(locations[i], recording_locations[j]) and np.size(recordings[i])>0:
                    recording_features[j, 0] = np.mean(recordings[i])
                    recording_features[j, 1] = np.std(recordings[i])
                    recording_features[j, 2] = sp.stats.skew(recordings[i])
                    recording_features[j, 3] = recordings[i].shape[0]/4000
                    recording_features[j, 4] = power(recordings[i])
                    tmp0 = band_envelope(recordings[i],B0,A0)
                    tmp1 = band_envelope(recordings[i],B1,A1)
                    tmp2 = band_envelope(recordings[i],B2,A2)
                    tmp3 = band_envelope(recordings[i],B3,A3)
                    tmp4 = band_envelope(recordings[i],B4,A4)
                    recording_features[j, 5] = power(tmp0)
                    recording_features[j, 6] = power(tmp1)
                    recording_features[j, 7] = power(tmp2)
                    recording_features[j, 8] = power(tmp3)
                    recording_features[j, 9] = power(tmp4)
                    recording_features[j, 10] = np.corrcoef(tmp0,tmp1)[0,1]
                    recording_features[j, 11] = np.corrcoef(tmp1,tmp2)[0,1]
                    recording_features[j, 12] = np.corrcoef(tmp3,tmp4)[0,1]
                    recording_features[j, 13] = np.corrcoef(tmp4,tmp0)[0,1]
                    recording_features[j, 14] = scipy.stats.kurtosis(recordings[i])




    recording_features = recording_features.flatten()

    # murmur = get_murmur(data)
    # murmur = 1 if murmur == 'Present' else 0

    features = np.hstack(([age], sex_features, [height], [weight], [is_pregnant],[len(recording_features)], recording_features))
    features = np.nan_to_num(features)
    return np.asarray(features, dtype=np.float32)