import logging
import tensorflow as tf
import tensorflow.keras.backend as K
from tensorflow.keras import Model,callbacks
from tensorflow.keras.layers import Input
from custom_layers import *

from utils import *
import numpy as np

def build_model(class_count,lead_count=12):
    def res_block(x,units=12,kernel_size=[3,4,5]):
        
        shortcut = x
        shortcut = MaskedConv1D(units=units*3, kernel_size=3, strides=1, activation='relu',use_batch_norm=False)(shortcut)

        x = Conv1DInception(units=units,kernel_size = kernel_size)(x)
        x = Conv1DInception(units=units,kernel_size = kernel_size)(x)
        x = Conv1DInception(units=units,kernel_size = kernel_size)(x)

        x = tf.keras.layers.Add()([shortcut, x])
        x = tf.keras.layers.Activation('relu')(x)
        x = MaskedMaxPooling1D(pool_size=3, strides=3, padding='same')(x)
        #x = ChannelSelfAttention1D_v2(32,5,8)(x)
        x = ChannelSelfAttention1D(32,5)(x)
        return x

    x_input = Input(shape = (None,lead_count))
    x = tf.keras.layers.Masking(mask_value=0.0)(x_input)
    #x = ChannelSelfAttention1D_v2(32,5,8)(x)
    x = ChannelSelfAttention1D(32,5)(x)
    for u in [12,24,24,48,48]:
        x = res_block(x,units = u, kernel_size=[3,4,5])
    x = AttentionPooling(class_count,32)(x)
    x = Predict(class_count)(x)

    return Model(inputs = x_input, outputs = x)

class MakeDataset:
    def __init__(self,y_map,freq=125,leads=None):
        self.y_map = y_map
        self.leads = leads
        self.num_leads = len(leads)
        self.freq = freq # 1: for 500Hz, 4 for 125 

    def __call__(self,header_files,batch_size=128,shuffle=False,add_noise=False,prefetch_size=None):

        def map_decorator(func):
            def wrapper(*args):
                # Use a tf.py_function to prevent auto-graph from compiling the method
                x,label = tf.py_function(func,inp=[*args],Tout=(tf.float32,tf.bool))
                x.set_shape([None,None,self.num_leads])
                label.set_shape([None,None])
                return x,label
            return wrapper

        @map_decorator
        def get_XY(header_files):
            X=[]
            ecg_length=[]
            Y=[]
            for header_file in header_files.numpy():
                p=get_p(header_file=header_file.decode("utf-8"),leads =self.leads,training=True)
                # dr = self.downsample_ratio*(2 if p.id[0]=='S' else 1)
                dr = int(round(p.fs/self.freq))
                p.ecg = p.ecg[:,::dr]
                if add_noise and np.random.rand()>1/3:
                    p.add_noise(mean_snr=10,sd_snr=5/3)
                # p.ecg shape (12,T)
                p.ecg = p.ecg.T
                # p.ecg shape (T,12)
                # crop signals if too long ( for I dataset)
                max_length = int(120*self.freq) #crop signals at 120s
                if p.ecg.shape[0]>max_length:
                    p.ecg = p.ecg[:max_length,:]
                X.append(p.ecg+1e-05)
                ecg_length.append(p.ecg.shape[0])
                Y.append(get_y(p.label,self.y_map))
            x = np.zeros((len(X),max(ecg_length),self.num_leads))
            y = np.vstack(Y)
            for i,xi,l in zip(range(len(X)),X,ecg_length):
                x[i,:l,:]=xi
            return x,y

        AUTOTUNE = tf.data.experimental.AUTOTUNE
        dataset = tf.data.Dataset.from_tensor_slices(header_files)
        if shuffle: dataset = dataset.shuffle(50000, reshuffle_each_iteration=True)
        dataset = dataset.batch(batch_size, drop_remainder=False)
        dataset = dataset.map(get_XY, num_parallel_calls=AUTOTUNE)
        if prefetch_size == None : prefetch_size = AUTOTUNE
        dataset = dataset.prefetch(buffer_size=prefetch_size)
        return dataset

class StopIfTrainingFails(tf.keras.callbacks.Callback):
    def __init__(self, monitor='val_loss', cutoff_epoch=10, min_value=0.2):
        super(StopIfTrainingFails, self).__init__()

        self.monitor = monitor
        self.min_value = min_value
        self.cutoff_epoch = cutoff_epoch

    def on_epoch_end(self, epoch, logs=None):
        current = self.get_monitor_value(logs)
        if current is None:
            return
        if epoch==self.cutoff_epoch and current<self.min_value:
            self.model.stop_training = True

    def get_monitor_value(self, logs):
        logs = logs or {}
        monitor_value = logs.get(self.monitor)
        if monitor_value is None:
            print('Early stopping conditioned on metric `%s` '
                        'which is not available. Available metrics are: %s',
                        self.monitor, ','.join(list(logs.keys())))
        return monitor_value