#!/usr/bin/env python

# Edit this script to add your team's training code.

from helper_code import *
import numpy as np, os
from torch.optim import Adam
from torch.nn import CrossEntropyLoss
import torch.nn.functional as F
import torch
import torch.nn as nn
#import wfdb
#import neurokit2 as nk
#import pickle
#from wfdb import processing
#import math


################################################################################
#
# Network
#
################################################################################

class FFNN(nn.Module):
    def __init__(self, feature_size, num_labels):
        super(FFNN, self).__init__()
        self.fc1 = nn.Linear(feature_size, 1024)
        self.fc2 = nn.Linear(1024, 1024)
        self.fc3 = nn.Linear(1024, 1024)
        self.fc4 = nn.Linear(1024, num_labels * 2)
        self.dropout_dense = nn.Dropout(p=0.315)

    def forward(self, input):
        x = F.relu(self.fc1(input))
        x = self.dropout_dense(x)
        x = F.relu(self.fc2(x))
        x = self.dropout_dense(x)
        x = F.relu(self.fc3(x))
        x = self.dropout_dense(x)
        x = self.fc4(x)

        x = x.reshape(x.shape[0], num_classes, 2)
        x = x.reshape(x.shape[0] * x.shape[1], 2)

        return x


class CNNBinaryOutput(nn.Module):
    def __init__(self, input_size, channels_input, output_size, batch_size=64):
        super(CNNBinaryOutput, self).__init__()
        self.conv1 = nn.Conv2d(channels_input, 64, (1, 3))
        self.pool = nn.MaxPool2d((1, 2), stride=2)
        self.conv2 = nn.Conv2d(64, 64, (1, 3))
        self.conv3 = nn.Conv2d(64, 64, (1, 3))
        self.conv4 = nn.Conv2d(64, 64, (1, 3))
        self.conv5 = nn.Conv2d(64, 64, (1, 3))
        self.dropout_conv = nn.Dropout(p=0.1625)
        self.dropout_dense = nn.Dropout(p=0.315)

        self.avg_pool = nn.AvgPool2d((64, 2))

        self.fc1 = nn.Linear(64, 512)
        self.fc2 = nn.Linear(512, 512)
        self.fc3 = nn.Linear(512, 128)
        self.fc4 = nn.Linear(128, output_size)

        self.input_size = input_size
        self.batch_size = batch_size
        self.input_channels = channels_input

        self.softmax = nn.Softmax(dim=2)

    def forward(self, x):
        # Reshape input
        x = x.view(x.shape[0], self.input_channels, 1, self.input_size)

        # Conv Forward Pass
        out = self.pool(F.elu(self.conv1(x)))
        out = self.dropout_conv(out)
        out = self.pool(F.elu(self.conv2(out)))
        out = self.dropout_conv(out)
        out = self.pool(F.elu(self.conv3(out)))
        out = self.dropout_conv(out)
        out = self.pool(F.elu(self.conv4(out)))
        out = self.dropout_conv(out)
        out = self.pool(F.elu(self.conv5(out)))
        out = self.dropout_conv(out)

        out = F.adaptive_avg_pool2d(out, (1, 1))

        out = out.squeeze(dim=3).squeeze(dim=2)

        # FC Pass
        out = F.elu(self.fc1(out))
        out = self.dropout_dense(out)
        out = F.elu(self.fc2(out))
        out = self.dropout_dense(out)
        out = F.elu(self.fc3(out))
        out = self.dropout_dense(out)
        out = self.fc4(out)

        out = out.reshape(out.shape[0], num_classes, 2)
        # out = self.softmax(out)
        # out = out.reshape(out.shape[0],6)
        out = out.reshape(out.shape[0] * out.shape[1], 2)

        return out


################################################################################
#
# Helper functions
#
################################################################################

"""
Helper for get_mini_batches (as numpy does not provide a splitting method???
"""


