function  model = team_training_code(input_directory,output_directory) % train_ECG_leads_classifier
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% Purpose: Train ECG leads and obtain classifier models
% for 12-lead, 6-leads, 3-leads and 2-leads ECG sets
% Inputs:
% 1. input_directory
% 2. output_directory
%
% Outputs:
% model: trained model
% 4 logistic regression models for 4 different sets of leads
%
% Author: Erick Andres Perez Alday, PhD, <perezald@ohsu.edu>
% Version 1.0 Aug-2020
% Revision History
% By: Nadi Sadr, PhD, <nadi.sadr@dbmi.emory.edu>
% Version 2.0 1-Dec-2020
% Version 2.2 25-Jan-2021
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

global dsc_n_catch;
global dsc_n_missing;
global dsc_n_ok;

dsc_n_catch = 0;
dsc_n_missing = 0;
dsc_n_ok = 0;

%%input_directory = 'D:\physionet2021_data\100_training_dataset.tar\100_training_dataset_train';
%input_directory = 'D:\physionet2021_data\100_training_dataset.tar\100_training_dataset';
%input_directory = 'D:\physionet2021_data\WFDB_CPSC2018_2.tar\WFDB_CPSC2018_2'; %Q
%input_directory = 'D:\physionet2021_data\WFDB_Ga'; %E
%input_directory = 'D:\physionet2021_data\WFDB_CPSC2018';
%input_directory = 'G:\Jeffrey\WFDB_CPSC2018';
%input_directory = 'G:\Jeffrey\PhysioNetChallenge2020_Training_E.tar\PhysioNetChallenge2020_Training_E\WFDB';
%input_directory = 'G:\Jeffrey\PhysioNetChallenge2020_Training_PTB-XL2_1';
%input_directory = 'G:\Jeffrey\PhysioNetChallenge2020_Training_PTB-XL2_2';
%input_directory = 'G:\Jeffrey\data_q\all_data\12_lead';
%input_directory = 'G:\Jeffrey\data_s\all_data\12_lead';
%input_directory = 'C:\physionet2021_data\WFDB_PTB';
%input_directory = 'C:\physionet2021_data\WFDB_PTBXL_1';
%input_directory = 'C:\physionet2021_data\WFDB_PTBXL_2';
%input_directory = 'C:\physionet2021_data\WFDB_PTBXL_3';
%input_directory = 'C:\physionet2021_data\WFDB_PTBXL_4';
%input_directory = 'C:\physionet2021_data\WFDB_PTBXL_5';
%input_directory = 'E:\physionet2021_data\WFDB_ShaoxingUnivOfficial\WFDB_ChapmanShaoxing';
%input_directory = 'E:\physionet2021_data\WFDB_Ningbo\WFDB_Ningbo\part_1';
%input_directory = 'E:\physionet2021_data\WFDB_Ningbo\WFDB_Ningbo\part_2';
%input_directory = 'E:\physionet2021_data\WFDB_Ningbo\WFDB_Ningbo\part_3';
%input_directory = 'E:\physionet2021_data\WFDB_Ningbo\WFDB_Ningbo\part_3_2';
%input_directory = 'E:\physionet2021_data\WFDB_Ningbo\WFDB_Ningbo\part_4';
%input_directory = 'E:\physionet2021_data\WFDB_ShaoxingUniv\part_1';
%input_directory = 'D:\physionet2021_data\100_training_dataset.tar\100_training_dataset_2_lead';
%input_directory = 'D:\physionet2021_data\100_training_dataset.tar\100_training_dataset_3_lead';
%input_directory = 'D:\physionet2021_data\100_training_dataset.tar\100_training_dataset_6_lead';
%output_directory = 'D:\physionet2021_models';
%%output_directory = 'E:\physionet2021_models';
%output_directory = 'C:\physionet2021_models';
%output_directory = 'C:\physionet2021_features';
%output_directory = 'E:\physionet2021_features';
%sw_train = true;
%sw_train = false

sw_check_diskspace = false;

disp('Loading data...')

% Find files.
input_files = {};
features =[];
for f = dir(input_directory)'
    if exist(fullfile(input_directory, f.name), 'file') == 2 && f.name(1) ~= '.' && all(f.name(end - 2 : end) == 'mat')
        input_files{end + 1} = f.name;
    end
end

% Extract classes from dataset.
% read number of unique classes
classes = get_classes(input_directory,input_files);
num_classes = length(classes);     % number of classes
num_files   = length(input_files);
Total_data  = cell(1,num_files);
Total_header= cell(1,num_files);

scored_classes = get_scored_classes();

true_labels = zeros(num_files,length(scored_classes));
process_file = false(num_files, 1);

