function [best_model, CM]=SIMDA(NN_arch, X, Y, beta, latitude, longitude, lr, output, max_iter)
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% ADMM optimizer for model SIMDA in the paper "Incomplete Label Multi-task 
% Deep Learning for Spatio-temporal Event Subtype Forecasting" 
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The input of the function is as follows:   
% NN_arch: network archtecture: [input_dim, hidden_dim , ... , output_dim]
% X: the input data, a 3-D martix where each dimension represents [task, 
%    sample, feature vector]
% Y: the class labels, a 2-D martix where each dimension represents [task, 
%    sample]
% beta: theta constriants co-efficient (selected using validation set)
% latitude & longitude: the pythical location of the corresponding sptial 
%                       task
% lr: the learning rate
% output: controls whether to output the training logs
% max_iter: the maximum training iterations
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
% The return values of the function are as follows:
% best_model: the model that achieves best performance on validation set 
%             during training.
% CM: the confusion matrix of the best_model on test data
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%

% get each class instances
values = unique(Y);
k=size(values,1);

TypeAcc=@(confusionMat) sum(diag(confusionMat(2:end,2:end)))/sum(sum(confusionMat(2:end,1:end)));
precision = @(confusionMat) diag(confusionMat)./sum(confusionMat,2);
recall = @(confusionMat) diag(confusionMat)./sum(confusionMat,1)';
f1Scores = @(confusionMat) 2*(precision(confusionMat).*recall(confusionMat))./(precision(confusionMat)+recall(confusionMat));
meanP = @(confusionMat) nansum(precision(confusionMat))/k;
meanR = @(confusionMat) nansum(recall(confusionMat))/k;
meanF1 = @(confusionMat) nansum(f1Scores(confusionMat))/k;


randseed=1;
    
% Use 50% of the data (1 year data) as training set and rest as test data
TRAIN_size=fix(size(X,2)/2);
TEST_size=fix(size(X,2)/2);

% use 85% data for training
missing_ratio=1;
re_size=fix(TRAIN_size*missing_ratio);
rng(randseed);
idxs=randsample(TRAIN_size,re_size);

X_train=X(:,idxs(1:re_size),:);
Y_train=Y(:,idxs(1:re_size));

% the rest as test set
X_test=X(:,fix(size(X,2)/2)+1:fix(size(X,2)/2)+TEST_size,:);
Y_test=Y(:,fix(size(X,2)/2)+1:fix(size(X,2)/2)+TEST_size);

TASK_num= size(Y_train, 1);

% precompute distance matrix between tasks using lat&lon geo-infomation
distances=zeros(TASK_num,TASK_num);
for i=1:TASK_num
    for j=1:TASK_num
        distances(i,j) = exp(-distance(latitude(i),longitude(i),latitude(j),longitude(j)));
    end
    distances(i,i)=0;
end

rng(randseed);
idxs=randperm(TEST_size);

% use 1/2 data as validation set
v=fix(TEST_size/2);

X_vali=X_test(:,idxs(1:v),:);
Y_vali=Y_test(:,idxs(1:v));
X_test=X_test(:,idxs(v+1:end),:); 
Y_test=Y_test(:,idxs(v+1:end));

VALI_size=size(Y_vali,2);
TEST_size=size(Y_test,2);
TRAIN_size=size(Y_train,2);

ratio= size(Y_train(:),1)./histc(Y_train(:),values);
%% format Y
train_y = cell(TASK_num,1);

   for i = 1 : TASK_num
        groundTruth = zeros(TRAIN_size,k);
        for j=1:TRAIN_size
            groundTruth(j,Y_train(i,j)) =1;
        end
        train_y{i}=groundTruth;
   end

%% Model

combos = combntns(1:k,2);
M=zeros(k,nchoosek(k,2));

for i=1:nchoosek(k,2)
    M(combos(i,1),i)=1;
    M(combos(i,2),i)=-1;
end

Msize=size(M,2);

% initialize parameters for training data for each task(size= d*k*TASK_num)
nnall=cell(TASK_num,1);
nnzfx=cell(TASK_num,1);