def split(rand_array, batch_size):
    batches = []
    for index in np.arange(0, rand_array.size, batch_size):
        batches.append(torch.tensor(rand_array[index:index + batch_size]))
    return batches


"""
Returns an array of indeces of size batch_size (mini-batches). Is not a tensor yet, needs to be casted!
"""


def get_mini_batches(max_size, batch_size):
    rng = np.random.default_rng()
    arr = np.arange(max_size)
    rng.shuffle(arr)
    batch_arr = split(arr, batch_size)
    if batch_arr[-1].shape[0] < batch_size:
        return batch_arr[:-1]
    else:
        return batch_arr


"""
Preprocess recordings, e.g. norm them at 100 Hz and slice them into N seconds size chunks
Use in conjunction with build_ts_data(i.e. the output here are the records)
"""


def preprocess_recordings(recordings, frequencies, verbose=True):
    normalized_records = []
    cntr = 1
    if verbose: print("Norming Data...")
    for record, frequence in zip(recordings, frequencies):

        if verbose: print('    {}/{}...'.format(cntr, len(recordings)))
        cntr += 1
        # get length of sample in seconds
        sample_size = record.shape[1]
        time = int(sample_size / frequence)

        # Norm to 100 Hz for every time
        sub_recordings = []
        # Only take every (step_value)th point s.t. each sample is normed at 100 Hz
        step_value = int(frequence / 100)
        for episode in range(step_value):
            sub_recordings.append(record[:, episode::step_value])

        # Split the sub recordings into N second long chunks
        N = 3
        num_time_slices = int(time / N)
        # get a max. of X samples from one ECG
        X = 40
        num_time_slices = min(num_time_slices, X)
        # get a max. of 5 sub recordings for each record (i.e. cap at 500 Hz)
        amount_sub_recordings = min(len(sub_recordings), 5)
        sub_recordings = sub_recordings[:amount_sub_recordings]

        normed_records = []

        for sub_record in sub_recordings:
            for i in range(num_time_slices):
                normed_records.append(sub_record[:, i * N * 100:(i + 1) * N * 100])

        normalized_records.append(np.array(normed_records))
    return normalized_records


"""
Converts the list of lists of records into a tensor that has correct labels and aux data
@:param records a list of list of records (length = len(input_labels) = len(data))
@:return finished preprocessed tensors for labels, records and aux data
"""


def build_ts_data(records, input_labels, auxiliary_data, verbose=True):
    if verbose: print("Rebuild tensors...")
    ts = []
    labels = []
    aux_data = []
    cntr = 1
    for record, label, aux in zip(records, input_labels, auxiliary_data):
        if verbose: print('    {}/{}...'.format(cntr, len(records)))
        cntr += 1
        for normed_record in record:
            ts.append(normed_record)
            labels.append(label)
            aux_data.append(aux)

    # return as tensors
    return torch.from_numpy(np.array(ts)), torch.from_numpy(np.array(labels)), torch.from_numpy(np.array(aux_data))


def get_age_and_sex(header):
    # Extract age.
    age = get_age(header)
    if age is None:
        age = float('nan')

    # Extract sex. Encode as 0 for female, 1 for male, and NaN for other.
    sex = get_sex(header)
    if sex in ('Female', 'female', 'F', 'f'):
        sex = 0
    elif sex in ('Male', 'male', 'M', 'm'):
        sex = 1
    else:
        sex = float('nan')

    return age, sex


def get_binary_labels(labels):
    amount_neurons = len(relevant_classes) * 2
    summarized_labels = torch.zeros(labels.shape[0], amount_neurons)

    # Per default set each disease to not present (each odd number is 1)
    for counter, label in enumerate(summarized_labels):
        for inner_index in range(1, amount_neurons, 2):
            summarized_labels[counter, inner_index] = 1

    for index, label in enumerate(labels):

        index_list = [i for i, x in enumerate(label) if x]

        for inner_index in index_list:
            # Set the corresponding disease to 1 and the "no" neuron to 0.
            summarized_labels[index][inner_index * 2] = 1
            summarized_labels[index][(inner_index * 2) + 1] = 0

    return summarized_labels


