% Main file for activity classfication from wrist/ankle data 
% 
% [Andrea Mannini: a.mannini@sssup.it; Created: Apr 13 2012]




close all
clear all

%% set path
addpath('Toolboxes_and_functions')
addpath('Toolboxes_and_functions\PRToolsMatlab\prtools\')
addpath('Toolboxes_and_functions\libsvm-3.1\matlab\')


%% Load data file
disp('Loading data...')


% specify the data .mat file: 

% load StanfordDataset2010_uncorrected_Ankle_MSSE; toremove = [];      % uncorrected data in MSSE article
% load StanfordDataset2010_uncorrected_Wrist_MSSE; toremove = [];      % uncorrected data in MSSE article
load StanfordDataset2010_corrected_Ankle_MSSE; toremove = [];      % corrected data in MSSE article
% load StanfordDataset2010_corrected_Wrist_MSSE; toremove = [];        % corrected data in MSSE article

% load data.mat; toremove = [];


%% settings:
indiv = 1;  % indiv = 1 -> single subject validation; 
            % indiv = 0, multiple subjects nfold cv (all data are considered as behaving to a single subject and algorithms are cross validated)
l1o = 1;        % Assuming indiv = 1;
                % l1o = 1 -> leave-one-subject-out; l1o = 0, indiv nfold cv

nfold = 10; % for cross-valid

Np = length(Data)
% Np = 20 %% if you want to use less than 33 subjects specify a different value here

crop_transitions = true;    % data cropping settings
wincrop = 0                 % number of windows discarded before and after activity transition windows, 
                            % if (crop_transitions = true) and wincrop = 0 only the two windows across 
                            % the transition are discarded
     
% Manual feature set selection 
use_tf = 1;     % time-frequency analysis
use_wvlt1 = 0;  % wavelet, feature 1
use_wvlt2 = 0;  % wavelet, feature 2
use_mean = 1;
use_std  = 1;
use_max  = 1;
use_min  = 1;
use_Dmaxmin  = 0; % max-min
use_var  = 0;     % variance


% if you are not interested in displaying data or if raw data are not
% available set this to false
use_rawdata = true; %false;

%% Dataset organization in a single big matrix
reorganize_data;


%% crop data across label transitions
if crop_transitions
    transit = find(  (diff([0; Lab1w]) ~=0) | (diff([0; Lab2w]) ~=0)  ); 
%     transit = find(  diff(labels)~=0 )+1;% | (diff([0; Lab2w]) ~=0)  ); 
    L = length(transit);
    for u = 1:L
        tmp(u,:) = [transit(u)-wincrop-1 : transit(u)+wincrop];
    end
    to_remove = unique(tmp);
    to_remove = to_remove(to_remove>0);
    to_remove = to_remove(to_remove<=length(labels));
    inst(to_remove,:) = [];
    subj_label(to_remove,:) = [];
    labels(to_remove,:) = [];
    
    Lab1w(to_remove,:) = [];
    Lab2w(to_remove,:) = [];
    Lab3w(to_remove,:) = [];
    Lab4w(to_remove,:) = [];
    Lab5w(to_remove,:) = [];
        
    if use_rawdata
        Data_m(to_remove,:) = [];
    end
end


%% remove activities with few samples available
% % remove standing still while carrying load 
% %  and remove jumping jacks (few data windows)
% %  and remove lifting box   (few data windows)
ind_st = find(  (ismember(Lab2w,[11:14])) | (ismember(Lab1w,[18]) & ismember(Lab2w,[1]) ));

Lab1w(ind_st) = [];
Lab2w(ind_st) = [];
Lab3w(ind_st) = [];
Lab4w(ind_st) = [];
Lab5w(ind_st) = [];

labels(ind_st) = [];
inst(ind_st,:) = [];
subj_label(ind_st) = [];

if use_rawdata
   Data_m(ind_st) = [];
end

%% classification
disp('Classification:')

cm  = zeros(1,2,2);
cm4 = zeros(1,max(Lab2w),2);
cm5 = zeros(1,max(Lab2w),2);
cmg = zeros(1,max(Lab2w),2);

% parameters
maxlabelnum = max(labels);
classifierType = 1;
sc_low = -1;
sc_up = 1;
Pb = 0;
termination = 0.0000001;
kernel = 2;    % radial basis function kernel
bestc = 100;
bestg = 0.1;
%           kernel = 0;    % linear kernel (Zhang et al.)
%               bestg = [];
%               bestc = 1;
% Zhang: Support Vector Machine, the linear kernel was used,
% convergence tolerance ? = 1.0ej7, upper bound of
% complexity was c = 1.0;

[N,F] = size(inst);

if(indiv) % choose the validation approach
    if (l1o)
        [CM_aggr, CM, Acc_aggr, Acc, predict, modelliSVM, prob] = l1o_valid_svm2(inst, labels, subj_label, sc_low, sc_up, kernel, bestc, bestg, Pb, termination, maxlabelnum);
    else
        [CM_aggr, CM, Acc, predict, modelliSVM, prob] = indiv_cross_valid_svm2(inst, labels, subj_label, nfold, sc_low, sc_up, kernel, bestc, bestg, Pb, termination, maxlabelnum);
    end
else
    [CM_aggr, Acc, predict, modelliSVM, index, prob] = cross_valid_svm(inst, labels, nfold, sc_low, sc_up, kernel, bestc, bestg, Pb, termination, maxlabelnum);
end
[SE,SP,PR,se,sp,pr,tp,fp,tn,fn] = confmat2SeSpPr(CM_aggr,0);

if (indiv)
    % data sorting if needed
    pred = [];
    for n = 1:Np
        pred = [pred; predict{n}];
    end
    pred_i = pred;
else
    pred = [];
    for n = nfold:-1:1
        pred = [pred; predict{n}];
    end
    [sorted_dataind, sorting_ind] = sort(index);
    pred_i = pred(sorting_ind);
end


l1 = unique(labels);
l2 = unique(pred_i);
[CM3(l1,l2),ne3,lablist] = confmat(labels, pred_i);

acc = trace(CM3) / sum (sum(CM3));
disp('confusion matrix:')
CM_aggr
disp(['Accuracy: ' num2str(acc)])

for class = 1:maxlabelnum
    prec = pr(class);
    rec = se(class);
    F1_s(class) = 2/ (1/prec+1/rec);
    MCC(class) = (tp(class)*tn(class) - fp(class)*fn(class)) / sqrt( (tp(class)+fp(class))*(tp(class)+fn(class))*(tn(class)+fp(class))*(tn(class)+fn(class)));
end
disp(['F1 score:             ' num2str(F1_s)])
disp(['Mattews Corr. Coeff.: ' num2str(MCC)])

lab2 = unique(pred_i);
lab3 = unique(labels);


%% output

figure('units','normalized','position',[0.1 0.1 0.8 0.8]);
subplot(311),imagesc(labels'), axis tight, title('Actual label')
%         subplot(411),imagesc(Lab4w'+1), axis tight, title('Actual (multiclass)')
hold on
subplot(312),imagesc(pred_i'), axis tight , title('Classified')
subplot(313),imagesc(subj_label'), axis tight , title('Subject ID')


classif_name = {'svm'};
class_name = {'walk','cycle','non-locomotion','rest'};
figure('unit','normalized','position', [0.2 0.3 0.6 0.4]);
subplot(121),imagesc(CM_aggr);colorbar; axis square
somme = repmat(sum(CM_aggr,2),1,maxlabelnum);
cm_perc = CM_aggr ./ somme;
title ('Confusion matrix, # windows')
subplot(122),imagesc(cm_perc);colorbar; axis square
cm_perc
title ('Confusion matrix, % windows')

Y = 1;
if (indiv)
    for class = 1:maxlabelnum
        figure('units','normalized','position',[0.05 0.1 0.9 0.8]);
        set(gcf,'name',class_name{class})
        acc_single = zeros(length(CM),size(CM,3));
        for y = 1:Y
            for pp = 1: size(CM,3)
                acc_single(y,pp) = CM(class,class,pp) / sum(CM(class,:,pp));

                [SEs,SPs,PRs,ses,sps,prs,tps,fps,tns,fns] = confmat2SeSpPr(CM(:,:,pp),0);
                prec = prs(class);
                rec = ses(class);
                sF1_s(y,pp) = 2/ (1/prec+1/rec);
                sMCC(y,pp) = (tps(class)*tns(class) - fps(class)*fns(class)) / sqrt( (tps(class)+fps(class))*(tps(class)+fns(class))*(tns(class)+fps(class))*(tns(class)+fns(class)));
            end

            subplot(Y,1,y),
            bar(100*acc_single(y,:),'FaceColor','r'),xlabel('Participant'),ylabel('Accuracy');
            axis([0  size(CM,3)+1 0 105]), title( classif_name(1,y))

           % legend('Ankle', 'Wrist', 'location', 'SouthEast')
        end
    end
end


others = (ismember(Lab1w,[3:11 15 18])) | (ismember(Lab2w,[2:6 11:12 17:21])) ;

lab1 = unique(Lab2w(~others)); lab2 = unique(pred_i(~others));
[CM7(1:length(lab1),lab2) ,ne7,lablist7] = confmat(Lab2w(~others), pred_i(~others));
lab1 = unique(Lab2w(others)); lab2 = unique(pred_i(others));
[CM8(1:length(lab1),lab2) ,ne8,lablist8] = confmat(Lab2w(others), pred_i(others));

availclass = unique(Lab2w(~others));
availclass2 = unique(Lab2w(others));
for aa = 1:length(availclass)
    availact_g{aa} = [ num2str(availclass(aa)) ': ' Info(1).activity_labels{availclass(aa)}  ];
end
for aa = 1:length(availclass2)
    availact_g2{aa} = [ num2str(availclass2(aa)) ': ' Info(1).activity_labels{availclass2(aa)}  ];
end


class1 = [1 22 23 25:33 35:36 45 65:73];
class2 = [7:10 43 44 47 48 ];
c1 = ismember( availclass, class1);
c2 = ismember( availclass, class2);
c3 = ~(c1|c2);

f1 = figure('units','normalized','position',[0.1 0.1 0.8 0.8]);
cnames = {'Walking', 'Cycling', 'Others', 'Resting'};
t = uitable('Data',CM7(c1,:),'ColumnName',cnames,'RowName',availact_g(c1),'units','normalized','Position',[0.05 0.6 0.4 0.35]);
t = uitable('Data',CM7(c2,:),'ColumnName',cnames,'RowName',availact_g(c2),'units','normalized','Position',[0.50 0.6 0.45 0.35]);
t = uitable('Data',CM7(c3,:),'ColumnName',cnames,'RowName',availact_g(c3),'units','normalized','Position',[0.05 0.02 0.4 0.53]);
t = uitable('Data',CM8,'ColumnName',cnames,'RowName',availact_g2,'units','normalized','Position',[0.50 0.20 0.45 0.35]);
acc1 = sum(CM7(c1,1))/sum(sum(CM7(c1,:)));
acc2 = sum(CM7(c2,2))/sum(sum(CM7(c2,:)));
acc3 = sum(CM7(c3,3))/sum(sum(CM7(c3,:)));
acc4 = sum(CM8(:,4))/sum(sum(CM8(:,:)));
annotation(f1,'textbox',[0.05 0.9578 0.2236 0.02969],'String',{['Walking, accuracy = ' num2str(acc1)]},'FitBoxToText','off');
annotation(f1,'textbox',[0.50 0.9566 0.2236 0.02969],'String',{['Cycling, accuracy = ' num2str(acc2)]},'FitBoxToText','off');
annotation(f1,'textbox',[0.05 0.555 0.2236 0.02969],'String',{['Other activities, accuracy = ' num2str(acc3)]},'FitBoxToText','off');
annotation(f1,'textbox',[0.50 0.555 0.2236 0.02969],'String',{['Resting, accuracy = ' num2str(acc4)]},'FitBoxToText','off');

index3 = ismember (Lab2w, [38 37]);
availpost = unique(Lab1w(index3));
for aa = 1:length(availpost)
    availpost_g{aa} = [ num2str(availpost(aa)) ': ' Info(1).posture_labels{availpost(aa)}  ];
end

lab1 = unique(Lab1w(index3)); lab2 = unique(pred_i(index3));
[CM9(1:length(lab1),lab2),ne9,lablist9] = confmat(Lab1w(index3), pred_i(index3));
t = uitable('Data',CM9,'ColumnName',cnames,'RowName',availpost_g,'units','normalized','Position',[0.50 0.02 0.45 0.15]);

