function  model = team_training_code(input_directory,output_directory) % train_ECG_leads_classifier

twelve_leads = [{'I'}, {'II'}, {'III'}, {'aVR'}, {'aVL'}, {'aVF'}, {'V1'}, {'V2'}, {'V3'}, {'V4'}, {'V5'}, {'V6'}];
six_leads    = [{'I'}, {'II'}, {'III'}, {'aVR'}, {'aVL'}, {'aVF'}];
four_leads   = [{'I'}, {'II'}, {'III'}, {'V2'}];
three_leads  = [{'I'}, {'II'}, {'V2'}];
two_leads    = [{'I'}, {'II'}];
lead_sets = {twelve_leads, six_leads, four_leads, three_leads, two_leads};

disp('Loading data...')
addpath(genpath('dev/'))

% Find files.
input_files = {};
for f = dir(input_directory)'
    if exist(fullfile(input_directory, f.name), 'file') == 2 && f.name(1) ~= '.' && all(f.name(end - 2 : end) == 'mat')
        input_files{end + 1} = f.name;
    end
end

% Extract classes from dataset.
% read number of unique classes
% only consider the scored classes

dx_codes_scores = readmatrix('dx_mapping_scored.csv', 'OutputType', 'string');
scored_labels   = dx_codes_scores(:,2);
classes         = get_classes(input_directory,input_files,scored_labels);
%classes         = get_classes(input_directory,input_files);
num_classes     = length(classes);     % number of classes
num_files       = length(input_files);
Total_data      = cell(1,num_files);
Total_header    = cell(1,num_files);

%% Load data recordings and header files
% Iterate over files.
disp('Training model..')

label=zeros(num_files,num_classes);

for i = 1:num_files
    disp(['    ', num2str(i), '/', num2str(num_files), '...'])
    % Load data.
    file_tmp = strsplit(input_files{i},'.');
    tmp_input_file = fullfile(input_directory, file_tmp{1});
    [data,header_data] = load_challenge_data(tmp_input_file);
    
    %% Check the number of available ECG leads
    tmp_hea = strsplit(header_data{1},' ');
    num_leads = str2num(tmp_hea{2});
    [leads, leads_idx] = get_leads(header_data,num_leads);

    %Extract features/labels only for ECGs with less then 150s
    head_line1 = strsplit(header_data{1}," ");
    duration = str2double(head_line1{4})/str2double(head_line1{3});
    if duration < 150
        %% Extract features
        [tmp_features,lea] = get_features(data,header_data,leads_idx,leads);
        features(i,:) = tmp_features;

        %% Extract labels
        for j = 1 : length(header_data)
            if startsWith(header_data{j},'#Dx')
                tmp = strsplit(header_data{j},': ');
                % Extract more than one label if avialable
                tmp_c = strsplit(tmp{2},',');
                for k=1:length(tmp_c)
                    idx=find(strcmp(classes,tmp_c{k}));
                    label(i,idx)=1;
                end
                break
            end
        end
    end


end

% merge jointly assessed classes
classes_joined = {'164909002','733534002';...
    '59118001','713427006';...
    '63593006','284470004';...
    '17338001','427172004'};

for i = 1:size(classes_joined,1)
    mask1 = ismember(classes,classes_joined{i,1});
    mask2 = ismember(classes,classes_joined{i,2});

    if or(sum(mask1)==0,sum(mask2)==0)
        continue
    end
    
    label(:,mask1) = or(label(:,mask1),label(:,mask2));
    label(:,mask2) = [];
    classes(:,mask2) = [];
end


%% get feature names
parameterCLS;
[feat_names, varnames_global] = buildFeatureVec(features_numerical, features_numerical_global, features_categorical, features_categorical_global, leads);
feat_names = [feat_names, varnames_global];

%% save feature vector and labels
% empties = cellfun('isempty',features);
% features(empties) = {NaN};
% features = cell2mat(features);
% str_leads = string(lea);
% save([output_directory+"\trainingData.mat"],'features','classes','label','str_leads','-v7.3')


%% eliminate labels that only occure once (only relevant for small subsets)
idx_single = find(sum(label,1)==1);
classes(idx_single) = [];
label(:,idx_single) = [];