relevant_classes = ["270492004", "164889003", "164890007", "426627000", "713427006", "713426002", "445118002",
                    "39732003"
    , "164909002", "251146004", "698252002", "10370003", "284470004", "427172004", "164947007", "111975006", "164917005"
    , "47665007", "59118001", "427393009", "426177001", "426783006", "427084000", "63593006", "164934002", "59931005",
                    "17338001"]

leads_dict = {12: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF', 'V1', 'V2', 'V3', 'V4', 'V5', 'V6'),

              6: ('I', 'II', 'III', 'aVR', 'aVL', 'aVF'),
              4: ('I', 'II', 'III', 'V2'),
              3: ('I', 'II', 'V2'),
              2: ('I', 'II'),
              }
lead_sets = [leads_dict[12], leads_dict[6], leads_dict[4], leads_dict[3], leads_dict[2]]
lead_ints = [12, 6, 4, 3, 2]
twelve_leads = leads_dict[12]
num_classes = len(relevant_classes)
device = "cuda" if torch.cuda.is_available() else "cpu"
ann_dir = "annotations"
feature_dims_stack = 31

"""
# Heuristic Functions

def get_lqt(header):
    patient = header.split()[0]
    freq = get_frequency(header)
    try:
        with open(ann_dir+"/" + patient + "I" + "wfdb.ann", "rb") as f:
            r_peaks = pickle.load(f)
            r_peaks = r_peaks.qrs_inds
            f.close()

        with open(ann_dir+"/" + patient + "I" + "nk2.ann", "rb") as f:
            waves_peak = pickle.load(f)
            t_peaks = waves_peak['ECG_T_Peaks']
            p_peaks = waves_peak['ECG_P_Peaks']
            q_peaks = waves_peak['ECG_Q_Peaks']
            s_peaks = waves_peak['ECG_S_Peaks']

            f.close()
    except:
        return 0

    # q -> r -> t (max. dist q,t)
    valid_distances = []
    for q in q_peaks:
        if math.isnan(q): continue
        try:
            t = min([t for t in t_peaks if t > q])
            number_r = len([1 for r in r_peaks if t > r > q])
            if number_r == 1:
                valid_distances.append(t - q)
        except:
            print("No t found.")

    try:
        # 350< qtzeit < 550
        qtzeit = ((max(valid_distances)) / freq) * 1000

        rr = (np.mean([(r_peaks[i + 1] - r_peaks[i]) for i in range(len(r_peaks) - 1)]) / freq)

        if 0.6 < rr < 1:
            qtc = qtzeit / math.sqrt(rr)
        else:
            qtc = qtzeit / (rr ** (1 / 3))

        if qtc > 440:
            return 1
        else:
            return 0

    except ValueError:
        print("No valid dist found")
        return 0


def get_lqrsv(header, recording):
    leads = get_leads(header)

    patient = header.split()[0]
    freq = get_frequency(header)
    try:
        with open(ann_dir+"/" + patient + "I" + "wfdb.ann", "rb") as f:
            r_peaks = pickle.load(f)
            r_peaks = r_peaks.qrs_inds
            f.close()

        with open(ann_dir+"/" + patient + "I" + "nk2.ann", "rb") as f:
            waves_peak = pickle.load(f)
            s_peaks = waves_peak['ECG_S_Peaks']

            f.close()

        ans = []
        # r, s
        for lead_index, lead_name in zip(range(len(recording.shape[0])), leads):
            ecg_signal = recording[lead_index, :]
            max_dist = int((1 / 20) * freq)
            y_axis_distances = []
            for r in r_peaks:
                close_s = [s for s in s_peaks if 0 < s - r < max_dist]
                if len(close_s) < 1: continue
                y_axis_distances.append(ecg_signal[r] - ecg_signal[close_s[0]])
            if len(y_axis_distances) > 0:
                mvqrs = max(y_axis_distances)
            else:
                mvqrs = 50000
            treshold = 500 if twelve_leads.index(lead_name) <= 5 else 1000
            if mvqrs < treshold:
                ans.append(1)
            else:
                ans.append(0)

        return sum(ans) == len(ans)

    except:
        return 0


####
# Function that is responsible to generate input for the stack network
####

def get_stack_labels(batch, header_files):
    labels = np.zeros((len(batch), num_classes), dtype=np.float32)
    for i, file_index in enumerate(batch):

        header = load_header(header_files[file_index])

        current_labels = get_labels(header)
        for label in current_labels:
            if label in relevant_classes:
                j = relevant_classes.index(label)
                labels[i, j] = 1
    labels = torch.tensor(labels)
    return get_binary_labels(labels)


def get_stack_input(model, batch, recording_files, header_files, loaded=False, lead_counter=12):
    data = torch.zeros([len(batch), feature_dims_stack])

    for index, file_index in enumerate(batch):
        if not loaded:
            recording = load_recording(recording_file=recording_files[file_index])
            header = load_header(header_files[file_index])

        else:
            header = header_files[0]
            recording = recording_files[0]

        if lead_counter != 12:
            lead_indices = [twelve_leads.index(lead) for lead in leads_dict[lead_counter]]
            recording = recording[lead_indices, :]
        cnn_out = run_cnn_alone(model, header, recording)
        cnn_out = cnn_out[1]
        age, sex = get_age_and_sex(header)
        if math.isnan(age): age = 60
        if math.isnan(sex):sex = 2
        age = age/100
        lqt = get_lqt(header)
        lqrsv = get_lqrsv(header, recording)
        ans = torch.vstack([torch.tensor(cnn_out).unsqueeze(dim=1), torch.tensor(lqt), torch.tensor(lqrsv),
                            torch.tensor(sex),torch.tensor(age)
                            ]) #
        #ans = torch.tensor(cnn_out)
        #ans = torch.vstack([torch.tensor(cnn_out).unsqueeze(dim=1), torch.tensor(lqt), torch.tensor(lqrsv),])
        data[index] = ans.T
    data = data.to(device)
    return data


def create_annotations(data_directory):
    print("creating annotations")
    header_files, recording_files = find_challenge_files(data_directory)
    output_directory = ann_dir
    if not os.path.isdir(output_directory):
        os.mkdir(output_directory)

    for lead_index, lead_name in zip([0, 1], ["I", "II"]):
        print("Curr. lead: " + lead_name)
        counter = 0
        for header_file, recording_file in zip(header_files, recording_files):
            try:
                counter += 1
                print(counter)
                header = load_header(header_file)
                ecg_signal = load_recording(recording_file)
                ecg_signal = ecg_signal[lead_index, :]
                ecg_signal = np.concatenate([np.zeros(500), ecg_signal])
                print("Starting nk...")
                # Extract R-peaks locations
                _, rpeaks = nk.ecg_peaks(ecg_signal, sampling_rate=get_frequency(header))

                # # Delineate the ECG signal
                _, waves_peak = nk.ecg_delineate(ecg_signal, rpeaks, sampling_rate=get_frequency(header))
                #
                print("Starting wfdb...")
                sig, fields = wfdb.rdsamp(recording_file[:-4], channels=[lead_index])
                sig = np.concatenate([np.zeros([500]), sig.squeeze()])
                xqrs = processing.XQRS(sig=sig, fs=fields['fs'])
                xqrs.detect(verbose=False)

                file = header.split()[0]

                with open(output_directory + "/" + file + lead_name + "wfdb" + ".ann", "wb") as f:
                    pickle.dump(xqrs, f)
                    f.close()

                with open(output_directory + "/" + file + lead_name + "nk2" + ".ann", "wb") as f:
                    pickle.dump(waves_peak, f)
                    f.close()
            except:
                pass
                print("Non processable patient found.")

"""
################################################################################
#
# Training function
#
################################################################################


