function  model = team_training_code(input_directory,output_directory) % train_ECG_leads_classifier
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Purpose: Train ECG leads and obtain classifier models
% for 12-lead, 6-lead, 3-lead, 4-lead and 2-lead ECG sets
% Inputs:
% 1. input_directory
% 2. output_directory
%
% Outputs:
% model: trained model
%
% Author: Erick Andres Perez Alday, PhD, <perezald@ohsu.edu>
% Version 1.0 Aug-2020
% Author2: Jonathan Torres, PhD, <jonathanrtc@gmail.com>
% Version 1.0 Apr-2021
addpath('functions')

if ~exist(output_directory, 'dir')
    mkdir(output_directory)
end

%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Define lead sets (e.g 12, 6, 4, 3 and 2 lead ECG sets)
twelve_leads = (1:12);
six_leads    = (1:6);
four_leads   = [1:3,8];
three_leads  = [1:2,8];
two_leads    = [1,2];
lead_sets = {twelve_leads, six_leads, four_leads, three_leads, two_leads};

disp('Loading data...')

load('valid_unvalid_labels')
% Find files.
input_files_tmp = {};
features =[];
for f = dir(input_directory)'
    if exist(fullfile(input_directory, f.name), 'file') == 2 && f.name(1) ~= '.' && all(f.name(end - 2 : end) == 'mat')
        input_files_tmp{end + 1} = f.name;
    end
end

% Extract classes from dataset.
% read number of unique classes
[classes_tmp,classes_all_tmp]  = get_all_classes(input_directory,input_files_tmp);

if isempty(classes_all_tmp)
   error('No valid data found in the directory.')
end

%%
tmp_valid=logical(zeros(size(classes_all_tmp)));
for i=1:size(valid_labels,1)
    [a,b]=find(strcmp(classes_all_tmp,valid_labels(i,:)));
    if ~isempty(a)
        for ii=1:length(a)
            tmp_valid(a(ii),b(ii))=1;
        end
    end
end

tmp_unvalid=logical(zeros(size(classes_all_tmp)));
for i=1:size(unvalid_labels,1)
    [a,b]=find(strcmp(classes_all_tmp,unvalid_labels(i,:)));
    if ~isempty(a)
        for ii=1:length(a)
            tmp_unvalid(a(ii),b(ii))=1;
        end
    end
end

classes_all=classes_all_tmp;
classes_all(tmp_unvalid)={'0000'};

% select only valid data
tmp_data_valid=find(sum(tmp_valid,2)~=0);
input_files = input_files_tmp(tmp_data_valid);

classes=unique(classes_all_tmp(logical(tmp_valid)));
num_classes = length(classes);     % number of classes
num_files   = length(input_files);

Total_data  = cell(1,num_files);
Total_header= cell(1,num_files);

%% Load data recordings and header files
% Iterate over files.
disp('Training model..')

label=zeros(num_files,num_classes);
first_label=zeros(num_files,1);
num_sec = 15; % number of seconds per signal
newfs = 200; % new Fs for all signals

tic
for i = 1:num_files
    flag=0;
    if ~mod(i,1000)
    disp(['    ', num2str(i), '/', num2str(num_files), '...'])
    end
    % Load data.
    file_tmp = strsplit(input_files{i},'.');
    tmp_input_file = fullfile(input_directory, file_tmp{1});
    [data,header_data] = load_challenge_data(tmp_input_file);
    %% Check the number of available ECG leads
    tmp_hea = strsplit(header_data{1},' ');
    num_leads = str2num(tmp_hea{2});
    [leads, leads_idx] = get_leads(header_data,num_leads);
    
    %% data preprocessing 

    val = preprocessing(data,header_data,leads_idx,num_sec,newfs);
    Xnew{i,1} = val;
%     tmp_output_file=fullfile(pwd,output_dir_tmp, file_tmp{1});
%     save(tmp_output_file,'val')
%     tmp=cellfun(@(x)(size(x,2)),new_data,'UniformOutput',false);
%     new_data_downsample=padSequence(new_data,7500);
    %% Extract labels
    len=length(header_data);
    label_aux=zeros(1,num_classes);
    for j = 1 : len
        if startsWith(header_data{j},'#Dx')
            tmp = strsplit(header_data{j},': ');
            % Extract more than one label if avialable
            tmp_c = strsplit(tmp{2},',');
            len2=length(tmp_c);
            for k=1:len2
                idx=find(strcmp(classes,tmp_c{k}));   
                label_aux(idx)=1;
                label(i,:)=label_aux;
                if ~isempty(idx) && flag==0
                    first_label(i,1)=idx;
                    flag=1;
                end
            end
            break
        end
    end
    
end
toc

%% select leads

YTrain=categorical(first_label);
X_Ts=1; YTest=1;
    
for i=1:length(lead_sets)
% for i=1:1
    
    num_leads=length(lead_sets{i});
    % Train ECG model
    disp(['Training ',num2str(num_leads),'-lead ECG model...'])


    if i == 2
        lead8=cellfun(@(X)X(8,:),Xnew, 'UniformOutput', false);
        Xnew = cellfun(@(X)X(lead_sets{i},:),Xnew, 'UniformOutput', false);
    
    elseif i == 3 
        Xnew = cellfun(@(X,Y)[X(1:3,:);Y],Xnew,lead8, 'UniformOutput', false);
        clearvars lead8
    elseif i == 4
        Xnew = cellfun(@(X,Y)X([1:2,4],:),Xnew, 'UniformOutput', false);
    elseif i == 5
        Xnew = cellfun(@(X,Y)X(1:2,:),Xnew, 'UniformOutput', false);
    end

    %% CNN-BILSTM

    miniBatchSize = 32
    maxEpochs=20;
    HU=150; dol=0.5; % bilstmLayer1 parameters 
    HU2=50; dol2=0.3; % bilstmLayer2 parameters
    vf_= 0; % 1 - validationfrequency on
    vf = 2; % validationfrequency/epoch  
    vp = 3; % validation patience
    plot_=0; % 1 - plot on

    % gpurng(0, "threefry") 
    net=CNN_BILSTM_network_training(Xnew,YTrain,X_Ts,...
        YTest,miniBatchSize,maxEpochs,HU,dol,HU2,dol2,vf_,vf,vp,plot_);

    save_ECGleads_model_mod1(net,output_directory,classes,num_leads,newfs,num_sec,miniBatchSize)
end
end
