%% Apply classifier model to test set

function [score, label,classes] = team_testing_code(data,header_data, loaded_model)

model   = loaded_model.model.model;
classes = loaded_model.classes;

num_classes = length(classes);

label = zeros([1,num_classes]);

score = ones([1,num_classes]);

sw_train = false;

% Extract features from test data
tmp_hea = strsplit(header_data{1},' ');
num_leads = str2num(tmp_hea{2});
[leads, leads_idx] = get_leads(header_data,num_leads);

%features = get_features(data,header_data,leads_idx);

current_lead_names = leads;
    
%% Extract features
%tmp_features = get_features(data,header_data,leads_idx);

[Ts_relevant_leads, T_relevant_leads_00, lead_names] = create_feature_table();
tmp_features = get_12ECG_features(double(data), header_data, sw_train, current_lead_names, Ts_relevant_leads, T_relevant_leads_00, lead_names);
T_relevant_leads = Ts_relevant_leads{num_leads};
features = flatten_12ECG_features(tmp_features, T_relevant_leads_00, T_relevant_leads, sw_train);

if num_leads == 2
    features.X = real(features.X);
end

scores = zeros(1, length(classes));
        
n_features = length(model{1}.PredictorNames);

%classifier_idx = 1;
for classifier_idx = 1:length(classes)

    num_bagging_it = size(model, 2);

    for bagging_it = 1:num_bagging_it

        disp(['classifier ' num2str(classifier_idx) ' it ' num2str(bagging_it)]);

        relevant_feature_idx = features_for_class(Ts_relevant_leads{num_leads}, classes{classifier_idx});

        classifier = model{classifier_idx, bagging_it};

        %lbl = predict(model,X,'Options',options);
        %lbl = predict(classifier,X_all);
        classifier.ScoreTransform = 'doublelogit';
        [~, score] = predict(classifier,features.X(:, relevant_feature_idx));

        if size(score, 2) >= 2
            scores(:,classifier_idx) = scores(:,classifier_idx) + score(:,2);
        else
            %this only happens if the training data contains only one class
            scores(:,classifier_idx) = scores(:,classifier_idx) + score(:,1);
        end
        %lbl = nominal(lbl);
        %lbl = string(lbl);
        %lbl = find(strcmp(classes,lbl));

    end

    scores(:,classifier_idx) = scores(:,classifier_idx) ./ num_bagging_it;

end

%load('E:\PhysioNetChallenge2020\partition_3\train_labels.mat', 'train_labels');
%thresholding_method.type = 'Pcut';
%[pred, thresh]= Thresholding(scores, thresholding_method, train_labels);

thresholding_method.type = 'Scut';

switch num_leads
    case 2
        thresholding_method.param = [0.276199415269194,0.256888251169295,0.0493204702261360,0.0518065484978944,0.207740789230699,0.129102932526532,0.220575027837850,0.313866079890883,0.196728996499269,0.0859948250758200,0.103231967143416,0.0133331733676337,0.135133921450287,0.000236931088875867,0.104510378173709,0.185220515213048,0.0716002757569612,0.131755263465223,0.186234820998711,0.229763950988502,0.304384851323112,0.292531710964416,0.235253824787101,7.28975475395383e-05,0.220870390892993,0.102996959016937,0.000256946588788318];
    case 3
        thresholding_method.param = [0.285858468845498,0.225014268431193,0.0511720799316023,0.0408829420160593,0.250193500180894,0.141621686386450,0.220575027837850,0.313866079890883,0.124231528813909,0.102566151826037,0.103138945348117,0.0137157992734053,0.127038664878463,0.000697521312974724,0.104510378173709,0.185220515213048,0.0780647422640141,0.131755263465223,0.172998058399663,0.227198627117107,0.302578945599381,0.268923414609357,0.217246376332846,0.000142458469041178,0.220870390892993,0.104273927577370,0.000316911507830359];
    case 4
        thresholding_method.param = [0.289407172862409,0.221787850253782,0.0465511116656417,0.0332579034334284,0.238007510407572,0.141951280693994,0.220575027837850,0.313866079890883,0.148031192985411,0.0937010341582766,0.116123047343566,0.0223188261827640,0.135425348746160,0.000389119304866879,0.104510378173709,0.185220515213048,0.0819692215407659,0.131755263465223,0.170984984386401,0.244898851077077,0.310781823357296,0.261410329530518,0.218196171976612,0.000195002342736799,0.220870390892993,0.108100497892967,0.000135346447039841];
    case 6
        thresholding_method.param = [0.281710553767104,0.222802602471845,0.0268672119263873,0.0263849516158324,0.214319639569465,0.140899807742139,0.220575027837850,0.313866079890883,0.164142672611650,0.0822488384276021,0.121435665394513,0.00476529366396215,0.101353296986424,0.000100187588421869,0.104510378173709,0.185220515213048,0.0744875771964261,0.131755263465223,0.178890042023556,0.225991249157861,0.308880524890369,0.263292733601110,0.188755547286423,2.62230090078498e-05,0.220870390892993,0.112123856880781,1.74331918085127e-05];
    case 12
        thresholding_method.param = [0.278047549857996,0.211562086242552,0.0313479043681213,0.0271404394369492,0.237501636974121,0.165963781076721,0.220575027837850,0.313866079890883,0.101467558509587,0.0970698764538945,0.125344792475024,0.00173347411739847,0.107905370355006,0.000137838710565832,0.104510378173709,0.185220515213048,0.0801901052047295,0.131755263465223,0.170716915568546,0.229002819053882,0.314247967736564,0.235964010898182,0.199722663286614,3.88479021681526e-05,0.220870390892993,0.104675362235502,2.64773127663693e-05];
end

pred = Thresholding(scores, thresholding_method);

score = pred;
label = pred;

end