%% Load data recordings and header files
% Iterate over files.
for i = 1:num_files
    disp(['    ', num2str(i), '/', num2str(num_files), '...'])
    % Load data.
    file_tmp = strsplit(input_files{i},'.');
    tmp_input_file = fullfile(input_directory, file_tmp{1});
    [data,hea_data] = load_challenge_data(tmp_input_file);
    Total_data{i}=data;
    Total_header{i}=hea_data;

    [recording_label,~,single_recording_labels]=get_true_labels_gn([tmp_input_file '.hea'], scored_classes);
    true_labels(i,:) = single_recording_labels;
    
    % For now, do not use Shaoxing and Ningbo (JS) datasets
    % Using them would require us to re-fit all thresholds and we do not
    % have time for that
    process_file(i,:) = ~startsWith(recording_label,'I') && ~startsWith(recording_label,'JS');
    
end

disp('Training model..')

[Ts_relevant_leads, T_relevant_leads_00, lead_names] = create_feature_table();

sw_train = true;
%sw_train = false;  %TODO: temporarily enabled to test if flattened features extracted from 12-lead equal those extracted from 2-lead etc.

label=zeros(num_files,num_classes);

features_12 = zeros(num_files, 895);
features_6  = zeros(num_files, 835);
features_4  = zeros(num_files, 754);
features_3  = zeros(num_files, 715);
features_2  = zeros(num_files, 491);

startIdx = 1;
%startIdx = 9684;
for i = startIdx:num_files

    disp(['    ', num2str(i), '/', num2str(num_files), '... ' 'catch: ' num2str(dsc_n_catch) ' | ' 'ok: ' num2str(dsc_n_ok) ' | ' ' missing: ' num2str(dsc_n_missing)]);