for i=1: TASK_num
    nn = nnsetup(NN_arch);
    opts.numepochs         = 1;            %  Number of full sweeps through data
    nn.activation_function = 'sigm';       %  Use sigmoid fun
    nn.learningRate = 0.5;                    %  Sigm require a lower learning rate
    nn.dropoutFraction=0;
    nn.output              = 'softmax';    %  use softmax output
    opts.batchsize         = TRAIN_size;
    nnall{i}=nn;
   
end

%share the bottom weight over all tasks
for l=1:TASK_num
    for j=1:nnall{l}.n-2
        nnall{l}.W{j}=nnall{1}.W{j};
    end
end

d=nn.size(end-1)+1;

theta=zeros(TASK_num, k, d);

for i=1: TASK_num
    theta(i,:,:)=nnall{i}.W{end};     % the last layer that task specific
end

V=theta;

Y2=zeros(size(theta));

W=theta;

Y3=zeros(size(theta));

% admm param
rho=1;

% for better speed
tempW=cell(TASK_num,1);
tempZ=cell(TASK_num,1);
tempV=cell(TASK_num,1);

% initialize Z=f(x) Y1
Y1=cell(TASK_num,1);
for l=1:TASK_num
    nnall{l} = nnff(nnall{l}, squeeze(X_train(l,:,:)), train_y{l});
    tempZ{l}=nnall{l}.a{end-1}(:, 2:end);
    Y1{l}=zeros(size(tempZ{l}));
end

best_MF1=0;
best_AUC=0;
for adi= 1: max_iter

% precompute kron(ZZ,MM)
kronZM=cell(TASK_num,1);
for l=1:TASK_num
    kronZM{l}=kron([ones(TRAIN_size,1) tempZ{l}]'*[ones(TRAIN_size,1) tempZ{l}],M*M');
end

%% Update theta via gradient decent

L_old=inf;
    
for i=1:50
    LZ = 0;
    LN = 0;
    
    for l = 1 : TASK_num
	%step 3: Neural network parameter sets update over label
        nnall{l} = nnff(nnall{l}, squeeze(X_train(l,:,:)), train_y{l});
        LN=LN+nnall{l}.L;
        %BP
        nnall{l} = nnbp(nnall{l});

        % apply gradient to task specific layer
        j = nn.n - 1;
        nnall{l}.W{j} = nnall{l}.W{j} - lr * nnall{l}.dW{j};
        %nnall{l}.W{j} = nnall{l}.W{j} - lr * repmat(sqrt(ratio), [1 d]).* nnall{l}.dW{j};
        
        % apply gradient to shared bottom layer(s)
        for j = 1 : (nn.n - 2)
           nnall{l}.W{j} = nnall{l}.W{j} - lr * nnall{l}.dW{j};
        end
        
        %update theta
        theta(l,:,:)=nnall{l}.W{end};
        
        % share the bottom weight over all tasks
        for ll=1:TASK_num
            for j=1:nn.n-2
                nnall{ll}.W{j}=nnall{l}.W{j};
            end
        end
        
    end
    
    Vgard=Y2+ rho * (theta-V);
    Wgard=Y3+ rho * (theta-W);
    
	for l = 1 : TASK_num

    %step 2: Task specific (output layer) weight matrix theta update:
        nnall{l}.W{end} = nnall{l}.W{end} - lr/d* (squeeze(Vgard(l,:,:)) + squeeze(Wgard(l,:,:)));
        %nnall{l}.W{end} = nnall{l}.W{end} - lr/d* (squeeze(Vgard(l,:,:)) + squeeze(Wgard(l,:,:)));
    end
    
