#!/usr/bin/env python

################################################################################
# Import libraries and functions. You can change or remove them.
################################################################################
import os
import numpy as np
#import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from sklearn.model_selection import train_test_split, KFold
from collections import defaultdict
from collections import Counter

from MyModel import *
from MyDataset import *
from model_prediction import *
from Data_Preprocess import *
from helper_code import *


# 每个epoch训练函数
def train(epoch, model, optimizer, train_loader, device, criterion):
    model.train()
    total_loss = 0
    n_entries = 0
    all_logits = []
    all_labels = []
    train_desc = "Epoch {:2d}: train - Loss: {:.6f}"#描述
    train_bar = tqdm(initial=0, leave=True, total=len(train_loader), desc=train_desc.format(epoch, 0), position=0)
    for i, batch in enumerate(train_loader):
        data, target = batch
        data = data.to(device)
        target = target.to(device)
        model.zero_grad()
        output = model(data)
        # logits_sigmoid = torch.sigmoid(output)
        logits_softmax = torch.softmax(output,dim=1)#该是第0维内部和为1. 应为dim=1
        loss = criterion(output, target)
        loss.backward()
        # if 0 == ((i+1) % 2):
        #     optimizer.step()
        #     optimizer.zero_grad()
        optimizer.step()
        batch_loss = loss.detach().cpu().numpy()
        total_loss += batch_loss
        bs = target.size(0)
        n_entries += bs
        all_logits.append(logits_softmax.detach().cpu())
        all_labels.append(target.cpu())
        train_bar.desc = train_desc.format(epoch, total_loss / n_entries)
        train_bar.update(1)
    train_bar.close()

    return total_loss / n_entries, torch.cat(all_labels), torch.cat(all_logits)#返回平均loss 所有lable和所有输出


# 每个epoch验证函数
def validate(epoch, model, valid_loader, device, criterion):
    model.eval()
    total_loss = 0
    n_entries = 0
    all_logits = []
    all_labels = []
    eval_desc = "Epoch {:2d}: valid - Loss: {:.6f}"
    eval_bar = tqdm(initial=0, leave=True, total=len(valid_loader),
                    desc=eval_desc.format(epoch, 0), position=0)
    for i, batch in enumerate(valid_loader):
        with torch.no_grad():
            data, target = batch
            data = data.to(device)
            target = target.to(device)
            logits = model(data)
            loss = criterion(logits, target)
            total_loss += loss.item()
            logits_sigmoid = torch.sigmoid(logits)
            # logits_softmax = torch.softmax(logits,dim=1)
            all_logits.append(logits_sigmoid.detach().cpu())
            all_labels.append(target.cpu())
            bs = data.size(0)
            n_entries += bs
            eval_bar.desc = eval_desc.format(epoch, total_loss / n_entries)
            eval_bar.update(1)
    eval_bar.close()
    return total_loss / n_entries, torch.cat(all_labels), torch.cat(all_logits)



