% This code train neural network based on generated features for
% GenerateFeatures.m file
% if the file containing the features all_features.mat does not exist the
% script generateFeatures will be called to genereate and save the features
% Dependencies:
%
%       1) This function requre the NN toolbox from Matlab
% 
% Written by  Haidar Almubarak April 06, 2015
%             h.almubarak@ieee.org
% Last Modified:
%

clear all; 
close all; clc;


% Check if features exist exist
allfeatures=dir(['all_features.mat']);
if(~isempty(allfeatures))
    load('all_features.mat');
else
    % Call the script to generate the features
    generateFeatures
end

calcPerf = 0; % display the accuracy of the network

rng(1); % set random number seed to get the same results each time it run
numHidden = 20; % number of hidden neuron
% Train the Asystoel
features = double(Asystole_features(:,1:end-1));
target = double(Asystole_features(:,end));
A_net = patternnet(numHidden);
A_net = configure(A_net,features',target');

% Divide the data to training set and testing set
A_net.divideParam.trainRatio = 0.66;
A_net.divideParam.valRatio = 0;
A_net.divideParam.testRatio = 0.33;   

[A_net,tr] = train(A_net,features',target');
        

if calcPerf
    % Retrive testing indices  to calculate performance
    y = A_net(features');
    compA = [y',target,(y>0.30)'];
    acc = mean(compA(:,2) == compA(:,3));
    fprintf('Asystoel accuracy = %0.2f\n',acc)
end

% Train the Bradycardia
features = double(Bradycardia_features(:,1:end-1));
target = double(Bradycardia_features(:,end));
B_net = patternnet(numHidden);
B_net = configure(B_net,features',target');

% Divide the data to training set and testing set
B_net.divideParam.trainRatio = 0.66;
B_net.divideParam.valRatio = 0;
B_net.divideParam.testRatio = 0.33;  

[B_net,tr] = train(B_net,features',target');

if calcPerf
    % Retrive testing indices  to calculate performance
    y = B_net(features');
    compB = [y',target,(y>0.30)'];
    acc = mean(compB(:,2) == compB(:,3));
    fprintf('Bradycardia accuracy = %0.2f\n',acc)
end

% Train the Tachycardia
features = double(Tachycardia_features(:,1:end-1));
target = double(Tachycardia_features(:,end));
T_net = patternnet(numHidden);
T_net = configure(T_net,features',target');

% Divide the data to training set and testing set
T_net.divideParam.trainRatio = 0.66;
T_net.divideParam.valRatio = 0;
T_net.divideParam.testRatio = 0.33;  

[T_net,tr] = train(T_net,features',target');

if calcPerf
    %Retrive testing indices  to calculate performance
    y = T_net(features');
    compT = [y',target,(y>0.30)'];
    acc = mean(compT(:,2) == compT(:,3));
    fprintf('Tachycardia accuracy = %0.2f\n',acc)
end


% Train the Ventricular_Flutter_Fib
features = double(Ventricular_Flutter_Fib_features(:,1:end-1));
target = double(Ventricular_Flutter_Fib_features(:,end));
VF_net = patternnet(numHidden);
VF_net = configure(VF_net,features',target');

% Divide the data to training set and testing set
VF_net.divideParam.trainRatio = 0.66;
VF_net.divideParam.valRatio = 0;
VF_net.divideParam.testRatio = 0.33;  

[VF_net,tr] = train(VF_net,features',target');

if calcPerf
    %Retrive testing indices  to calculate performance
    y = VF_net(features');
    compVF = [y',target,(y>0.30)'];
    acc = mean(compVF(:,2) == compVF(:,3));
    fprintf('Ventricular Flutter Fib accuracy = %0.2f\n',acc)
end

% Train the Ventricular_Tachycardia
features = double(Ventricular_Tachycardia_features(:,1:end-1));
target = double(Ventricular_Tachycardia_features(:,end));
VT_net = patternnet(numHidden);
VT_net = configure(VT_net,features',target');

% Divide the data to training set and testing set
VT_net.divideParam.trainRatio = 0.66;
VT_net.divideParam.valRatio = 0;
VT_net.divideParam.testRatio = 0.33;  

[VT_net,tr] = train(VT_net,features',target');

if calcPerf
    %Retrive testing indices  to calculate performance
    y = VT_net(features');
    compVT = [y',target,(y>0.30)'];
    acc = mean(compVT(:,2) == compVT(:,3));
    fprintf('Ventricular Tachycardia accuracy = %0.2f\n',acc)
end

save('trainednet.mat','A_net','B_net','T_net','VF_net','VT_net');


