import scipy
import numpy as np
from matplotlib import pyplot as plt
# import prox_tv as ptv
import lib.data.process.input_lib as input_lib
import scipy.signal
def moving_average(signal, w_size=5):
    cur_sum = 0
    signal_avg = signal * 0
    for i in range(signal.size):
        cur_sum += signal[i]
        if i < w_size:
            signal_avg[i] = cur_sum / (i + 1) 
        else:
            cur_sum -= signal[i - w_size]
            signal_avg[i] = cur_sum / w_size
    return signal_avg


def moving_average(signal, w_size=5):
    cur_sum = 0
    signal_avg = signal * 0
    for i in range(signal.size):
        cur_sum += signal[i]
        if i < w_size:
            signal_avg[i] = cur_sum / (i + 1) 
        else:
            cur_sum -= signal[i - w_size]
            signal_avg[i] = cur_sum / w_size
    return signal_avg
def select_best_lead(signals, topk=5,
                     percentile_low=0,
                     percentile_high=50):
    best_ratio = 0
    best_ratio_idx = 0
    idx = 0
    for signal in signals:
        signal_calib = signal[1:] - signal[: -1] 
        signal_calib = signal_calib[signal_calib > 0]
        if signal_calib.size == 0:
            idx += 1
            continue
        signal_sorted = np.sort(signal_calib)
        topk_peaks_min = np.min(signal_sorted[-topk:])
        signal_sorted = signal_sorted / topk_peaks_min
        start = signal.size * percentile_low // 100
        end = signal.size * percentile_high // 100
        range_average = np.mean(signal_sorted[start: end])
        if range_average and topk_peaks_min / range_average > best_ratio:
            best_ratio = topk_peaks_min / range_average
            best_ratio_idx = idx
        idx += 1
    return best_ratio_idx
          

def get_approx_early_peak(signal,
                          window_start,
                          window_end,
                          max_tolerance=0.4,
                          epsi=50,
                          verbose=False):
    """Gets the approximate earliest peak in the window.
    
    Args:
        signal: 1-d array.
        window_start: start of the window.
        window_end: end of the window.
        epsi: min-window of [peak - window, peak + window] to get the local 
            maxima.
    Returns:
        index of the earliest approximate peak.
    """ 
    max_idxs = np.argsort(signal[window_start:window_end])[::-1] + window_start
    feasible_max_idxs_size = 0
    for idx in max_idxs: 
        if signal[idx] < 0:
            break
        # Divide by max
        ratio = signal[idx] / signal[max_idxs[0]]
        if  ratio < max_tolerance:
            break
        feasible_max_idxs_size += 1
    min_max_idx = np.min(max_idxs[0: feasible_max_idxs_size])    
    mini_window_start = max(window_start, min_max_idx - epsi)
    mini_window_end = min(window_end, min_max_idx + epsi)
    min_max_idx_adjusted = np.argmax(
        signal[mini_window_start: mini_window_end]
    ) + mini_window_start
    
    if verbose:
        plt.plot(signal)
        plt.show()
        print(f"Before adjusting {min_max_idx}")
        print(f"After adjusting {min_max_idx_adjusted}")
    return min_max_idx_adjusted


