# Jan Pavlus

import os
from typing import Dict, List, Tuple

import numpy as np
from tqdm import tqdm
from skmultilearn.model_selection import iterative_train_test_split
import torch as th

from .utils import *

'''MODES 
    - All -  means that if the murmur is present 
             then all locations murmur GT is true

    - Forget - means that if the murmur is present, 
               then locations where the murmur is 
               not audible are forgotten and they 
               will be not used in training

    - Locate - means that if the murmur is present,
               then location where the murmur is 
               not audible are marked as absent.
'''
MODES = {"All": 0, "Forget": 1, "Locate": 2}

CLASS_MAP = {"Present": 0, "Unknown": 1, "Absent": 2}
REVERS_CLASS_MAP = {0: "Present", 1: "Unknown", 2: "Absent"}

OUTCOME_CLASS_MAP = {"Abnormal": 0, "Normal": 1}
REVERS_OUTCOME_CLASS_MAP = {0: "Abnormal", 1: "Normal"}

def get_outcome(data):
    label = None
    for l in data.split('\n'):
        if l.startswith('#Outcome:'):
            try:
                label = l.split(': ')[1]
            except:
                pass
    if label is None:
        raise ValueError('No outcome available. Is your code trying to load outcomes from the hidden data?')
    return label

def get_murmur_locations(data: str) -> List[str]:
    """Get murmur locations

    Args:
        data (str): Patient data

    Raises:
        ValueError: No patient data for murmur_locations

    Returns:
        List[str]: Murmur audilbe locations.
    """
    locations = None
    for l in data.split('\n'):
        if l.startswith('#Murmur locations:'):
            try:
                tmp = l.split(': ')[1]
                locations = tmp.split('+')
            except:
                pass
    if locations is None:
        raise ValueError(
            'No label available. Is your code trying to load labels from the hidden data?')
    return locations


def load_segmentation(tsv_file: str, wav_file: str) -> th.tensor:
    """Load segmentation from tsv file.

    Args:
        tsv_file (str): Path to tsv segmentation file.
        data (str): Patient data.
        wav_file (str): Path to wav file.

    Returns:
        th.tensor: _description_
    """
    # Load segmentation marks
    with open(tsv_file, "r") as file:
        seg_data = file.read()
    # Unify separators
    seg_data = seg_data.replace('\t', ' ')
    # Split to lines
    lines = seg_data.split('\n')
    # load signal from wav file
    wav, fs = load_wav_file(wav_file)
    # create segmentation array same shape as wav and set all values to zero
    segmentation = th.full_like(th.tensor(wav), 0)
    # for each segmentation stamp line
    for line in lines:
        entries = line.split(' ')
        # control if empty lane, continue
        if len(entries) < 3:
            continue
        # get from and to sample stamps from segmentation line
        from_sample = round(float(entries[0])*fs)
        to_sample = round(float(entries[1])*fs)
        # set all values between theese steps to value specified in segmentation line
        segmentation[from_sample:to_sample] = int(entries[2])
    return segmentation


def get_wav_tsv_files(data_folder: str,
                      data: str,
                      annotation: str,
                      outcome:str,
                      locations: List[str] = None,
                      inverse: bool = False) -> List[str]:
    """Get wav files with annotations. 
       If locations are specified then only from specified locations.

    Args:
        data_folder (str): Folder containing patient data files.
        data (str): Patient data.
        annotation (str): Annotation given for patient. Absent, Present or Unknown.
        locations (List[str], optional): Locations from where to get wav files. 
                                    Defaults to None.
        inverse (bool, optional): If locations are set and inverse is True, 
                                  then get recordings not present in locations list.
                                  Defaults to False.

    Returns:
        List[str]: _description_
    """
    num_locations = get_num_locations(data)
    recording_information = data.split('\n')[1:num_locations+1]
    annotated_recordings = list()
    for line in recording_information:
        entries = line.split(' ')
        # If recordings from specific locations are wanted - support inverse argument
        if locations:
            if (entries[0] not in locations and not inverse) or \
               (entries[0] in locations and inverse):
                continue
        # if no tsv file present for wav
        if len(entries) < 4:
            continue
        # create wav and tsv file paths
        wav_file = os.path.join(data_folder, entries[2])
        tsv_file = os.path.join(data_folder, entries[3])
        # load segmentation data
        segmentation_data = load_segmentation(tsv_file, wav_file)
        annotated_recordings.append({"wav": os.path.join(data_folder, entries[2]),
                                     "classnum": CLASS_MAP[annotation],
                                     "class": annotation,
                                     "outcomenum": OUTCOME_CLASS_MAP[outcome],
                                     "outcome": outcome,
                                     "tsv": segmentation_data})
    return annotated_recordings


