import numpy as np, os
from scipy.io import loadmat
from numpy.random import randint
from numpy.random import rand

def challenge_metric(y_true,y_pred,dx_weights,inactive_index,threshold=0.5):
    def _compute_modified_cm(labels, outputs):
        normalization = np.maximum(np.sum(np.minimum(labels+outputs,1),axis=1,keepdims=True), 1)
        return labels.T@(outputs/normalization)
    y_true = np.array(y_true,np.float32)
    y_pred = np.array(y_pred,np.float32)>threshold
    
    inactive_outputs = np.zeros_like(y_true)
    inactive_outputs[:,inactive_index] = 1
    
    observed_score = _compute_modified_cm(y_true,y_pred)
    correct_score = _compute_modified_cm(y_true,y_true)
    inactive_score = _compute_modified_cm(y_true,inactive_outputs)

    correct_score = np.sum(correct_score*dx_weights)
    observed_score = np.sum(observed_score*dx_weights)
    inactive_score = np.sum(inactive_score*dx_weights)
    if correct_score != inactive_score:
        normalized_score = (observed_score - inactive_score) / (correct_score - inactive_score)
    else:
        normalized_score = 0.0
    return normalized_score

def get_p(header_file,leads,**kwargs):
    with open(header_file, 'r') as f:
        header = f.read().split('\n')
    mat_file = header_file.replace('.hea', '.mat')
    x = loadmat(mat_file)['val']
    recording = np.asarray(x, dtype=np.float32)
    p=Patient(header,recording,leads,**kwargs)
    return p

def get_y(label,y_map):
    y=np.zeros((1,max(y_map.values())+1))
    for i in label: 
        if i in y_map: y[0,y_map[i]]=1
    return y

class Patient:
    def __init__(self,header,recording,leads,training=True):
        first_line=header[0].split(' ')
        self.id=first_line[0]
        self.fs=int(first_line[2]) #sampling frequency
        self.integration_window=int(.15*self.fs) # for qrs detection
        self.length=int(first_line[3])

        self.gain=int(header[2].split(' ')[2].split('/')[0])#gain of lead 2
        self.adc_resolution = int(header[2].split(' ')[3])
        self.signal_length = int(first_line[3]) 
        self.baseline_offset = int(header[2].split(' ')[4])

        for iline in header:
            if iline.startswith('#Age'):
                tmp_age = iline.split(': ')[1].strip()
                self.age = int(tmp_age if tmp_age != 'NaN' else 57)
            elif iline.startswith('#Sex'):
                tmp_sex = iline.split(': ')[1]
                if tmp_sex.strip()=='Female':
                    self.sex =1
                else:
                    self.sex=0
            elif iline.startswith('#Dx') and training:
                self.label=list()
                for i in iline.split(': ')[1].split(','):
                    self.label.append(i)
        #get available leads
        available_leads = list()
        adc_gains = np.zeros(len(leads))
        baselines = np.zeros(len(leads))            
        for i,iline in enumerate(header):
            entries = iline.split(' ')
            if i==0:
                num_leads = int(entries[1])
            elif i<=num_leads:
                available_leads.append(entries[-1])
                current_lead = entries[-1]
                if current_lead in leads:
                    j = leads.index(current_lead)
                    try:
                        adc_gains[j] = float(entries[2].split('/')[0])
                    except:
                        adc_gains[j] = 1000
                    try:
                        baselines[j] = float(entries[4].split('/')[0])
                    except:
                        baselines[j] = 0
            else:
                break

        # Reorder/reselect leads in recordings.
        indices = list()
        for lead in leads:
            i = available_leads.index(lead)
            indices.append(i)
        ecg = recording[indices, :]
        ecg = ecg - baselines.reshape(-1,1)
        ecg = ecg / adc_gains.reshape(-1,1)
        self.ecg = ecg
    
    def filter(self):
        order=4;lowcut=1;highcut=30
        b, a = butter(order, [lowcut, highcut], btype='band',fs=self.fs)
        self.ecg=filtfilt(b, a, self.ecg)

    def add_noise(self,mean_snr=10,sd_snr=5/3):
        '''
        Add noise to ecg recording
        params:
        mean : target snr mean in dB
        sd : target snr standard deviation in dB
        '''
        target_snr_db = np.random.normal(mean_snr,sd_snr)
        x = self.ecg[1]
        signal_avg_db = 10 *np.log10(np.mean(x ** 2))
        noise_avg_db = signal_avg_db - target_snr_db
        noise_avg_watts = 10 ** (noise_avg_db / 10)
        noise = np.random.normal(0, np.sqrt(noise_avg_watts), self.ecg.shape)
        self.ecg = self.ecg+noise

