from __future__ import unicode_literals, print_function, division
from io import open
import os
from scipy.io import loadmat
import numpy as np
import torch
from torch.utils import data
from torch.utils.data import DataLoader
import torch.nn as nn
from sklearn import preprocessing

from Scoring_file import load_weights
from metrics import f2_loss

n_leads = 12

list_lead = range(n_leads)
train_test_ratio = 0.8

epochs_number = 26
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print_every = 1

clip_grad = 1
loss_weights = torch.tensor([2.0, 18.0, 18.0, 3.0, 4.0, 1.0, 4.0, 6.0, 16.0, 10.0, 4.0, 6.0, 3.0, 19.0, 10.0, 6.0,
                             13.0, 2.0, 5.0, 3.0, 3.0, 2.0, 5.0, 1.0])

dropout_cnn_final = 0.3
hidden_size_final = 256
max_length = 5000

# criterion = nn.NLLLoss()
# criterion = nn.MultiLabelSoftMarginLoss()
criterion = f2_loss
# criterion = nn.CrossEntropyLoss()
second_criterion = nn.MultiLabelSoftMarginLoss()
# third_criterion = nn.BCEWithLogitsLoss(reduction='none')
third_criterion = nn.BCEWithLogitsLoss(pos_weight=loss_weights, reduction='sum')
# loss_weights = torch.tensor([0.0, 180.0, 180.0, 0.0, 0.0, 0.0, 0.0, 0.0, 160.0, 10.0, 400.0, 600.0, 300.0, 190.0,
# 100.0, 600.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0])
learning_rate = 0.0005
# not using it for now
weight_decay = 0.003

# load_model_name = "rnn_0.4037399206288249.pt"
load_model_name = "rnn_final.pt"

test_pathologies = ['164889003',
                    '164890007',
                    '426627000',
                    '713427006',
                    '270492004',
                    '713426002',
                    '445118002',
                    '39732003',
                    '164909002',
                    '251146004',
                    '698252002',
                    '10370003',
                    '284470004',
                    '427172004',
                    '164947007',
                    '111975006',
                    '164917005',
                    '47665007',
                    '59118001',
                    '427393009',
                    '426177001',
                    # '426783006',
                    '427084000',
                    '63593006',
                    '164934002',
                    '59931005',
                    '17338001']
annot_to_patho = {'164889003': 'AF',
                  '164890007': 'AFL',
                  '426627000': 'Brady',
                  '713427006': 'RBBB',
                  '270492004': 'IAVB',
                  '713426002': 'IRBBB',
                  '445118002': 'LAnFB',
                  '39732003': 'LAD',
                  '164909002': 'LBBB',
                  '251146004': 'LQRSV',
                  '698252002': 'NSIVCB',
                  '10370003': 'PR',
                  '284470004': 'PAC',
                  '427172004': 'PVC',
                  '164947007': 'LPR',
                  '111975006': 'LQT',
                  '164917005': 'QAb',
                  '47665007': 'RAD',
                  '59118001': 'RBBB',
                  '427393009': 'SA',
                  '426177001': 'SB',
                  # '426783006': 'SNR',
                  '427084000': 'STach',
                  '63593006': 'PAC',
                  '164934002': 'TAb',
                  '59931005': 'TInv',
                  '17338001': 'PVC'}
list_pathologies = ['AF',
                    'AFL',
                    'Brady',
                    'IAVB',
                    'IRBBB',
                    'LAD',
                    'LAnFB',
                    'LBBB',
                    'LPR',
                    'LQRSV',
                    'LQT',
                    'NSIVCB',
                    'PAC',
                    'PAC',
                    'PR',
                    'PVC',
                    'PVC',
                    'QAb',
                    'RAD',
                    'RBBB',
                    'RBBB',
                    'SA',
                    'SB',
                    'STach',
                    'TAb',
                    'TInv']

compute_score = ['164889003',
                 '164890007',
                 '426627000',
                 '270492004',
                 '713426002',
                 '445118002',
                 '39732003',
                 '164909002',
                 '251146004',
                 '698252002',
                 '10370003',
                 '284470004',
                 '427172004',
                 '164947007',
                 '111975006',
                 '164917005',
                 '47665007',
                 '59118001',
                 '427393009',
                 '426177001',
                 '426783006',
                 '427084000',
                 '164934002',
                 '59931005']
encodeur = preprocessing.LabelEncoder()
# threshold = torch.Tensor([0.4])

treshold_23 = 0.5
treshold_snr = 0.4

# Changed in the competition
n_pathologies = 23
weights = load_weights("weights.csv", compute_score)

train_losses_all_epoch, test_losses_all_epoch, train_challenge_all_epoch, test_challenge_all_epoch = [], [], [], []

load_from_last = False
name_fig = 1

# AF	AFL	Brady	IAVB	IRBBB	LAD	LAnFB	LBBB	LPR	LQRSV	LQT	NSIVCB	PAC	PAC	PR	PVC	PVC	QAb	RAD	RBBB
# RBBB	SA	SB	SNR	STach	TAb	TInv AF	AFL	Brady	IAVB	IRBBB	LAD	LAnFB	LBBB	LPR	LQRSV	LQT	NSIVCB	PAC
# PR	PVC	QAb	RAD	RBBB	SA	SB	STach	TAb	TInv SNR