%     if sw_check_diskspace
%         FileObj      = java.io.File(output_directory);
%         %free_bytes   = FileObj.getFreeSpace;
%         %total_bytes  = FileObj.getTotalSpace;
%         usable_gigabytes = FileObj.getUsableSpace * 10^-9;
% 
%         if usable_gigabytes < 1 %2
%             disp('disk full');
%             break;
%         end
%     end
    
    data = Total_data{i};
    data = double(data);
    
    header_data = Total_header{i};
    %% Check the number of available ECG leads
    tmp_hea = strsplit(header_data{1},' ');
    num_leads = str2num(tmp_hea{2});
    [leads, leads_idx] = get_leads(header_data,num_leads);
    current_lead_names = leads;
    
    if process_file(i,1)
    
        %% Extract features
        %tmp_features = get_features(data,header_data,leads_idx);

        tmp_features = get_12ECG_features(data, header_data, sw_train, current_lead_names, Ts_relevant_leads, T_relevant_leads_00, lead_names);

    %     if ~sw_train
    %         T_relevant_leads = Ts_relevant_leads{num_leads};
    %         features_n_lead = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

            %save([output_directory '\' 'features_test_A0001_2_lead.mat'], 'tmp_features');
            %save([output_directory '\' 'features_test_A0001_' num2str(num_leads) '_lead.mat'], 'features_n_lead', 'tmp_features', 'Ts_relevant_leads');
    %     else
            T_relevant_leads = Ts_relevant_leads{12};
            features_12_lead = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

            T_relevant_leads = Ts_relevant_leads{6};
            features_6_lead = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

            T_relevant_leads = Ts_relevant_leads{4};
            features_4_lead = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

            T_relevant_leads = Ts_relevant_leads{3};
            features_3_lead = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

            T_relevant_leads = Ts_relevant_leads{2};
            features_2_lead = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

            % Comment for feature extraction without training
            % Begin
            features_12(i,:) = features_12_lead.X; %1x895   %1x246
            features_6(i,:) = features_6_lead.X;   %1x835   %1x210
            features_4(i,:) = features_4_lead.X;   %
            features_3(i,:) = features_3_lead.X;   %1x715   %1x108
            features_2(i,:) = features_2_lead.X;   %1x490   %78
            % End

            %save([output_directory '\' 'features_train_A0001.mat'], 'features_12_lead', 'features_6_lead', 'features_4_lead', 'features_3_lead', 'features_2_lead', 'tmp_features', 'Ts_relevant_leads'); 

            %save([output_directory '\' 'features_' input_files{i}], 'features_12_lead', 'features_6_lead', 'features_4_lead', 'features_3_lead', 'features_2_lead', 'tmp_features', 'Ts_relevant_leads'); 
    %     end

        %features_12(i,:) = tmp_features(:);
    
    end
    
    %% Extract labels
    % Comment this for only saving features
    % Begin
    for j = 1 : length(header_data)
        if startsWith(header_data{j},'#Dx')
            tmp = strsplit(header_data{j},': ');
            % Extract more than one label if avialable
            tmp_c = strsplit(tmp{2},',');
            for k=1:length(tmp_c)
                idx=find(strcmp(classes,tmp_c{k}));
                label(i,idx)=1;
            end
            break
        end
    end
    % End


end

% available_lead_cnt = zeros(12, 1);
% num_strips_with_missing_lead = 0;   % 695
% 
% for i = 1:num_files
% 
%     disp(['    ', num2str(i), '/', num2str(num_files), '... ', 'catch: ' num2str(dsc_n_catch) ' | ' 'ok: ' num2str(dsc_n_ok) ' | ' ' missing: ' num2str(dsc_n_missing)]);
%     
%     data = Total_data{i};
%     data = double(data);
%     
%     available_lead = min(data, [], 2) ~= max(data, [], 2);
%     
%     if sum(available_lead) ~= 12
%         num_strips_with_missing_lead = num_strips_with_missing_lead + 1;
%     end
%     
%     available_lead_cnt = available_lead_cnt + available_lead;
% end

%% train 4 logistic regression models for 4 different sets of leads

features_12 = features_12(process_file, :);
features_6 = features_6(process_file, :);
features_4 = features_4(process_file, :);
features_3 = features_3(process_file, :);
features_2 = features_2(process_file, :);

features_2 = real(features_2);

scored_labels = true_labels(process_file, :);

n_instances = size(scored_labels, 1);

R = 1.0;
N = 1;

classifiers_12 = cell(size(true_labels, 2), N);
classifiers_6 = cell(size(true_labels, 2), N);
classifiers_4 = cell(size(true_labels, 2), N);
classifiers_3 = cell(size(true_labels, 2), N);
classifiers_2 = cell(size(true_labels, 2), N);

for n = 1 : N
                    
    % Boostrap samples with replacement
    k = round(n_instances * R);
    train_idx = datasample(1:n_instances, k);
    
    for label_idx = 1:size(true_labels, 2)
        
        disp(['bag_' num2str(n) 'lbl_' num2str(label_idx)]);
        
        labels = scored_labels(:, label_idx);

        numTrees = 300;  %100
        t = templateTree('MaxNumSplits', 5);
        
        relevant_feature_idx = features_for_class(Ts_relevant_leads{12}, scored_classes{label_idx});
        
        model_i = fitcensemble(features_12(train_idx, relevant_feature_idx), labels(train_idx, :), 'Method', 'AdaBoostM1', 'Learners', t, 'NumLearningCycles', numTrees, 'NPrint', 'off');
        model_s = compact(model_i);
        classifiers_12{label_idx, n} = model_s;
        
        relevant_feature_idx = features_for_class(Ts_relevant_leads{6}, scored_classes{label_idx});
        
        model_i = fitcensemble(features_6(train_idx, relevant_feature_idx), labels(train_idx, :), 'Method', 'AdaBoostM1', 'Learners', t, 'NumLearningCycles', numTrees, 'NPrint', 'off');
        model_s = compact(model_i);
        classifiers_6{label_idx, n} = model_s;
        
        relevant_feature_idx = features_for_class(Ts_relevant_leads{4}, scored_classes{label_idx});
        
        model_i = fitcensemble(features_4(train_idx, relevant_feature_idx), labels(train_idx, :), 'Method', 'AdaBoostM1', 'Learners', t, 'NumLearningCycles', numTrees, 'NPrint', 'off');
        model_s = compact(model_i);
        classifiers_4{label_idx, n} = model_s;
        
        relevant_feature_idx = features_for_class(Ts_relevant_leads{3}, scored_classes{label_idx});
        
        model_i = fitcensemble(features_3(train_idx, relevant_feature_idx), labels(train_idx, :), 'Method', 'AdaBoostM1', 'Learners', t, 'NumLearningCycles', numTrees, 'NPrint', 'off');
        model_s = compact(model_i);
        classifiers_3{label_idx, n} = model_s;
        
        relevant_feature_idx = features_for_class(Ts_relevant_leads{2}, scored_classes{label_idx});
        
        model_i = fitcensemble(features_2(train_idx, relevant_feature_idx), labels(train_idx, :), 'Method', 'AdaBoostM1', 'Learners', t, 'NumLearningCycles', numTrees, 'NPrint', 'off');
        model_s = compact(model_i);
        classifiers_2{label_idx, n} = model_s;
        
        %if sw_save_classifiers
        %    save([output_directory '\model_bag_' num2str(n) 'lbl_' num2str(label_idx) '.mat'], 'model_s');
        %end
        
    end
    
end

model_12.model = classifiers_12;
%save_ECG12leads_model(model_12,output_directory,scored_classes);
save_ECGleads_model(model_12,output_directory,scored_classes,12);

model_6.model = classifiers_6;
%save_ECG6leads_model(model_6,output_directory,scored_classes);
save_ECGleads_model(model_6,output_directory,scored_classes,6);

model_4.model = classifiers_4;
%save_ECG4leads_model(model_4,output_directory,scored_classes);
save_ECGleads_model(model_4,output_directory,scored_classes,4);

model_3.model = classifiers_3;
%save_ECG3leads_model(model_3,output_directory,scored_classes);
save_ECGleads_model(model_3,output_directory,scored_classes,3);

model_2.model = classifiers_2;
%save_ECG2leads_model(model_2,output_directory,scored_classes);
save_ECGleads_model(model_2,output_directory,scored_classes,2);

model.model_12 = classifiers_12;
model.model_6 = classifiers_6;
model.model_4 = classifiers_4;
model.model_3 = classifiers_3;
model.model_2 = classifiers_2;

end