# Code from evaluate_model.py from https://github.com/physionetchallenges/evaluation-2021
# Load weights.
def load_weights(weight_file):
    # Load a table with row and column names.
    def is_number(x):
        try:
            float(x)
            return True
        except (ValueError, TypeError):
            return False
    def is_finite_number(x):
        if is_number(x):
            return np.isfinite(float(x))
        else:
            return False
    def load_table(table_file):
        # The table should have the following form:
        #
        # ,    a,   b,   c
        # a, 1.2, 2.3, 3.4
        # b, 4.5, 5.6, 6.7
        # c, 7.8, 8.9, 9.0
        #
        table = list()
        with open(table_file, 'r') as f:
            for i, l in enumerate(f):
                arrs = [arr.strip() for arr in l.split(',')]
                table.append(arrs)

        # Define the numbers of rows and columns and check for errors.
        num_rows = len(table)-1
        if num_rows<1:
            raise Exception('The table {} is empty.'.format(table_file))
        row_lengths = set(len(table[i])-1 for i in range(num_rows))
        if len(row_lengths)!=1:
            raise Exception('The table {} has rows with different lengths.'.format(table_file))
        num_cols = min(row_lengths)
        if num_cols<1:
            raise Exception('The table {} is empty.'.format(table_file))

        # Find the row and column labels.
        rows = [table[0][j+1] for j in range(num_rows)]
        cols = [table[i+1][0] for i in range(num_cols)]

        # Find the entries of the table.
        values = np.zeros((num_rows, num_cols), dtype=np.float64)
        for i in range(num_rows):
            for j in range(num_cols):
                value = table[i+1][j+1]
                if is_finite_number(value):
                    values[i, j] = float(value)
                else:
                    values[i, j] = float('nan')

        return rows, cols, values
    # Load the table with the weight matrix.
    rows, cols, values = load_table(weight_file)

    # Split the equivalent classes.
    rows = [set(row.split('|')) for row in rows]
    cols = [set(col.split('|')) for col in cols]
    assert(rows == cols)

    # Identify the classes and the weight matrix.
    classes = rows
    weights = values

    return classes, weights

class ChallengeMetricOpt:
    '''
    Objective function to optimize, negative challenge metric 
    '''
    def __init__(self,y_true,y_pred,dx_weights,inactive_index):
        self.y_true = np.array(y_true,np.float32)
        self.y_pred = np.array(y_pred,np.float32)
        self.dx_weights = dx_weights
        inactive_outputs = np.zeros_like(y_true)
        inactive_outputs[:,inactive_index] = 1
        correct_score = self._compute_modified_cm(y_true,y_true)
        inactive_score = self._compute_modified_cm(y_true,inactive_outputs)
        self.correct_score = np.sum(correct_score*dx_weights)
        self.inactive_score = np.sum(inactive_score*dx_weights)

    def _compute_modified_cm(self,labels, outputs):
        normalization = np.maximum(np.sum(np.minimum(labels+outputs,1),axis=1,keepdims=True), 1)
        return labels.T@(outputs/normalization)

    def __call__(self,threshold=0.5):
        threshold = np.array(threshold)
        y = self.y_pred>threshold

        observed_score = self._compute_modified_cm(self.y_true,y)
        observed_score = np.sum(observed_score*self.dx_weights)

        if self.correct_score != self.inactive_score:
            normalized_score = (observed_score - self.inactive_score) / (self.correct_score - self.inactive_score)
        else:
            normalized_score = 0.0
        return -normalized_score