# Train your model. This function is **required**. Do **not** change the arguments of this function.
def training_code(data_directory, model_directory):

    #create_annotations(data_directory)

    header_files, recording_files = find_challenge_files(data_directory)
    num_recordings = len(recording_files)

    # Extract features and labels from dataset.
    print('Extracting features and labels...')

    aux_data = np.zeros((num_recordings, 2),
                        dtype=np.float32)

    # Overwrite the classes with only the relevant ones
    classes = relevant_classes
    # num_classes = len(classes)

    labels = np.zeros((num_recordings, num_classes), dtype=np.float32)

    # Collect frequencies and recordings
    frequencies = []
    recordings = []

    for i in range(num_recordings):
        print('    {}/{}...'.format(i + 1, num_recordings))

        # Load header and recording.
        header = load_header(header_files[i])
        recording = load_recording(recording_files[i])

        recordings.append(recording)
        frequencies.append(get_frequency(header))
        # Get age and sex
        age, sex = get_age_and_sex(header)
        aux_data[i, 0] = age
        aux_data[i, 1] = sex

        current_labels = get_labels(header)
        for label in current_labels:
            if label in classes:
                j = classes.index(label)
                labels[i, j] = 1

    normed_recordings = preprocess_recordings(recordings, frequencies)
    data, labels, aux_data = build_ts_data(normed_recordings, labels, aux_data)

    """
    FYI: 
    Data is build like (num_samples, num_leads, time_series_length)
    Labels is build like (num_samples, output_size)
    """
    labels = get_binary_labels(labels)
    # Training Loop

    batch_size = 128

    #Debug code:
    #lead_sets = [twelve_leads]
    for leads in lead_sets:
        print("Starting training for the " + str(len(leads)) + " lead model...")
        lead_indices = [twelve_leads.index(lead) for lead in leads]

        net = CNNBinaryOutput(input_size=data.shape[2], channels_input=len(leads), output_size=labels.shape[1])
        net.to(device)
        optim = Adam(net.parameters(), lr=0.0003)  # 0.0006
        criterion = CrossEntropyLoss()

        for epoch in range(55): #55

            mini_batches = get_mini_batches(len(data), batch_size)
            acc_loss = 0

            for batch in mini_batches:
                batch = batch.type(torch.LongTensor)
                batch_data = data[batch]  # .to(device)
                batch_data = batch_data[:, lead_indices, :]
                batch_data = batch_data.type(torch.FloatTensor)
                batch_data = batch_data.to(device)
                output = net(batch_data)
                current_labels = labels[batch].reshape(batch_size * num_classes, 2)
                indices = current_labels.argmax(dim=1).to(device)

                loss = criterion(output, indices)
                acc_loss += loss.item()
                optim.zero_grad()
                loss.backward()
                optim.step()

            print("Acc. Loss over Epoch {}: {}".format(epoch, acc_loss / len(mini_batches)))

        if not os.path.isdir(model_directory):
            os.mkdir(model_directory)
        torch.save(net.state_dict(), model_directory + "/" + str(len(leads)) + "_lead_ecg_model.pt")


