function  model = train_12ECG_classifier(input_directory,output_directory)

disp('Loading data...')

% Find files.
input_files = {};
for f = dir(input_directory)'
    if exist(fullfile(input_directory, f.name), 'file') == 2 && f.name(1) ~= '.' && all(f.name(end - 2 : end) == 'mat')
        input_files{end + 1} = f.name;
    end
end

% read number of unique classes
classes = get_classes(input_directory,input_files);

num_classes = 28; % length(classes); switch from all (111) to 27 diagnoses and a junk drawer 
num_files = length(input_files);

% Total_data=cell(1,num_files);
% Total_header=cell(1,num_files);

M = 3602; %%%%%%%%%%%%% MAKE SURE THIS IS UPDATED %%%%%%%%%%%%%%

features = zeros(num_files,M); 
labelz=zeros(num_files,num_classes);

disp('Training model..')

% Iterate over files.
for i = 1:num_files
    disp(['    ', num2str(i), '/', num2str(num_files), '...'])
    
    % Load data.
    file_tmp=strsplit(input_files{i},'.');
    tmp_input_file = fullfile(input_directory, file_tmp{1});
    
%     [data,hea_data] = load_challenge_data(tmp_input_file);
    
%     Total_data{i}=data;
%     Total_header{i}=hea_data;
    
%     data = Total_data{i};
%     header_data = Total_header{i};
    
    [data,header_data] = load_challenge_data(tmp_input_file);

    % make sure it is oriented the right way for processing
    if(size(data,1)>12)
        data = data';
    end
    
    tmp_features = get_12ECG_features(data,header_data);
    
    features(i,:)=tmp_features;

%     for j = 1 : length(header_data)
%         if startsWith(header_data{j},'#Dx')
%             tmp = strsplit(header_data{j},': ');
%             tmp_c = strsplit(tmp{2},',');
%             for k=1:length(tmp_c)
%                 idx=find(strcmp(classes,tmp_c{k}));
%                 label(i,idx)=1;
%             end
%             break
%         end
%     end

[~,~,~,~,~,~,~,all_labels]=extract_header_data(header_data);

label = split(all_labels,','); % find all labels
class = zeros(1,num_classes); % 27 graded classes and 1 junk drawer

snomed = {'270492004','164889003','164890007','426627000','713427006',...
    '713426002','445118002','39732003','164909002','251146004','698252002',...
    '10370003','284470004','427172004','164947007','111975006','164917005',...
    '47665007','59118001','427393009','426177001','426783006','427084000',...
    '63593006','164934002','59931005','17338001'};

for j = 1:length(label)
    if(strcmp(label{j},snomed{1})) % 1st degree AV block
        class(1) = 1;
    elseif(strcmp(label{j},snomed{2})) % atrial fibrillation
        class(2) = 1;
    elseif(strcmp(label{j},snomed{3})) % atrial flutter
        class(3) = 1;
    elseif(strcmp(label{j},snomed{4})) % bradycardia
        class(4) = 1;
    elseif(strcmp(label{j},snomed{5})) % complete right bundle branch block
        class(5) = 1;
    elseif(strcmp(label{j},snomed{6})) % incomplete right bundle branch block
        class(6) = 1;
    elseif(strcmp(label{j},snomed{7})) % left anticular fascicular block
        class(7) = 1;
    elseif(strcmp(label{j},snomed{8})) % left axis deviation
        class(8) = 1;
    elseif(strcmp(label{j},snomed{9})) % left bundle branch block
        class(9) = 1;
    elseif(strcmp(label{j},snomed{10})) % low qrs voltages
        class(10) = 1;
    elseif(strcmp(label{j},snomed{11})) % nonspecific intraventricular conduction disorder
        class(11) = 1;
    elseif(strcmp(label{j},snomed{12})) % pacing rhythm
        class(12) = 1;
    elseif(strcmp(label{j},snomed{13})) % premature atrial contraction
        class(13) = 1;
    elseif(strcmp(label{j},snomed{14})) % premature ventricular contractions
        class(14) = 1;
    elseif(strcmp(label{j},snomed{15})) % prolonged pr interval
        class(15) = 1;
    elseif(strcmp(label{j},snomed{16})) % prolonged qt interval
        class(16) = 1;
    elseif(strcmp(label{j},snomed{17})) % q wave abnormal
        class(17) = 1;
    elseif(strcmp(label{j},snomed{18})) % right axis deviation
        class(18) = 1;
    elseif(strcmp(label{j},snomed{19})) % right bundle branch block
        class(19) = 1;
    elseif(strcmp(label{j},snomed{20})) % sinus arrhythmia
        class(20) = 1;
    elseif(strcmp(label{j},snomed{21})) % sinus bradycardia
        class(21) = 1;
    elseif(strcmp(label{j},snomed{22})) % sinus rhythm
        class(22) = 1;
    elseif(strcmp(label{j},snomed{23})) % sinus tachycardia
        class(23) = 1;
    elseif(strcmp(label{j},snomed{24})) % supraventricular premature beats
        class(24) = 1;
    elseif(strcmp(label{j},snomed{25})) % t wave abnormal
        class(25) = 1;
    elseif(strcmp(label{j},snomed{26})) % t wave inversion
        class(26) = 1;
    elseif(strcmp(label{j},snomed{27})) % ventricular premature beats (VPB)
        class(27) = 1;
    else
        class(28) = 1; % other
    end