def parse(data_folder: str,
          files: List,
          mode: int = 0,
          verbose: int = 1) -> Tuple[List[Dict],
                                     List[Dict]]:
    """Parse annotation files for patients from set given by parameter files.

    Args:
        data_folder (str): Folder with wav and other files.
        files (List): Files to be parsed.
        mode (int, optional): Mode used for data annotation. Defaults to 0.
        verbose (int, optional): Print to stdout if set higher then zero. Defaults to 1.

    Returns:
        Tuple[List[Dict], List[Dict]]: List of train and validation files
    """
    stats = {"Present":0, "Absent":0, "Unknown":0, "Abnormal":0, "Normal": 0}
    # for each file
    annotated_recordings = list()
    for file in tqdm(files) if verbose >= 1 else files:
        # load patient data
        data = load_patient_data(file)
        status = get_label(data)
        outcome = get_outcome(data)

        # if murmur is absent, unknown or present
        # and preparing type is set to all
        if status == "Absent" or status == "Unknown" \
           or (status == "Present" and mode == MODES["All"]):
            # then get for each location file name and gt to annotations
            tmp_recordings = get_wav_tsv_files(data_folder=data_folder,
                                                      data=data,
                                                      annotation=status,
                                                      outcome=outcome)
            stats[status] += len(tmp_recordings)
            stats[outcome] += len(tmp_recordings)
            annotated_recordings += tmp_recordings
        # if murmur is present
        else:
            # if preparing type is set to forget,
            # then save only files from locations
            # where murmur is audible to annotation
            if mode == MODES["Forget"]:
                present_murmur_locations = get_murmur_locations(data)
                tmp_recordings = get_wav_tsv_files(data_folder=data_folder,
                                                          data=data,
                                                          annotation=status,
                                                          locations=present_murmur_locations,
                                                          outcome=outcome)
                stats[status] += len(tmp_recordings)
                stats[outcome] += len(tmp_recordings)
                annotated_recordings += tmp_recordings
            # if preparing type is set to Locate, then save files from locations where murmur is audible to annotation with gt Present and files from other locations with gt Absent
            else:
                present_murmur_locations = get_murmur_locations(data)
                # Set present
                tmp_recordings = get_wav_tsv_files(data_folder=data_folder,
                                                          data=data,
                                                          annotation=status,
                                                          locations=present_murmur_locations,
                                                          outcome=outcome)
                stats[status] += len(tmp_recordings)
                stats[outcome] += len(tmp_recordings)
                annotated_recordings += tmp_recordings
                # Set Absent
                tmp_recordings = get_wav_tsv_files(data_folder=data_folder,
                                                          data=data,
                                                          annotation="Absent",
                                                          locations=present_murmur_locations,
                                                          outcome=outcome,
                                                          inverse=True)
                stats["Absent"] += len(tmp_recordings)
                stats[outcome] += len(tmp_recordings)
                annotated_recordings += tmp_recordings
    return annotated_recordings, stats


def prepare(data_folder: str,
            mode: int = 0,
            split_ratio: float = 0.2,
            verbose: int = 1) -> Tuple[List[Dict],
                                       List[Dict]]:
    """Prepare and split training data

    Args:
        data_folder (str): Folder with wav and other files.
        split_ratio (float, optional): Ratio used to split dataset to train and validation parts. Defaults to 0.2.
        mode (int, optional): Mode used for data annotation. Defaults to 0.
        verbose (int, optional): Print to stdout if set higher then zero. Defaults to 1.

    Returns:
        Tuple[List[Dict], List[Dict]]: List of train and validation files
    """
    assert mode < len(MODES)
    # load files from folder
    files = find_patient_files(data_folder)

    statuses = list()
    for file in tqdm(files) if verbose >= 1 else files:
        data = load_patient_data(file)
        statuses.append(CLASS_MAP[get_label(data)])

    if verbose >= 1:
        print(f"Iterative split with ratio: {split_ratio}")
    files = np.array(files).reshape(-1, 1)
    statuses = np.array(statuses).reshape(-1, 1)

    train_files, _, val_files, _ =\
        iterative_train_test_split(files,
                                   statuses,
                                   test_size=split_ratio)

    if verbose >= 1:
        print(
            f"Train patients: {len(train_files)}\n Val patients: {len(val_files)}\n")
    train, train_stats = parse(data_folder, train_files.reshape(-1).tolist(), mode, verbose)
    val, val_stats = parse(data_folder, val_files.reshape(-1).tolist(), mode, verbose)
    if verbose >= 1:
        print(f"Train recordings: {len(train)}\n Val recordings: {len(val)}\n")
        print(f"Train stats: {train_stats}\n Val stats: {val_stats}\n")
    return train, val