################################################################################
#
# File I/O functions
#
################################################################################

def load_model(model_directory, leads):
    if len(leads) == 12:
        return load_twelve_lead_model(model_directory)
    if len(leads) == 6:
        return load_six_lead_model(model_directory)
    if len(leads) == 4:
        return load_four_lead_model(model_directory)
    if len(leads) == 3:
        return load_three_lead_model(model_directory)
    if len(leads) == 2:
        return load_two_lead_model(model_directory)


# Load your trained 12-lead ECG model. This function is **required**. Do **not** change the arguments of this function.
def load_twelve_lead_model(model_directory):
    model = CNNBinaryOutput(300,12,54)
    model.load_state_dict(torch.load(model_directory + "/12_lead_ecg_model.pt"))
    #model = torch.load(model_directory + "/12_lead_ecg_model.pt")
    model.eval()
    model.to(device)
    return model


# Load your trained 6-lead ECG model. This function is **required**. Do **not** change the arguments of this function.
def load_six_lead_model(model_directory):
    model = CNNBinaryOutput(300, 6, 54)
    model.load_state_dict(torch.load(model_directory + "/6_lead_ecg_model.pt"))
    #model = torch.load(model_directory + "/6_lead_ecg_model.pt")
    model.eval()
    model.to(device)
    return model