################################################################################
#
# Required functions. Edit these functions to add your code, but do not change the arguments.
#
################################################################################
def train_challenge_model(data_folder, model_folder, verbose):
    # Find data files.
    if verbose >= 1:
        print('Finding data files...')

    # Find the patient data files.
    patient_files = find_patient_files(data_folder)
    num_patient_files = len(patient_files)

    if num_patient_files == 0:
        raise Exception('No data was provided.')

    # Create a folder for the model if it does not already exist.
    os.makedirs(model_folder, exist_ok=True)

    # Extract the features and labels.
    if verbose >= 1:
        print('Extracting features and labels from the Challenge data...')

    ############# 读取数据，得到特征和标签 ################
    totalFeatures, totalMurmurs, totalOutcomes = getDataFeatures(data_folder, patient_files, verbose=2)

    '''统计种类'''
    presentcount = unknowcount = absentcount = 0
    for i in totalMurmurs:
        presentcount += (i.tolist()==[1,0,0])
        unknowcount += (i.tolist() == [0,1,0])
        absentcount += (i.tolist() == [0, 0, 1])
    print('Count_types')
    print(presentcount,unknowcount,absentcount)

    abnormcount = normalcount  = 0
    for i in totalOutcomes:
        abnormcount += (i.tolist()==[1,0])
        normalcount += (i.tolist() == [0,1])
    print(abnormcount,normalcount)
    ''' end '''

    ############ 数据归一化处理 ###########
    uniform_features, uniform_Murmurs, uniform_Outcomes, theMean, theStd = data_uniform(totalFeatures, totalMurmurs, totalOutcomes)
    print(uniform_features.shape, uniform_Murmurs.shape, uniform_Outcomes.shape)


    # 设备选择
   # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0")
    print(device)
    # ***************************Murmur模型****************************

    #murmur网络超参数
    batch_size_murmur = 10
    learning_rate_murmur = 0.00015
    num_epochs_murmur = 50
    num_class_murmur = len(murmur_classes)

    # 加载数据至Dataloader接口
    dataset_train = PCG_Dataset(uniform_features, uniform_Murmurs)
    train_loader = DataLoader(dataset_train, batch_size=batch_size_murmur, shuffle=True, num_workers=0)

    # 损失函数设置
    # loss_func = Focal_Loss(alpha=[0.7,0.5,0.4])
    loss_func = FocalLoss(logits=True)
    # 实例化model
    murmur_model = VGG_11(in_channel=1,num_classes=num_class_murmur)

    for param in murmur_model.parameters():  # 使用uniform distribution初始化模型参数
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)
    # 转移模型
    murmur_model = murmur_model.to(device)
    # 优化器设置
    optimizer = torch.optim.Adam(murmur_model.parameters(), lr=learning_rate_murmur, betas=(0.9, 0.98), eps=1e-9)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 25, 32, 40, 45], gamma=0.7)

    # ***************************训练Murmur模型****************************
    for epoch in range(num_epochs_murmur):
        t_loss, t_labels, t_outputs = train(epoch + 1, murmur_model, optimizer, train_loader, device, loss_func)
        scheduler.step()#更新学习率

        # for i in range(len(t_outputs)):
        #     t_outputs[i] = (t_outputs[i] >=0.5).astype(np.int8)
        #
        # murmur_binary_outputs = enforce_positives(t_outputs, murmur_classes, 'Present')
    # *******原交叉验证分界线
    print('Saving model...')
    model_filename = 'Murmur_model.pt'
    root, _ = os.path.splitext(model_filename)
    torch.save(murmur_model, os.path.join(model_folder, root + '.pt'))

    MurmurDict = {'Murmur_classes': murmur_classes}
    json.dump(MurmurDict, open(os.path.join(model_folder, root + '.json'), 'w'))

    print('Murmur model Done.')
    ################################################################################
    ################################################################################
    # ***************************outcome模型****************************
    # outcome 网络超参数
    batch_size_outcome = 8
    learning_rate_outcome = 0.0002
    num_epochs_outcome = 50
    num_class_outcome = len(outcome_classes)

    # 加载数据至Dataloader接口
    dataset_train = PCG_Dataset(uniform_features, uniform_Outcomes)
    train_loader = DataLoader(dataset_train, batch_size=batch_size_outcome, shuffle=True, num_workers=0)

    # 其他参数设置
    #device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    device = torch.device("cuda:0")
    loss_func =Focal_Loss(alpha=0.6,num_classes=2, gamma=1)

    # 实例化model
    outcome_model = VGG_11(in_channel=1, num_classes=num_class_outcome)

    for param in outcome_model.parameters():  # 使用uniform distribution初始化模型参数
        if param.dim() > 1:
            nn.init.xavier_uniform_(param)
    outcome_model = outcome_model.to(device)

    optimizer = torch.optim.Adam(outcome_model.parameters(), lr=learning_rate_outcome, betas=(0.9, 0.98), eps=1e-9)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 25, 32, 40,45], gamma=0.7)

     # ***************************训练Outcome模型****************************

    for epoch in range(num_epochs_outcome):
        t_loss, t_labels, t_outputs = train(epoch + 1, outcome_model, optimizer, train_loader, device, loss_func)
        scheduler.step()  # 更新学习率

        # for i in range(len(t_outputs)):
        #     t_outputs[i] = (t_outputs[i] >=0.5).astype(np.int8)
        #
        # murmur_binary_outputs = enforce_positives(t_outputs, murmur_classes, 'Present')
    # *******原交叉验证分界线
    print('Saving model...')
    model_filename = 'Outcome_model.pt'
    root, _ = os.path.splitext(model_filename)
    torch.save(outcome_model, os.path.join(model_folder, root + '.pt'))

    OutcomeDict = {'Outcome_classes': outcome_classes}
    json.dump(OutcomeDict, open(os.path.join(model_folder, root + '.json'), 'w'))

    print('Outcome model Done.')
    ################################################################################

# Load your trained model. This function is *required*. You should edit this function to add your code,
# but do *not* change the arguments of this function.
def load_challenge_model(model_folder, verbose):
    # 加载Murmur/Outcome模型
    model_filename_murmur = 'Murmur_model.pt'
    model_filename_outcome = 'Outcome_model.pt'
    root_murmur, _ = os.path.splitext(model_filename_murmur)#分离文件名与扩展名
    root_outcome, _ = os.path.splitext(model_filename_outcome)#分离文件名与扩展名

    # Murmur_model = []
    # Outcome_model = []
    dirs = os.listdir(model_folder)
    for file in dirs:
        root_file, _ = os.path.splitext(file)
        if root_murmur == root_file:
            if str(file).endswith(".pt"):
                Murmur_model = torch.load(os.path.join(model_folder, file))
            elif str(file).endswith(".json"):
                Murmur_Dict = json.load(open(os.path.join(model_folder, file), 'r'))  # 读取模型辅助信息

        if root_outcome == root_file:
            if str(file).endswith(".pt"):
                Outcome_model = torch.load(os.path.join(model_folder, file))
            elif str(file).endswith(".json"):
                Outcome_Dict = json.load(open(os.path.join(model_folder, file), 'r'))  # 读取模型辅助信息

    Murmur_classes = Murmur_Dict['Murmur_classes']
    Outcome_classes = Outcome_Dict['Outcome_classes']

    return [Murmur_classes, Murmur_model, Outcome_classes, Outcome_model]