if size(NN_arch,2)>2
    for l = 1 : TASK_num
    %step 3:Shared bottom layers weight update:
        nnzfx{l}=nnall{l};
        nnzfx{l}.size=nnall{l}.size(1:end-1);
        nnzfx{l}.W=nnall{l}.W(1:end-1);
        nnzfx{l}.n=nnall{l}.n-1;
        nnzfx{l} = rmfield(nnzfx{l},'a');
        nnzfx{l} = rmfield(nnzfx{l},'dW');
        nnzfx{l}.output='sigm';
        
        nnzfx{l} = nnff(nnzfx{l}, squeeze(X_train(l,:,:)), tempZ{l});
        LZ=LZ+1/TRAIN_size*(sum(sum(Y1{l}.*(tempZ{l}-nnzfx{l}.a{end})))+(rho / 2.) * sum(sum((tempZ{l}-nnzfx{l}.a{end}).^2)));
        nnzfx{l}.e = rho*(tempZ{l}-nnzfx{l}.a{end})+Y1{l};
        nnzfx{l} = nnbp(nnzfx{l});
        
        % apply gradient to shared bottom layer(s)
        for j = 1 : (nnzfx{l}.n - 1)
           nnall{l}.W{j} = nnall{l}.W{j} - lr * nnzfx{l}.dW{j};
        end
        
        % share the bottom weight over all tasks
        for ll=1:TASK_num
            for j=1:(nnzfx{l}.n - 1)
                nnall{ll}.W{j}=nnall{l}.W{j};
            end
        end
    end
end
        
    LT=1/d*(sum(sum(sum(Y2.*(theta-V))))+(rho / 2.) * sum(sum(sum((theta-V).^2)))+sum(sum(sum(Y3.*(theta-W))))+(rho / 2.) * sum(sum(sum((theta-W).^2))));
    L=LZ+LN+LT;
	%disp(L);
	if abs(L_old-L)<1e-3
        break;
    end
	L_old=L;
end