# Load your trained 2-lead ECG model. This function is **required**. Do **not** change the arguments of this function.
def load_three_lead_model(model_directory):
    model = CNNBinaryOutput(300, 3, 54)
    model.load_state_dict(torch.load(model_directory + "/3_lead_ecg_model.pt"))
    #model = torch.load(model_directory + "/3_lead_ecg_model.pt")
    model.eval()
    model.to(device)
    return model


# Load your trained 2-lead ECG model. This function is **required**. Do **not** change the arguments of this function.
def load_two_lead_model(model_directory):
    model = CNNBinaryOutput(300, 2, 54)
    model.load_state_dict(torch.load(model_directory + "/2_lead_ecg_model.pt"))
    #model = torch.load(model_directory + "/2_lead_ecg_model.pt")
    model.eval()
    model.to(device)
    return model


def load_four_lead_model(model_directory):
    model = CNNBinaryOutput(300, 4, 54)
    model.load_state_dict(torch.load(model_directory + "/4_lead_ecg_model.pt"))
    #model = torch.load(model_directory + "/4_lead_ecg_model.pt")
    model.eval()
    model.to(device)
    return model


################################################################################
#
# Running trained model functions
#
################################################################################

"""
# Generic function for running a trained model with the corresponding stack model
def run_model(model, header, recording): #run_model simple version: run_model_real
    num_leads = len(get_leads(header))
    data = get_stack_input(model, [0], [recording], [header], loaded=True)
    stack_model = FFNN(feature_dims_stack,27)
    stack_model.load_state_dict(torch.load("model/" + str(num_leads) + "_lead_stack_model.pt"))
    stack_model.eval()
    stack_model.to(device)
    pred = stack_model(data)

    #Check if data flows correctly through the stack_input
    #pred = data
    #return relevant_classes, [int(pre) for pre in pred[0].tolist()], [int(pre) for pre in pred[0].tolist()]
    #pred = pred.reshape(len(relevant_classes), 2)

    labels = [1 - x for x in pred.argmax(dim=1).tolist()]
    if any([False if label == net_pred else True for label,net_pred in zip(list(labels),data[0].tolist())]):
        print("Potential problem")
        pass
    preds = [F.softmax(row, dim=0)[0].item() for row in pred]

    # Lower clip
    predictions = []
    for p in preds:
        if p >= 0.1:
            predictions.append(p)
        else:
            predictions.append(0)

    return relevant_classes, labels, predictions
"""

def run_model(model, header, recording): #run_cnn_alone # simple version: run_model
    age, sex = get_age_and_sex(header)

    aux_data = np.array([age, sex])
    frequency = get_frequency(header)
    # My Pipeline
    normed_recording = preprocess_recordings([recording], [frequency], verbose=False)
    data, _, aux_data = build_ts_data(normed_recording, [0], aux_data, verbose=False)
    data = data.type(torch.FloatTensor)
    data = data.to(device)
    pred = model(data)

    pred = pred.reshape(data.shape[0], len(relevant_classes * 2))
    pred = torch.mean(pred, axis=0)

    pred = pred.reshape(len(relevant_classes), 2)

    labels = [1 - x for x in pred.argmax(dim=1).tolist()]

    preds = [F.softmax(row, dim=0)[0].item() for row in pred]

    # Lower clip
    predictions = []
    for p in preds:
        if p >= 0.1:
            predictions.append(p)
        else:
            predictions.append(0)

    return relevant_classes, labels, predictions