end
labelz(i,:) = class; 
    
end



% label=zeros(num_files,num_classes);

% for i = 1:num_files 
%     
%     disp(['    ', num2str(i), '/', num2str(num_files), '...']);
%     
%     data = Total_data{i};
%     header_data = Total_header{i};
%     
%     % make sure it is oriented the right way for processing
%     if(size(data,1)>12)
%         data = data';
%     end
%     
%     tmp_features = get_12ECG_features(data,header_data);
%     
%     features(i,:)=tmp_features;
% 
%     for j = 1 : length(header_data)
%         if startsWith(header_data{j},'#Dx')
%             tmp = strsplit(header_data{j},': ');
%             tmp_c = strsplit(tmp{2},',');
%             for k=1:length(tmp_c)
%                 idx=find(strcmp(classes,tmp_c{k}));
%                 label(i,idx)=1;
%             end
%             break
%         end
%     end
%     
%     
% end


nfeat = round(16*sqrt(M));
ntrees = 150; % Increase this number for actual testing. We've got time 

parfor (i = 1:num_classes-1,8) % only train 27, not 28. don't need to train on the junk drawer 
    model{i} = compact(TreeBagger(ntrees,features,labelz(:,i),'InBagFraction',1,'Method','classification','NumPredictorsToSample',nfeat));
end
% model = mnrfit(features,label,'model','hierarchical');

save_12_ECG_model(model,output_directory,classes);

end

function save_12_ECG_model(model,output_directory,classes)
% Save results.
tmp_file = 'finalized_model.mat';
filename=fullfile(output_directory,tmp_file);
save(filename,'model','classes','-v7.3');


disp('Done.')
end


% find unique number of classes
function classes = get_classes(input_directory,files)

classes={};
num_files = length(files);
k=1;
for i = 1:num_files
    g = strrep(files{i},'.mat','.hea');
    input_file = fullfile(input_directory, g);
    fid=fopen(input_file);
    tline = fgetl(fid);
    tlines = cell(0,1);
    
    while ischar(tline)
        tlines{end+1,1} = tline;
        tline = fgetl(fid);
        if startsWith(tline,'#Dx')
            tmp = strsplit(tline,': ');
            tmp_c = strsplit(tmp{2},',');
            for j=1:length(tmp_c)
                idx2 = find(strcmp(classes,tmp_c{j}));
                if isempty(idx2)
                    classes{k}=tmp_c{j};
                    k=k+1;
                end
            end
            break
        end
    end
    
    fclose(fid);
    
end
classes=sort(classes);
end

function [data,tlines] = load_challenge_data(filename)

% Opening header file
fid=fopen([filename '.hea']);

if (fid<=0)
    disp(['error in opening file ' filename]);
end

tline = fgetl(fid);
tlines = cell(0,1);
while ischar(tline)
    tlines{end+1,1} = tline;
    tline = fgetl(fid);
end
fclose(fid);

f=load([filename '.mat']);

try
    data = f.val;
catch ex
    rethrow(ex);
end

end