clearvars

% many small separate files so unzip them into a local directory to speed
% up reading into memory
source_dir = fullfile('D:', 'tmp', 'preprocessed');
dx_codes = readmatrix('Dx_map.csv', 'OutputType', 'string'); % diagnoses in plain text
dx_codes_scores = readmatrix('dx_mapping_scored.csv', 'OutputType', 'string'); % scored diagnoses in plain text
dx_codes_weights = readmatrix('weights.csv', 'OutputType', 'string'); % weights for the scored diagnoses

%% read the preprocessed files into memory
db_folders = dir(source_dir);
features = struct();
for i=3:length(db_folders)
    data_files = dir(fullfile(db_folders(i).folder, db_folders(i).name));
    dbn = strrep(db_folders(i).name, '-', '_');
    for j=3:length(data_files)
        fn = split(data_files(j).name, '.');
        features.(dbn).(fn{1}) = load(fullfile(data_files(j).folder, data_files(j).name));
    end
end

%% misc
leads = {};
labels = {};
lengths = {};
demographics = {};
erroneous = {};

%% browse through the records and calculate the desired statistics
dbs = fieldnames(features);
for i=1:length(dbs) % iterate over databases
    dbn = dbs{i};
    recs = fieldnames(features.(dbn));
    for j=1:length(recs) % iterate over records in the database
        recn = recs{j};
        output = features.(dbn).(recn).output;
        try
            % check the labels
            tmp = split(output.labels, ',');
            labels(end+1) = {cellstr(tmp)};
            % check the number of leads
            tmp = fieldnames(output);
            tmp = tmp(strcmp(tmp, 'labels')~=1);
            leads(end+1) = {tmp};
            % check the record length and demographics
            for k=1:length(tmp)
                try
                    lengths(end+1) = {output.(tmp{k}).TotalTime};
                    demographics(1, end+1) = {output.(tmp{k}).Age};
                    demographics(2, end) = {output.(tmp{k}).Sex};
                    break;
                catch
                end
            end
            
        catch ex
            erroneous(1, end+1) = {recn};
            erroneous(2, end) = {ex};
        end
    end
end

%% evaluate the collected stats
unique_labels = unique(vertcat(labels{:}));
f = cell2mat(demographics(2,:)) == 1;
m = cell2mat(demographics(2,:)) == 0;
for i=1:size(unique_labels, 1)
    idx = find(strcmp(dx_codes(:,2), unique_labels{i, 1}));
    unique_labels{i, 2} = dx_codes{idx, 1};
    unique_labels{i, 3} = sum(strcmp(unique_labels{i, 1}, vertcat(labels{f})));
    unique_labels{i, 4} = sum(strcmp(unique_labels{i, 1}, vertcat(labels{m})));
    unique_labels{i, 5} = sum(strcmp(unique_labels{i, 1}, vertcat(labels{:})));
end

unique_leads = unique(vertcat(leads{:}));
for i=1:size(unique_leads, 1)
    unique_leads{i, 2} = sum(strcmp(unique_leads{i, 1}, vertcat(leads{:})));
end

%% barplot for the different labels split by gender
figure

minimum_number_of_dx = 0;

tmp = flip(sortrows(unique_labels, 5));
tmp = tmp(cell2mat(tmp(:,5)) > minimum_number_of_dx, :);

x = categorical(tmp(:,2));
x = reordercats(x, tmp(:,2));
y = [cell2mat(tmp(:,3)), cell2mat(tmp(:,4))];
b = bar(x, y, 'stacked');
xtips1 = b(1).XEndPoints;
ytips1 = b(1).YEndPoints;
labels1 = string(b(1).YData);
text(xtips1,ytips1,labels1,'HorizontalAlignment','center',...
    'VerticalAlignment','bottom')
xtips2 = b(2).XEndPoints;
ytips2 = b(2).YEndPoints;
labels2 = string(b(2).YData);
text(xtips2,ytips2,labels2,'HorizontalAlignment','center',...
    'VerticalAlignment','bottom')
title('Gender specific number of labels')
legend('women', 'men');
grid on

% change label color of scored diagnoses
x_tick_labels = get(gca, 'XTickLabel');
x_tick_labels_new = cell(length(x_tick_labels));
for i=1:length(x_tick_labels)
    comp = x_tick_labels{i} == dx_codes_scores(:,1);
    if sum(comp) > 0
        x_tick_labels_new{i} = ['\color{red} ' x_tick_labels{i}];
    else
        x_tick_labels_new{i} = x_tick_labels{i};
    end
end
set(gca, 'XTickLabel', x_tick_labels_new);

%% barplot for the length of the recordings
figure
histogram(cell2mat(lengths), 2000);
title('Number of records with specified length in seconds')
set(gca, 'XScale','log')
grid on

