%% Train boosted decision tree ensemble
function classifier = fit_boosted_decision_tree_ensemble(X, ...
                                                         Y, ...
                                                         scored_labels, ...
                                                         use_cost, ...
                                                         categorical_var_names, ...
                                                         var_names, ...
                                                         dte_parameters)

%     X = array2table(X, 'VariableNames', var_names);
    isCategoricalPredictor = ismember(var_names,categorical_var_names);
    
    classifier = struct();
    
    for i=1:length(scored_labels)
        fprintf('fitting model for dx (%i of %i): %s comprising of %i labels...\n', i, length(scored_labels), scored_labels{i}, sum(Y(:, i)));
        % build output vector
        y = Y(:, i);
        
        if use_cost
            classificationCosts = [0 1; (length(y)/sum(y)) * dte_parameters.cost_amp 0];  
%             classificationCosts = [0 length(y)/sum(y==0); length(y)/sum(y) 0];  
        else
            classificationCosts = [0 1; 1 0];
        end
        
        [model, result] = fit_to_one_class(X, y, classificationCosts, isCategoricalPredictor, dte_parameters);
        classifier(i).classificationCosts = classificationCosts;
        
        classifier(i).branch = model;
        classifier(i).dx = scored_labels{i};
        classifier(i).result = result;
    end
end



function [trainedClassifier, result] = fit_to_one_class(X, ....
                                              y, ...
                                              cost, ...
                                              isCategoricalPredictor, ...
                                              dte_parameters)

                                          
    % Create stratified cross validation partitions according to the specific class
    % labels
    cvp = cvpartition(y, dte_parameters.stratified_string, dte_parameters.stratified_numeric, 'Stratify', true);
    
    % Train a classifier
    % This code specifies all the classifier options and trains the classifier.
    template = templateTree(...
        'MaxNumSplits', dte_parameters.n_splits);
    
    trainedClassifier = fitcensemble(...
    X, ...
    y, ...
    'CategoricalPredictors', isCategoricalPredictor, ...
    'Method', dte_parameters.method, ...
    'NumBins', dte_parameters.numBins, ...
    'NumLearningCycles', dte_parameters.n_learners, ...
    'Learners', template, ...
    'LearnRate', dte_parameters.learn_rate, ...
    'ClassNames', [0; 1], ...
    'Cost', cost, ...
    'CVPartition', cvp);
        
    [pred_labels, pred_scores] = kfoldPredict(trainedClassifier);
    
    result.cm = confusionmat(y, pred_labels);
    result.prec = result.cm(2,2) / (result.cm(2,2) + result.cm(1,2));
    result.rec = result.cm(2,2) / (result.cm(2,2) + result.cm(2,1));
    result.f1 = 2 * (result.prec * result.rec) / (result.prec + result.rec);
    
    [prec, tpr, fpr, thresh] = prec_rec(pred_scores(:,2), y);
    result.prec_rec = struct();
    result.prec_rec.prec = prec;
    result.prec_rec.tpr = tpr;
    result.prec_rec.fpr = fpr;
    result.prec_rec.thresh = thresh;
    result.prec_rec.auroc = trapz(result.prec_rec.fpr, result.prec_rec.tpr);
    result.prec_rec.auprc = trapz(result.prec_rec.tpr, result.prec_rec.prec);
    
    % Find optimum threshold for best F1 score
    f1_best = 0;
    for i=1:length(thresh)
        f1 = 2 * (prec(i) * tpr(i)) / (prec(i) + tpr(i));
        if f1 > f1_best
            f1_best = f1;
            th = thresh(i);
        end
    end
    result.th_opt = th;
    result.f1_best = f1_best;
    
    fprintf('Confusion Matrix:\n');
    disp(result.cm);
    fprintf('Precision: %f\n', result.prec);
    fprintf('Recall: %f\n', result.rec);
    fprintf('F1 Score: %f\n', result.f1);
    fprintf('F1 Score best: %f\n', result.f1_best);
    fprintf('AUROC: %f\n', result.prec_rec.auroc);
    fprintf('AUPRC: %f\n', result.prec_rec.auprc);
end
