import scipy
import numpy as np
from matplotlib import pyplot as plt
# import prox_tv as ptv
from lib.data.process import peak_lib
from lib.data.process import fft_lib



def get_header_feats(header_data):
    tmp_hea = header_data[0].split(' ')
    ptID = tmp_hea[0]
    num_leads = int(tmp_hea[1])
    sample_Fs= int(tmp_hea[2])
    gain_lead = np.zeros(num_leads)
    resolution_lead = np.zeros(num_leads)


    for ii in range(num_leads):
        tmp_hea = header_data[ii+1].split(' ')
        gain_lead[ii] = int(tmp_hea[2].split('/')[0])
        resolution_lead[ii] = int(tmp_hea[4])
    # for testing, we included the mean age of 57 if the age is a NaN
    # This value will change as more data is being released
    for iline in header_data:
        if iline.startswith('#Age'):
            tmp_age = iline.split(': ')[1].strip()
            age = int(tmp_age if tmp_age != 'NaN' else 57)
        elif iline.startswith('#Sex'):
            tmp_sex = iline.split(': ')[1]
            if tmp_sex.strip()=='Female':
                sex =1
            else:
                sex=0
                
    return age, sex, num_leads, sample_Fs, gain_lead, resolution_lead


def get_record_labels(header_data):
    for iline in header_data:
        if iline.startswith('#Dx'):
            labels = iline.split(': ')[1]
            labels = labels.split(',')
    labels = [label.strip() for label in labels]
    return labels


def get_all_feats(signals, header, max_num_beats=10):
    """Returns None if we cannot extract features.
    
    Args:
        max_num_beats: trims signal to the specified number of beats.
    """
    # First extract header information
    age, sex, num_leads, sample_Fs, gain_leads, resolution_leads = (
        get_header_feats(header)
    )
    
    # Limit signals to max of 10 sec.
    max_pts = 10 * sample_Fs
    if signals[0].size > max_pts:
        extra = signals[0].size - max_pts
        # Start from the middle of the signal to avoid noisy early / late 0s.
        w_start = extra // 2
        signals = signals[:, w_start:]
    signals = signals[:, : max_pts]
    # Normalize.
    signals -= resolution_leads.reshape(-1, 1)
    signals /= gain_leads.reshape(-1, 1)
    # print(signals)
    signal_peaks = peak_lib.get_peaks(signals[0],
                                      fs=sample_Fs,
                                      min_hr=30,
                                      max_hr=240,
                                      peak_width_to_hr_frac=0.5,
                                      autocorr_max_tolerance=0.5,
                                      topk=5,
                                      verbose=False)
    
    if signal_peaks.size < 2:
        print("No peaks")
        return None
    
    # Calculation in get_peak_feats is done on per sec basis.
    peak_feats = peak_lib.get_peak_feats(signal_peaks / sample_Fs)
    
    if signal_peaks.size > max_num_beats + 1:
        signal_peaks = signal_peaks[: max_num_beats + 1]
    
    peak_delta_feats = (signal_peaks[1:] - signal_peaks[: -1]) / sample_Fs
    fft_feats = fft_lib.get_fft_feats(signals, 
                                      peaks=signal_peaks,
                                      fs=sample_Fs)
    
    header_feats = np.array([age, sex, sample_Fs])
    n_beats = peak_delta_feats.size
    
    feats = [header_feats, peak_feats, peak_delta_feats, fft_feats]
    if any(feat is None for feat in feats):
        return None
    all_feats = {
        # shape = [1, 3]
        'header_feats': header_feats.reshape(1, -1),
        # shape = [1, ??]
        'peak_feats': peak_feats.reshape(1, -1),
        # shape = [1, n_beats, 1]
        'peak_delta_feats': peak_delta_feats.reshape(1, n_beats, 1),
        # shape = [n_peaks, n_fft_feats]
        'fft_feats': fft_feats.reshape(1, n_beats, -1)
    }
    return all_feats