%% Update V

    for i=1:TASK_num
        tempW{i}=squeeze(W(i,:,:));
        tempV{i}=squeeze(V(i,:,:));
    end

    V_old=V;

    %compute regularization term 2
    %term2sum in L*S*C
    term2sum=zeros(TASK_num,TRAIN_size,Msize);
    distance_norm=zeros(TASK_num,1);
    for l=1:TASK_num
        distance_norm(l)=sum(distances(l,:));
            tempsum=zeros(TRAIN_size,Msize);
            for c=1:TASK_num
                if c~=l
                    tempsum=tempsum+distances(l,c)*([ones(TRAIN_size,1) tempZ{c}]*(tempW{c}'*M));
                    
                end
            end
        term2sum(l,:,:)=tempsum/distance_norm(l);
    end
    
    % temporary save for update w usage
    tempterm2=cell(TASK_num,Msize);
    for c=1:TASK_num
        for m=1:Msize
        	tempterm2{c,m}=tempZ{c}'*repmat(M(:,m)',TRAIN_size,1);
        end
    end

    V_vec=zeros(TASK_num,d*k);
    I=eye(d*k);
    for l=1:TASK_num
        a=(beta/(TRAIN_size*Msize)*kronZM{l}+rho*I);
        b=( squeeze(Y2(l,:,:))+rho*squeeze(theta(l,:,:))+(beta/(TRAIN_size*Msize))*M*squeeze(term2sum(l,:,:))'*[ones(TRAIN_size,1) tempZ{l}] );
        V_vec(l,:)=a\b(:);
    end
    
    V=reshape(V_vec,TASK_num,k,d);
    
%% Update W

for i=1:TASK_num
    tempV{i}=squeeze(V(i,:,:));
end

    W_old=W;   
     
    %precompute term2
    term2=zeros(TASK_num,k,d);
    for c=1:TASK_num
        tempsum2=zeros(k,d);
        for l=1:TASK_num
                tempsum=zeros(TRAIN_size,Msize);
                for i=1:TASK_num
                    if i~=l&&i~=c
                        tempsum=tempsum+distances(l,i)*[ones(TRAIN_size,1) tempZ{i}]*(tempW{i}'*M);
                    end
                end
            tempsum2=tempsum2+distances(l,c)/distance_norm(l)*M*(tempsum/distance_norm(l)-[ones(TRAIN_size,1) tempZ{l}]*(tempV{l}'*M))'*[ones(TRAIN_size,1) tempZ{c}];
        end
        term2(c,:,:)=tempsum2;
    end
    
    %precompute disnormsum
    disnormsum=zeros(TASK_num,1);
    for c=1:TASK_num
        for l=1:TASK_num
            if l~=c
                disnormsum(c)=disnormsum(c)+(distances(l,c)/distance_norm(l))^2;
            end
        end
    end
    
    W_vec=zeros(TASK_num,d*k);
    I=eye(d*k);
    for c=1:TASK_num
        a=( (beta/(TRAIN_size*Msize)) * disnormsum(c) *kronZM{c}+rho*I);
        b=( squeeze(Y3(c,:,:))+rho*squeeze(theta(c,:,:))-(beta/(TRAIN_size*Msize))*squeeze(term2(c,:,:)) );
        W_vec(c,:)=a\b(:);
    end
    
    W=reshape(W_vec,TASK_num,k,d);

%% Update Z
Z_old=cell(TASK_num,1);

for i=1:TASK_num
    tempW{i}=squeeze(W(i,:,:));
    Z_old{i}=tempZ{i};
end
    %compute regularization term 2
    %term2sum in L*S*C
    term2sum=zeros(TASK_num,TRAIN_size,Msize);
    distance_norm=zeros(TASK_num,1);
    for l=1:TASK_num
        distance_norm(l)=sum(distances(l,:));
            tempsum=zeros(TRAIN_size,Msize);
            for c=1:TASK_num
                if c~=l
                    tempsum=tempsum+distances(l,c)*([ones(TRAIN_size,1) tempZ{c}]*(tempW{c}'*M));
                    
                end
            end
        term2sum(l,:,:)=tempsum/distance_norm(l);
    end
    
for l=1:TASK_num
    nnall{l} = nnff(nnall{l}, squeeze(X_train(l,:,:)), train_y{l});
    %nnzfx{l}.a{end}/ nnall{l}.a{end-1}(:,2:end)
    a=-Y1{l}+rho*nnall{l}.a{end-1}(:,2:end)+beta/(TRAIN_size*Msize)*squeeze(term2sum(l,:,:))*M'*tempV{l}(:,2:end);
    b=beta/(TRAIN_size*Msize)*tempV{l}(:,2:end)'*(M*M')*tempV{l}(:,2:end)+rho*eye(d-1);
    tempZ{l}=a/b;
end


%% Update Y
for l=1:TASK_num
    Y1{l}=Y1{l}+rho*(tempZ{l}-nnall{l}.a{end-1}(:,2:end));
end
    Y2=Y2+rho*(theta-V);
	Y3=Y3+rho*(theta-W);
    
    %% test on validation set for finding hyper-paramter and stop critira
label=zeros(TASK_num,VALI_size);
clear p;
for l=1:TASK_num
    [label(l,:),p(l,:,:)] = nnpredict(nnall{l}, squeeze(X_vali(l,:,:)));
end

[X_roc,Y_roc,AUC]=macro_ROC(Y_vali(:),reshape(p,[],size(p,3),1));

index=find(label(:)~=Y_vali(:));

CM=confusionmat(Y_vali(:),label(:));

        MP=meanP(CM);
        MR=meanR(CM);
        MF1=meanF1(CM);
        AUC=mean(AUC);
        
    if output
        fprintf('iter %d: precision: %f\trecall: %f\tF1: %f\tMZE: %f\tAUC: %f\n',adi,MP,MR,MF1,size(index,1)/(VALI_size*TASK_num),mean(AUC));
    end
    
    %% update best model & setup stop critira (if any)
    if MF1>best_MF1
        best_MF1=MF1;
        best_model=nnall;
    end
        
end
      

%% finally, test learned model on test set
label=zeros(TASK_num,TEST_size);
clear p;
for l=1:TASK_num
    [label(l,:),p(l,:,:)] = nnpredict(best_model{l}, squeeze(X_test(l,:,:)));
end

[X_roc,Y_roc,AUC]=macro_ROC(Y_test(:),reshape(p,[],size(p,3),1));
save(['ROC_SIMDA_' mat2str(NN_arch) '_' mat2str(beta) '.mat'],'X_roc','Y_roc','AUC','CM');

index=find(label(:)~=Y_test(:));

CM=confusionmat(Y_test(:),label(:));

        MP=meanP(CM);
        MR=meanR(CM);
        MF1=meanF1(CM);   
        
        Acc=TypeAcc(CM);
        
fprintf('on valiset: best F1: %f\n',best_MF1);
fprintf('on testset: precision: %f\trecall: %f\tF1: %f\tMZE: %f\tAUC: %f\n',MP,MR,MF1,size(index,1)/(TEST_size*TASK_num),mean(AUC));


end