def get_peaks(signal,
              prom_to_max_ratio=0.6,
              fs=500,
              min_hr=30,
              max_hr=240,
              peak_width_to_hr_frac=0.5,
              autocorr_max_tolerance=0.5,
              topk=5,
              verbose=False,
              gamma=1):
    """Returns a numpy array of peak indices.
    
    Args:
        signal: 1-d array of signal values.
        prom_to_max_ratio: allowed prominance as a ratio w.r.t the max signal
            value.
        min_hr: minimum acceptable HR, also affects peak search.
        max_hr: maximum acceptable HR, also affects peak search.
        peak_width_to_hr_frac: This allows us to reduce the window between 
            peaks as a percentage of HR to take care of the case when heart
            rate varies.

    Returns:
        peak indices.
    """
    # Approximate prominance.
    signal_zero_mean = signal - np.mean(signal)
    # Could change lambda later.
    signal_calib = signal_zero_mean
    # signal_calib = ptv.tv1_1d(signal_zero_mean, gamma)
    signal_calib = signal_calib[1:] - signal_calib[: -1] 
    
    topk_peaks_min = np.min(np.sort(signal_calib)[-topk:])
    prominence = topk_peaks_min * prom_to_max_ratio
    
    # Autocorrelate signal with itself then talk only the positive part.
    signal_autocorr = np.correlate(signal_calib, 
                                   signal_calib,
                                   mode='same')


    signal_autocorr = signal_autocorr[signal_autocorr.size // 2:]
    # Fix heart rate between 30 and 240 per minute, this means a Peak should
    # appear every 0.25 sec to 2 sec.
    min_shift = fs * 60 // max_hr
    max_shift = fs * 60 // min_hr


    hr_hertz_approx = get_approx_early_peak(signal_autocorr,
                                            window_start=min_shift,
                                            window_end=max_shift)
    distance = max(min_shift, hr_hertz_approx * peak_width_to_hr_frac)
    peaks, _ = scipy.signal.find_peaks(
        signal_calib,
        prominence=prominence,
        distance=distance)
    peaks = peaks + 1
    
    if verbose:
        print("Hr adjusted: ", hr_hertz_approx)
        print("Min shift: ", min_shift)
        print("Max shift: ", max_shift)

    return peaks


def get_peak_feats(peaks,
                   min_hr=30,
                   max_hr=240):
    """Peaks are measured in seconds (already divided by sampling freq)."""
    beats_per_minute = (peaks.size - 1) * 60 / (peaks[-1] - peaks[0])
    if beats_per_minute < min_hr or beats_per_minute > max_hr:
        return None
    
    intervals_mean = 0
    intervals_std = 0
    intervals_diff_mean = 0
    intervals_diff_std = 0

    intervals = peaks[1:] - peaks[: -1]
    intervals_mean = np.mean(intervals)
    intervals_std = np.std(intervals)
    
    if peaks.size >= 4:
        intervals_diff = intervals[1:] - intervals[: -1]
        intervals_diff_mean = np.mean(intervals_diff)
        intervals_diff_std = np.std(intervals_diff)
    
    return np.array([beats_per_minute, intervals_mean, intervals_std,
                     intervals_diff_mean, intervals_diff_std])
        
    
    
#############################################################################
####################### Code below is used for testing ######################
#############################################################################

def validate_on_random_samples(data, header, n_samples, samples=[]):
    idxs = np.random.permutation(len(data))[0: n_samples]
    if samples:
        idxs = samples
    for idx in idxs:
        print("Signal index: ", idx)
        age, sex, num_leads, sample_Fs, gain_leads, resolution_leads = (
            input_lib.get_header_feats(header[idx])
        )
        data_idx = data[idx] - resolution_leads.reshape(-1, 1)
        data_idx = data_idx / gain_leads.reshape(-1, 1)
        w_start = 0
        w_end = 5000
        if data_idx[0].size > 5000:
            over = data_idx[0].size - 5000
            w_start = over // 2
        data_idx = data_idx[:, w_start: w_start + 5000]
        
        best_idx = select_best_lead(data_idx)
        print("Best lead: ", best_idx)

        y = data_idx[best_idx]
        peaks = get_peaks(y, gamma=0)
        x = np.arange(y.size)
        #model = np.polyfit(x, y, 40)
        #predicted = np.polyval(model, x)
        
        # y_filt = ptv.tv1_1d(y, 1000)
        
#         sos = scipy.signal.butter(1, [0.4, 4], 'bp', fs=500, output='sos')
#         y_filt = scipy.signal.sosfilt(sos, y)
#         plt.plot(y_filt)
#         plt.show()
        predicted = np.hstack((y_filt[1:], [0]))
#         y_filt = predicted
#         predicted = np.hstack((y_filt[1:], [0]))


        #plt.plot(y_filt - predicted)
#         plt.plot(scipy.signal.detrend(y, type='linear',
#                                       bp=peaks))
        plt.plot(y)
        plt.plot(peaks, y[peaks], "xr")
        #plt.legend()
        plt.show()


