import os
from scipy.io import loadmat
import numpy as np
import pandas as pd
import keras
from keras.models import Sequential, load_model
from keras.layers import LSTM, GRU, TimeDistributed, Bidirectional, LeakyReLU, BatchNormalization
from keras.layers import Dense, Dropout, Activation, Flatten,  Input, Reshape, GRU, CuDNNGRU, MaxPooling1D
from keras.layers import Convolution1D, MaxPool1D, GlobalAveragePooling1D,concatenate,AveragePooling1D,Conv1D
from keras.models import Model
from sklearn.preprocessing import MinMaxScaler
from keras.models import model_from_json
from scipy.signal import resample
from sklearn.model_selection import train_test_split
from keras.callbacks import EarlyStopping, ModelCheckpoint

def ResNet_model(WINDOW_SIZE,INPUT_FEAT,OUTPUT_CLASS,LAYER):
    # Add CNN layers left branch (higher frequencies)
    # Parameters from paper
    k = 1    # increment every 4th residual block
    p = True # pool toggle every other residual block (end with 2^8)
    convfilt = 64
    convstr = 1
    ksize = 16
    poolsize = 2
    poolstr  = 2
    drop = 0.5
    
    # Modelling with Functional API
    #input1 = Input(shape=(None,1), name='input')
    input1 = Input(shape=(WINDOW_SIZE,INPUT_FEAT), name='input')
    
    ## First convolutional block (conv,BN, relu)
    x = Conv1D(filters=convfilt,
               kernel_size=ksize,
               padding='same',
               strides=convstr,
               kernel_initializer='he_normal')(input1)                
    x = BatchNormalization()(x)        
    x = Activation('relu')(x)  
    
    ## Second convolutional block (conv, BN, relu, dropout, conv) with residual net
    # Left branch (convolutions)
    x1 =  Conv1D(filters=convfilt,
               kernel_size=ksize,
               padding='same',
               strides=convstr,
               kernel_initializer='he_normal')(x)      
    x1 = BatchNormalization()(x1)    
    x1 = Activation('relu')(x1)
    x1 = Dropout(drop)(x1)
    x1 =  Conv1D(filters=convfilt,
               kernel_size=ksize,
               padding='same',
               strides=convstr,
               kernel_initializer='he_normal')(x1)
    x1 = MaxPooling1D(pool_size=poolsize,
                      strides=poolstr)(x1)
    # Right branch, shortcut branch pooling
    x2 = MaxPooling1D(pool_size=poolsize,
                      strides=poolstr)(x)
    # Merge both branches
    x = keras.layers.add([x1, x2])
    del x1,x2
    
    ## Main loop
    p = not p 
    for l in range(LAYER):
        
        if (l%4 == 0) and (l>0): # increment k on every fourth residual block
            k += 1
             # increase depth by 1x1 Convolution case dimension shall change
            xshort = Conv1D(filters=convfilt*k,kernel_size=1)(x)
        else:
            xshort = x        
        # Left branch (convolutions)
        # notice the ordering of the operations has changed        
        x1 = BatchNormalization()(x)
        x1 = Activation('relu')(x1)
        x1 = Dropout(drop)(x1)
        x1 =  Conv1D(filters=convfilt*k,
               kernel_size=ksize,
               padding='same',
               strides=convstr,
               kernel_initializer='he_normal')(x1)        
        x1 = BatchNormalization()(x1)
        x1 = Activation('relu')(x1)
        x1 = Dropout(drop)(x1)
        x1 =  Conv1D(filters=convfilt*k,
               kernel_size=ksize,
               padding='same',
               strides=convstr,
               kernel_initializer='he_normal')(x1)        
        if p:
            x1 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(x1)                

        # Right branch: shortcut connection
        if p:
            x2 = MaxPooling1D(pool_size=poolsize,strides=poolstr)(xshort)
        else:
            x2 = xshort  # pool or identity            
        # Merging branches
        x = keras.layers.add([x1, x2])
        # change parameters
        p = not p # toggle pooling

    
    # Final bit    
    x = BatchNormalization()(x)
    x = Activation('relu')(x) 
    x = Flatten()(x)
    #x = Dense(1000)(x)
    #x = Dense(1000)(x)
    out = Dense(OUTPUT_CLASS, activation='softmax')(x)
    model = Model(inputs=input1, outputs=out)
    model.compile(optimizer='adam',
                  loss='categorical_crossentropy',
                  metrics=['accuracy'])
    #model.summary()
    return model

def generator(input_files, final_labels, batch_size):
    min_batch = []
    labels = []
    i=0
    while True:
        for k, f in enumerate(input_files):
            data = load_challenge_data(f)
            # data, header_data = load_challenge_data(f)
            # if np.argmax(final_labels[k]) == 21:
            for j in range(0,data.shape[1],2500):
                segment_data = data[:,j:j+2500]
                if segment_data.shape[1] < 2500:
                    segment_data = data[:,-2500:]
#                     break
                min_batch.append(segment_data[:,:].T)        
                labels.append(final_labels[k])
                i = i + 1 
                if (i%batch_size == 0):
                    X = np.array(min_batch)
                    y = np.array(labels)
                    min_batch = []
                    labels = []
                    yield X,y

def save_the_model(filename, model):
    # serialize model to JSON
    model_json = model.to_json()
    with open(str(filename)+"/model.json", "w") as json_file:
        json_file.write(model_json)
    # serialize weights to HDF5
    model.save_weights(str(filename)+"/model.h5")
    # print("Saved model to disk")
    # model.save(str(filename)+"/full_model.h5")


def load_challenge_data(filename):
    filename = filename.replace('.hea', '.mat')
    x = loadmat(filename)
    data = np.asarray(x['val'], dtype=np.float64)
    return data
    # new_file = filename.replace('.mat', '.hea')
    # input_header_file = os.path.join(new_file)

    # with open(input_header_file, 'r') as f:
    #     header_data = f.readlines()

    # return data, header_data

def resample_signal(data, original_frequency, final_frequency=1000):
    final_data = []
    for i in range(12):
        signal = data[i,:]
        final_samples = int(final_frequency/original_frequency * signal.shape[0])
        final_data.append(resample(signal, final_samples))
    return np.array(final_data)

def write_history(filename, my_dict):
    with open(str(filename)+'/History.csv', 'w') as f:
        for key in my_dict.keys():
            f.write("%s,%s\n"%(key,my_dict[key]))