%% train 5 models for 5 different sets of leads (2, 3, 4, 6, 12-lead ECG)
for i = 1:length(lead_sets)
    % Train ECG model
    num_leads = length(lead_sets{i});
    disp(['Training ',num2str(num_leads),'-lead ECG model...'])
    [leads, leads_idx] = get_leads(header_data,num_leads);
    % Features = [1:12] features from [2,3,4,6,12] ECG leads + Age + TotalTime + NumWindows + Sex
    Features_leads_idx = ismember(string(lea),string(leads_idx)) | ismember(string(lea),"0");
    Features_leads = features(:,Features_leads_idx);
    empties = cellfun('isempty',Features_leads);
    Features_leads(empties) = {NaN};
    Feat_names_leads = feat_names(Features_leads_idx);
    %leads_temp = lea(Features_leads_idx);
    % model = mnrfit(cell2mat(Features_leads),label,'model','hierarchical'); %%%%%%%%%%%%
    %model = fit_final_nn(cell2mat(Features_leads), label, classes, num_leads, leads_temp);
    model = fit_dt_main(cell2mat(Features_leads), label, classes, Feat_names_leads);
    save_ECGleads_model(model,output_directory,classes,num_leads);
    clear('model')
end
end

function save_ECGleads_model(model,output_directory,classes,num_leads) %save_ECG_model
% Save results.
tmp_file = [num2str(num_leads),'_lead_ecg_model.mat'];
filename = fullfile(output_directory,tmp_file);
save(filename,'model','classes','-v7.3');

disp('Done.')
end


function save_ECGleads_features(features,output_directory) %save_ECG_model
% Save results.
tmp_file = 'features.mat';
filename=fullfile(output_directory,tmp_file);
save(filename,'features');
end

% find unique number of classes
function classes = get_classes(input_directory,files,scored_labels)

classes={};
num_files = length(files);
k=1;
for i = 1:num_files
    g = strrep(files{i},'.mat','.hea');
    input_file = fullfile(input_directory, g);
    fid=fopen(input_file);
    tline = fgetl(fid);
    tlines = cell(0,1);

    while ischar(tline)
        tlines{end+1,1} = tline;
        tline = fgetl(fid);
        if startsWith(tline,'#Dx')
            tmp = strsplit(tline,': ');
            tmp_c = strsplit(tmp{2},',');
            for j=1:length(tmp_c)
                idx2 = find(strcmp(classes,tmp_c{j}));
                is_scored = any(strcmp(scored_labels,tmp_c{j}));
                if isempty(idx2) && any(is_scored)
                    classes{k}=tmp_c{j};
                    k=k+1;
                end
            end
            break
        end
    end

    fclose(fid);

end
classes=sort(classes);
end

% find unique number of classes
% function classes = get_classes(input_directory,files)
% 
% classes={};
% num_files = length(files);
% k=1;
% for i = 1:num_files
%     g = strrep(files{i},'.mat','.hea');
%     input_file = fullfile(input_directory, g);
%     fid=fopen(input_file);
%     tline = fgetl(fid);
%     tlines = cell(0,1);
% 
%     while ischar(tline)
%         tlines{end+1,1} = tline;
%         tline = fgetl(fid);
%         if startsWith(tline,'#Dx')
%             tmp = strsplit(tline,': ');
%             tmp_c = strsplit(tmp{2},',');
%             for j=1:length(tmp_c)
%                 idx2 = find(strcmp(classes,tmp_c{j}));
%                 if isempty(idx2)
%                     classes{k}=tmp_c{j};
%                     k=k+1;
%                 end
%             end
%             break
%         end
%     end
% 
%     fclose(fid);
% 
% end
% classes=sort(classes);
% end

function [data,tlines] = load_challenge_data(filename)

% Opening header file
fid=fopen([filename '.hea']);

if (fid<=0)
    disp(['error in opening file ' filename]);
end

tline = fgetl(fid);
tlines = cell(0,1);
while ischar(tline)
    tlines{end+1,1} = tline;
    tline = fgetl(fid);
end
fclose(fid);

f=load([filename '.mat']);

try
    data = f.val;
catch ex
    rethrow(ex);
end

end