# Run your trained model. This function is *required*. You should edit this function to add your code,
# but do *not* change the arguments of this function. model, patient_data, recordings
#每个病人 执行一次这个函数
def run_challenge_model(model, data, recordings, verbose):

    murmur_classes = model[0]  # murmur classes
    murmur_classifier = model[1]  # murmur classifier
    outcome_classes = model[2]  # outcome classes
    outcome_classifier = model[3]  # outcome classifier

    # **************************加载测试数据、解析特征 **************************
    # Load features.
    current_frequencies = get_frequency(data)
    single_features, demographic_feature = get_PCG_features(data, recordings, current_frequencies)
    single_recording_features = np.asarray(single_features, dtype='float32')

    # **************************数据归一化处理 **************************

    theMean = np.mean(single_recording_features)
    theStd = np.std(single_recording_features)
    feats_uniform = (single_recording_features - theMean) / theStd

    feats_uniform[np.isnan(feats_uniform)] = 0.01
    feats_uniform[np.isinf(feats_uniform)] = 0.01

    #转化为张量 并且搬移到model所在device中
    feats_uniform=torch.from_numpy(feats_uniform)
    feats_uniform=feats_uniform.reshape(1,feats_uniform.size(0), feats_uniform.size(1))
    feats_uniform=feats_uniform.to(next(murmur_classifier.parameters()).device)

    # Get murmur classifier probabilities.
    murmur_classifier.eval()
    murmur_probabilities = myPredict(murmur_classifier, feats_uniform).to('cpu')
    murmur_probabilities=murmur_probabilities.detach().numpy()
    murmur_probabilities = np.asarray(murmur_probabilities, dtype=np.float32)[:, 0:3]
    murmur_probabilities_list = murmur_probabilities.tolist()[0]

    # Get outcome classifier probabilities.
    outcome_classifier.eval()
    outcome_probabilities = myPredict(outcome_classifier, feats_uniform).to('cpu')
    outcome_probabilities = outcome_probabilities.detach().numpy()
    outcome_probabilities = np.asarray(outcome_probabilities, dtype=np.float32)[:, 0:2]
    outcome_probabilities_list = outcome_probabilities.tolist()[0]

    # Choose label with highest probability.
    # #执行sigmoid和设置阈值
    # murmur_probabilities_list = torch.sigmoid(torch.tensor(murmur_probabilities_list)).tolist()
    # outcome_probabilities_list = torch.sigmoid(torch.tensor(outcome_probabilities_list)).tolist()
    #改成softmax
    murmur_probabilities_list = torch.softmax(torch.tensor(murmur_probabilities_list), dim=0).tolist()
    outcome_probabilities_list = torch.softmax(torch.tensor(outcome_probabilities_list), dim=0).tolist()

    print(murmur_probabilities_list)
    print(outcome_probabilities_list)

    #sigmoid之后 选0.5作为标签
    murmur_labels = np.zeros(len(murmur_classes), dtype=np.int_)
    # idx = np.argmax(murmur_probabilities_list)
    # murmur_labels[idx] = 1

    if np.argmax(murmur_probabilities_list) == 0:
        murmur_labels[0] = 1
    elif np.argmax(murmur_probabilities_list) == 1:
        murmur_labels[1] = 1
    elif np.argmax(murmur_probabilities_list) == 2:
        if murmur_probabilities_list[2]>0.75:
            murmur_labels[2] = 1
        else:
            murmur_labels[np.argmax(murmur_probabilities_list[:2])] = 1
    else:
        pass

    outcome_labels = np.zeros(len(outcome_classes), dtype=np.int_)
    # idx = np.argmax(outcome_probabilities_list)
    # outcome_labels[idx] = 1
    if np.argmax(outcome_probabilities_list) == 0:
        outcome_labels[0] = 1
    elif np.argmax(outcome_probabilities_list) == 1:
        if outcome_probabilities_list[1] > 0.51:
            outcome_labels[1] = 1
        else:
            outcome_labels[0] = 1
    else:
        pass
    # Concatenate classes, labels, and probabilities.
    classes = murmur_classes + outcome_classes
    labels = np.concatenate((murmur_labels, outcome_labels))
    probabilities = np.concatenate((murmur_probabilities_list, outcome_probabilities_list))

    return classes, labels, probabilities