# genetic algorithm search for continuous function optimization
# Code for GA from https://machinelearningmastery.com/simple-genetic-algorithm-from-scratch-in-python/
# decode bitstring to numbers
def decode(bounds, n_bits, bitstring):
	decoded = list()
	largest = 2**n_bits
	for i in range(len(bounds)):
		# extract the substring
		start, end = i * n_bits, (i * n_bits)+n_bits
		substring = bitstring[start:end]
		# convert bitstring to a string of chars
		chars = ''.join([str(s) for s in substring])
		# convert string to integer
		integer = int(chars, 2)
		# scale integer to desired range
		value = bounds[i][0] + (integer/largest) * (bounds[i][1] - bounds[i][0])
		# store
		decoded.append(value)
	return decoded
 
# tournament selection
def selection(pop, scores, k=3):
	# first random selection
	selection_ix = randint(len(pop))
	for ix in randint(0, len(pop), k-1):
		# check if better (e.g. perform a tournament)
		if scores[ix] < scores[selection_ix]:
			selection_ix = ix
	return pop[selection_ix]
 
# crossover two parents to create two children
def crossover(p1, p2, r_cross):
	# children are copies of parents by default
	c1, c2 = p1.copy(), p2.copy()
	# check for recombination
	if rand() < r_cross:
		# select crossover point that is not on the end of the string
		pt = randint(1, len(p1)-2)
		# perform crossover
		c1 = p1[:pt] + p2[pt:]
		c2 = p2[:pt] + p1[pt:]
	return [c1, c2]
 
# mutation operator
def mutation(bitstring, r_mut):
	for i in range(len(bitstring)):
		# check for a mutation
		if rand() < r_mut:
			# flip the bit
			bitstring[i] = 1 - bitstring[i]
 
# genetic algorithm
def genetic_algorithm(objective, bounds, n_bits, n_iter, n_pop, r_cross, r_mut):
	# initial population of random bitstring
	pop = [randint(0, 2, n_bits*len(bounds)).tolist() for _ in range(n_pop)]
	# keep track of best solution
	best, best_eval = 0, objective(decode(bounds, n_bits, pop[0]))
	# enumerate generations
	for gen in range(n_iter):
		# decode population
		decoded = [decode(bounds, n_bits, p) for p in pop]
		# evaluate all candidates in the population
		scores = [objective(d) for d in decoded]
		# check for new best solution
		for i in range(n_pop):
			if scores[i] < best_eval:
				best, best_eval = pop[i], scores[i]
				# print(">%d, new best f(%s) = %f" % (gen,  decoded[i], scores[i]))
		# select parents
		selected = [selection(pop, scores) for _ in range(n_pop)]
		# create the next generation
		children = list()
		for i in range(0, n_pop, 2):
			# get selected parents in pairs
			p1, p2 = selected[i], selected[i+1]
			# crossover and mutation
			for c in crossover(p1, p2, r_cross):
				# mutation
				mutation(c, r_mut)
				# store for next generation
				children.append(c)
		# replace population
		pop = children
	return [best, best_eval]

def get_optim_thresholds(objective_function,class_count):
    # define range for input
    bounds = [[0.0, 1.0]]*26
    # define the total iterations
    n_iter = 100
    # bits per variable
    n_bits = 16
    # define the population size
    n_pop = 100
    # crossover rate
    r_cross = 0.9
    # mutation rate
    r_mut = 1.0 / (float(n_bits) * len(bounds))
    # perform the genetic algorithm search
    best, score = genetic_algorithm(objective_function, bounds, n_bits, n_iter, n_pop, r_cross, r_mut)
    print('Done!')
    decoded = decode(bounds, n_bits, best)
    print('f(%s) = %f' % (decoded, score))
    return np.array(decoded)
