#!/usr/bin/env python

import numpy as np, os
np.set_printoptions(linewidth=np.inf)
from sklearn.metrics import average_precision_score
from model import *
from dataset import dataset
from torch.utils.data import DataLoader
from tqdm import tqdm
import torch.optim as optim
import copy
from datetime import datetime
from evaluate_12ECG_score import compute_challenge_metric,load_weights
from scipy.optimize import differential_evolution
def _challenge_metric(labels, probs_outputs):
    weights_file = 'weights.csv'
    weights = load_weights(weights_file, dataset.CLASSES)
    normal_class = '426783006'
    score = list()
    for th in range(1,99):
        threshold = float(th/100)
        binary_outputs = probs_outputs >= threshold
        y = compute_challenge_metric(weights, labels, binary_outputs, dataset.CLASSES, normal_class)
        score.append((y,threshold))

    # for i in score:
    #     print(i)
    best = max(score,key=lambda item: item[0])
    # print(best)
    return best


class opt:
    def __init__(self,probs_outputs,labels):
        self.labels=labels
        self.probs_outputs = probs_outputs
        weights_file = 'weights.csv'
        self.weights = load_weights(weights_file, dataset.CLASSES)
        self.normal_class = '426783006'

    def __call__(self,x):
        # x = np.around(x,3)
        binary_outputs = self.probs_outputs >= x
        y = compute_challenge_metric(self.weights, self.labels, binary_outputs, dataset.CLASSES, self.normal_class)
        return -y

def _challenge_metric_GENopt(labels,probs_outputs,threshold):
    popsize = 24
    x0 = np.random.rand(popsize*24,24)
    x0[0,:] = threshold
    x0[1:12] = threshold + 0.1 * np.random.randn(11,24)
    x0[x0<0] = 0
    x0[x0>1] = 1
    result = differential_evolution(opt(probs_outputs,labels), [(0,1) for i in range(24)],workers=-1,disp=True,popsize=popsize,init=x0)
    print(result.x, result.fun)

class ModelSaver:
    def __init__(self,model_directory):
        self.model = None
        self.metrics = 0
        self.epoch = 0
        self.auprc = None
        self.threshold=None
        self.model_directory = model_directory
        if self.model_directory[-1] != "/":
            self.model_directory += "/"
        if not os.path.exists(self.model_directory):
            os.makedirs(self.model_directory)

    def append(self,model,metrics,epoch,classes,auprc,threshold,dump=False):
        if metrics > self.metrics:
            print("updating best model")
            self.model = copy.deepcopy(model)
            self.model = self.model.cpu()
            self.epoch = epoch
            self.classes = classes
            self.metrics = metrics
            self.auprc = auprc
            self.threshold = threshold
            if dump:
                self.dump()

    def dump(self,custom_name=True):

        modelname = self.model_directory + "model"
        if custom_name:
            modelname = self.model_directory+"model_{}".format(str(self.epoch).zfill(3))

        torch.save({"state_dict":self.model.state_dict(),
                    "classes": self.classes,
                    "thresholds":self.threshold,
                    "metrics": self.metrics,
                    "auprc": self.auprc,
                    "epoch": self.epoch,
                    "date": datetime.now().strftime("%m-%d-%Y %H:%M:%S")}
                    ,modelname)




def load_input_data(input_directory):
    # This part of code was updated for our purpose
    print('Loading data from {}'.format(input_directory))

    header_files = []
    for f in os.listdir(input_directory):
        g = os.path.join(input_directory, f)
        if not f.lower().startswith('.') and f.lower().endswith('hea') and os.path.isfile(g):
            header_files.append(g)

    return header_files


def my_collate(batch):
    ch = batch[0][0].shape[0]
    maxL = max([b[0].shape[-1] for b in batch])
    maxL = 250*60*2 if maxL<250*60*2 else maxL
    X = np.zeros((len(batch),ch,maxL))
    for i in range(len(batch)):
        X[i,:,-batch[i][0].shape[-1]:] = batch[i][0]
    t = np.array([b[1] for b in batch])
    X = torch.from_numpy(X)
    t = torch.from_numpy(t)
    return X,t


def train_12ECG_classifier(input_directory, output_directory):
    header_files = load_input_data(input_directory)

    #dataset
    train,valid,cweights = dataset.train_test_split(header_files,test_eval=False)

    NWORKS = 8
    train = DataLoader(dataset=train,
                    batch_size=64,
                    shuffle=True,
                    num_workers=NWORKS,
                    collate_fn=my_collate)

    valid = DataLoader(dataset=valid,
                    batch_size=64,
                    shuffle=True,
                    num_workers=NWORKS,
                    collate_fn=my_collate)

    #model
    device = 'cuda:0' #if torch.cuda.is_available() else 'cpu'
    model = NN(device=device,nOUT=24)
    saver = ModelSaver(output_directory)
    # torch.from_numpy(cweights).to(device)
    lossBCE = nn.BCEWithLogitsLoss()
    weights_file = 'weights.csv'
    weights = load_weights(weights_file, dataset.CLASSES)
    normal_class = '426783006'
    normal_index = np.argwhere(np.array(dataset.CLASSES)==normal_class)[0][0]
    opt = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-6)
    #opt = optim.SGD(model.parameters(),lr=1e-3,weight_decay=1e-6)
    #scheduler = optim.lr_scheduler.StepLR(opt, step_size=15, gamma=0.1)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(opt, 'max',factor=0.5,patience=7,threshold=0.02,verbose=True)
    #scheduler = optim.lr_scheduler.OneCycleLR(opt, max_lr=0.01, steps_per_epoch=len(train), epochs=75)
    # Train the classifier
    for epoch in range(75):
        model.train()
        for i,(x,t) in enumerate(tqdm(train)):
            x = x.unsqueeze(2).to(device).float()
            t = t.to(device)
            opt.zero_grad()

            y = model(x)
            # j0 = loss(output=y, target=t)
            J = lossBCE(input=y,target=t)
            J.backward()

            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1)
            opt.step()
            #scheduler.step()

        outputs=[]
        targets=[]
        model.eval()
        for i,(x,t) in enumerate(tqdm(valid)):
            x = x.unsqueeze(2).to(device).float()
            t = t.to(device)
            opt.zero_grad()
            y = model(x)
            p = torch.sigmoid(y).data.cpu().numpy()
            outputs.append(p)
            targets.append(t.data.cpu().numpy())
        outputs = np.concatenate(outputs,axis=0)
        targets = np.concatenate(targets,axis=0)

        # metrics
        auprc = average_precision_score(y_true=targets,y_score=outputs,average=None)
        auprc_mean = np.nanmean(auprc)
        auprc = np.around(auprc,2)

        score,threshold = _challenge_metric(labels=targets,probs_outputs=outputs)
        print(epoch,"\tAUPRC:",auprc_mean,"\tCHALLENGE:",score,"\tTHRESHOLD:",threshold)
        print(auprc)
        print(cweights)

        # if score >= 0.50:
        #     _challenge_metric_GENopt(labels=targets,probs_outputs=outputs,threshold=threshold)

        #decrease lr
        scheduler.step(score)

        # save model
        saver.append(model,score,epoch,dataset.CLASSES,auprc,threshold,dump=False)
    saver.dump(custom_name=False)