%% barplot for the age distribution
figure
histogram(cell2mat(demographics(1, f)));
hold on;
histogram(cell2mat(demographics(1, m)));
title('Gender specific age distribution')
legend('women', 'men');
grid on

%% barplot for the number of labels per recording
figure
number_of_labels_per_rec = {};
for i=1:size(labels,2)
    number_of_labels_per_rec{end+1} = length(labels{i});
end
histogram(cell2mat(number_of_labels_per_rec));
title('Number of labels per record')

%% plot weighting matrix
weights = str2double(cellstr(dx_codes_weights(2:end, 2:end)));
x_values = cell(size(dx_codes_weights, 2) - 1, 1);
for i=1:length(x_values)
    idx = find(dx_codes_weights{1, i+1} == dx_codes_scores(:,2));
    x_values{i} = [dx_codes_scores{idx, 2} ' ' dx_codes_scores{idx, 1}];
end
y_values = cell(size(dx_codes_weights, 1) - 1, 1);
for i=1:length(y_values)
    idx = find(dx_codes_weights{i+1, 1} == dx_codes_scores(:,2));
    y_values{i} = [dx_codes_scores{idx, 2} ' ' dx_codes_scores{idx, 1}];
end

h = heatmap(x_values, y_values, weights);
h.XLabel = 'Target Class';
h.YLabel = 'Predicted Class';

%% plot weighting matrix multiplied by the number of label occurrences
weights = str2double(cellstr(dx_codes_weights(2:end, 2:end)));
new_weights = zeros([size(weights, 1),size(weights)]);
x_values = cell(size(dx_codes_weights, 2) - 1, 1);
for i=1:length(x_values)
    idx = find(dx_codes_weights{1, i+1} == dx_codes_scores(:,2));
    x_values{i} = [dx_codes_scores{idx, 2} ' ' dx_codes_scores{idx, 1}];
    idx = find(dx_codes_weights{1, i+1} == string(unique_labels(:, 1)));
    n = unique_labels{idx, 5};
    new_weights(i, :, i) = weights(:, i) * n;
    new_weights(i, i, :) = weights(i, :) * n;
end
y_values = cell(size(dx_codes_weights, 1) - 1, 1);
for i=1:length(y_values)
    idx = find(dx_codes_weights{i+1, 1} == dx_codes_scores(:,2));
    y_values{i} = [dx_codes_scores{idx, 2} ' ' dx_codes_scores{idx, 1}];
end

new_weights = squeeze(sum(new_weights, 1));
h = heatmap(x_values, y_values, new_weights);
h.XLabel = 'Target Class';
h.YLabel = 'Predicted Class';

%% plot normalized weighting matrix multiplied by the number of label occurrences
weights = str2double(cellstr(dx_codes_weights(2:end, 2:end)));
new_weights = zeros([size(weights, 1),size(weights)]);
x_values = cell(size(dx_codes_weights, 2) - 1, 1);
for i=1:length(x_values)
    idx = find(dx_codes_weights{1, i+1} == dx_codes_scores(:,2));
    x_values{i} = [dx_codes_scores{idx, 2} ' ' dx_codes_scores{idx, 1}];
    idx = find(dx_codes_weights{1, i+1} == string(unique_labels(:, 1)));
    n = unique_labels{idx, 5};
    new_weights(i, :, i) = weights(:, i) * n;
    new_weights(i, i, :) = weights(i, :) * n;
end
y_values = cell(size(dx_codes_weights, 1) - 1, 1);
for i=1:length(y_values)
    idx = find(dx_codes_weights{i+1, 1} == dx_codes_scores(:,2));
    y_values{i} = [dx_codes_scores{idx, 2} ' ' dx_codes_scores{idx, 1}];
end

new_weights = tril(squeeze(sum(new_weights, 1)) / max(new_weights, [], 'all'));
h = heatmap(x_values, y_values, new_weights);
h.XLabel = 'Target Class';
h.YLabel = 'Predicted Class';

%% check for missing values
if ~exist('classifier', 'var')
    [featureset, categoricals] = build_featureset(source_dir, ...
                                                  leads, ...
                                                  features_numerical, ...
                                                  features_numerical_global, ...
                                                  features_categorical, ...
                                                  features_categorical_global);
end
                                          
X = table2array(featureset(:, 1:end-1));
n_nans = sum(isnan(X), 1) / size(X, 1);
n_inf = sum(isinf(X), 1) / size(X, 1);
x = categorical(featureset.Properties.VariableNames(1:end-1));
x = reordercats(x, featureset.Properties.VariableNames(1:end-1));
bar(x, [n_nans; n_inf], 'stacked');


