%% Reduce size of featureset
% load("D:\GitLab\cinc2021a\dev\featureSet2020\trainingData2020.mat");
% features = cell2mat(features);
% save("D:\GitLab\cinc2021a\dev\featureSet2020\trainingData2020_small.mat");

% load("D:\GitLab\cinc2021a\dev\featureSet2021\trainingData2021.mat");
% features = cell2mat(features);
% save("D:\GitLab\cinc2021a\dev\featureSet2021\trainingData2021_small.mat");

%% Load the featureset
% path = "D:\GitLab\cinc2021a\dev\featureSet2020\trainingData2020_small.mat";
% path = "/Users/matthieuscherpf/Desktop/cinc2021a/dev/featureSet2020/trainingData2020_small.mat";
path = "D:\GitLab\cinc2021a\dev\featureSet2021\trainingData2021_small.mat";
% path = "/Users/matthieuscherpf/Desktop/cinc2021a/dev/featureSet2021/trainingData2021_small.mat";
load(path);


%% Preprocess the featureset
% find samples with feature vector containing only NaN
rows = find(sum(isnan(features), 2) == size(features, 2));
features(rows, :) = [];
label(rows, :) = [];


%% Set vars
categorical_var_names = {'Sex'};


%% Perform grid search



%% Set hyperparamters
% apply a custom misclassification cost
use_cost = false;
% set the parameters for the boosted decision tree ensemble
dte_parameters = struct();
% dte_parameters.stratified_string = 'Holdout'; % 'Holdout' or 'KFold'
% dte_parameters.stratified_numeric = 0.3; % e.g.: 0.3 or 4
dte_parameters.stratified_string = 'KFold'; % 'Holdout' or 'KFold'
dte_parameters.stratified_numeric = 3; % e.g.: 0.3 or 4
dte_parameters.method = 'LogitBoost';
dte_parameters.n_splits = 2;
dte_parameters.n_learners = 49;
dte_parameters.learn_rate = 0.1;
dte_parameters.numBins = 50;


%% Train classifier
classifier = fit_boosted_decision_tree_ensemble(...
    features, ...
    label, ...
    classes, ...
    use_cost, ...
    categorical_var_names, ...
    feat_names, ...
    dte_parameters);


%% Predict classifier
    preds = zeros(size(features, 1), length(classes));
    preds_opt = zeros(size(features, 1), length(classes));
    scores = zeros(size(features, 1), length(classes), 2);
    
    for i=1:length(classes)
        
        if classes{i} ~= classifier(i).dx
            error('Error in trained classifier(i).dx and related scored_labels(i)');
        end
        
        [y, sc] = kfoldPredict(classifier(i).branch);

        preds(:, i) = y;
        scores(:, i, :) = sc;
        preds_opt(:, i) = sc(:,2) > classifier(i).result.th_opt;
        
        classifier(i).kfoldpredict = struct();
        classifier(i).kfoldpredict.preds = preds;
        classifier(i).kfoldpredict.preds_opt = preds_opt;
        classifier(i).kfoldpredict.scores = scores;
        classifier(i).kfoldpredict.targets = label(:, i);
    end


%% Save classifier results and parameters
base = "D:\GitLab\cinc2021a\dev\cls_2021\search\";
% base = "/Users/matthieuscherpf/Desktop/cinc2021a/dev/cls_2021/search/";
path = append(base, ...
        "dte_", ...
        "ns_", num2str(dte_parameters.n_splits), "_", ...
        "nl_", num2str(dte_parameters.n_learners), "_", ...
        "lr_", num2str(dte_parameters.learn_rate), "_", ...
        "nb_", num2str(dte_parameters.numBins), "_", ...
        "m_", dte_parameters.method, "_", ...
        "ss_", dte_parameters.stratified_string, "_", ...
        "sn_", num2str(dte_parameters.stratified_numeric), ...
        ".mat");

classifier = rmfield(classifier, 'branch');
save(path, 'classifier', 'preds', 'preds_opt', 'scores', 'label', 'classes');


%% Plot AUROC and AUPRC
vals = zeros(length(classifier),4);
for i=1:length(classifier)
    vals(i,1) = classifier(i).result.prec_rec.auroc;
    vals(i,2) = classifier(i).result.prec_rec.auprc;
    vals(i,3) = classifier(i).result.f1;
    vals(i,4) = classifier(i).result.f1_best;
end
figure();
bar(vals);
legend('auroc', 'auprc', 'f1', 'f1\_best');


%% compute challenge metric
evaluate_model(path)
