function [W,P,Q] = tree_level_lasso_v02(XY,gamma)
%{
This is the main code of the MREF algorithm based on ADMM for the paper:

Liang Zhao, Feng Chen, Chang-Tien Lu, and Naren Ramakrishnan. "Multi-resolution Spatial Event Forecasting in Social Media." in Proceedings of the IEEE International Conference on Data Mining (ICDM 2016), regular paper, (acceptance rate: 8.5%), pp. 689-698, Barcelona, Spain, Dec 2016.

Please cite the following paper in any work that uses this material:

@inproceedings{zhao2016multi,
  title={Multi-resolution spatial event forecasting in social media},
  author={Zhao, Liang and Chen, Feng and Lu, Chang-Tien and Ramakrishnan, Naren},
  booktitle={2016 IEEE 16th International Conference on Data Mining (ICDM)},
  pages={689--698},
  year={2016},
  organization={IEEE}
}

The main function of the parameter optimization algorithm for the model
MREF; The detailed algorithmatic introduction is detailed in the paper
Input:
    - XY:                   1*4 cell    each cell has the following meaning:
        > XY{1,1}           1*3 cell    each cell has the following meaning:
            >XY{1,1}{1,i}   1*n_l cell  for each level i, each cell is a
            matrix of input data which is a matrix of the size: num_of_samples *
            num_of_features.
            
        > XY{1,2}           1*3 cell    each cell is the output data for each location under each of the three geographical levels.
            >XY{1,2}{1,i}   1*n_l cell  for each level i, each cell is a
            matrix of output data which is a vector of the size: num_of_samples * 1.
        > XY{1,3}           1*18 cell   each i-th cell contains the indices of the 3rd-level
        locations (e.g., a city) which belongs to the i-th 2nd-level location (e.g., a state)
        > XY{1,4}           1*2 cell    each cell is a reverse mapping whose i-th
        element (e.g., for a city) is the index of higher-level location (e.g., a state) it belongs to.     
    - gamma:    scalar      regularization parameter

Output: For the meaning and format of W, P, and Q, please refer to the
descriptions in the paper.

For any questions, welcome to email Liang Zhao by: lzhao9@gmu.edu 

%} 
Q_exists = 0;
[X,Y,s2c_map,c2s_map,~,~]=deal(XY{:});
numCities = size(X{1,1},2);
numStates = size(X{1,2},2);
numDates = size(X{1,1}{1,1},1);
numFeatures = size(X{1,1}{1,1},2);
rho = 1;
W=cell(1,3);W{1,1}=zeros(numFeatures,numCities);W{1,2}=zeros(numFeatures,numStates);W{1,3}=zeros(numFeatures,1);
P=cell(1,3);P{1,1}=zeros(numFeatures,numCities);P{1,2}=zeros(numFeatures,numStates);P{1,3}=zeros(numFeatures,1);
U=cell(1,3);U{1,1}=zeros(numFeatures,numCities);U{1,2}=zeros(numFeatures,numStates);U{1,3}=zeros(numFeatures,1);
V=cell(1,3);V{1,1}=zeros(numFeatures,numCities);V{1,2}=zeros(numFeatures,numStates);V{1,3}=zeros(numFeatures,1);
Q=cell(1,3);Q{1,1}=zeros(numFeatures,numCities);Q{1,2}=zeros(numFeatures,numStates);Q{1,3}=zeros(numFeatures,1);
A1 = cell(1,3);A1{1,1}=zeros(numFeatures,numCities);A1{1,2}=zeros(numFeatures,numStates);A1{1,3}=zeros(numFeatures,1);
A2 = cell(1,3);A2{1,1}=zeros(numFeatures,numCities);A2{1,2}=zeros(numFeatures,numStates);A2{1,3}=zeros(numFeatures,1);
A3 = cell(1,3);A3{1,1}=zeros(numFeatures,numCities);A3{1,2}=zeros(numFeatures,numStates);A3{1,3}=zeros(numFeatures,1);
eps_pri = 1e-3;
eps_dual = 1e-3;
MAX_ITER = 1000;
for iter=1:MAX_ITER
    U_old = U; V_old = V; P_old = P; Q_old = Q;
    for i=1:3
        cur_size = size(X{1,i},2);
        if i==3
            W{1,i}=update_W(X{1,i}{1,1},Y{1,i}{1,1},P{1,i},Q{1,i},rho,A1{1,i});
            P{1,i}=update_P(Q{1,i},U{1,i},V{1,i},W{1,i},A1{1,i},A2{1,i},A3{1,i},rho);
        else
            for j=1:cur_size
                W{1,i}(:,j)=update_W(X{1,i}{1,j},Y{1,i}{1,j},P{1,i}(:,j),Q{1,i}(:,j),rho,A1{1,i}(:,j));
                P{1,i}(:,j)=update_P(Q{1,i}(:,j),U{1,i}(:,j),V{1,i}(:,j),W{1,i}(:,j),A1{1,i}(:,j),A2{1,i}(:,j),A3{1,i}(:,j),rho);
            end
        end
    end
    U=update_U(P,A2,rho,gamma);
    V=update_V(P,A3,rho,gamma,s2c_map);
    if Q_exists == 1
        Q=update_Q(W,P,A1,rho,gamma);
    end
    sum_Q = sum(sum(Q{1,1}))
    sum_W = sum(sum(W{1,1}))
    sum_P = sum(sum(P{1,1}))
    if(sum_Q==0)
        aa = 1;
    end
    for i=1:3
        A1{1,i}=update_A(A1{1,i},W{1,i},P{1,i}+Q{1,i},rho);
        A2{1,i}=update_A(A2{1,i},P{1,i},U{1,i},rho);
        A3{1,i}=update_A(A3{1,i},P{1,i},V{1,i},rho);
    end
    s_mat = 0;
    r_mat = 0;
    for i=1:3
        r_mat = r_mat + sum(sum((W{1,i}-P{1,i}-Q{1,i}).^2))+...
                sum(sum((P{1,i}-U{1,i}).^2))+...
                sum(sum((P{1,i}-V{1,i}).^2));
        s_mat = s_mat + sum(sum((P_old{1,i}-P{1,i}+Q_old{1,i}-Q{1,i}).^2))+...
                sum(sum((U_old{1,i}-U{1,i}+V_old{1,i}-V{1,i}).^2+(Q_old{1,i}-Q{1,i}).^2));
    end
    s = rho*sqrt(s_mat);
    r = sqrt(r_mat);
    if(r>10*s)
        rho = 2*rho;
    else
        if(10*r<s)
            rho = rho/2;
        end
    end
    fprintf('r:%e\ts:%e\trho:%f\n',r,s,rho);
    if(r < eps_pri && s < eps_dual)
        break;
    end
end
end

function w = update_W(x,y,p,q,rho,alpha)
lambda = 1;
try
left = x'*x+(rho+lambda)*eye(size(x,2));
catch
    aa = 1;
end
right = rho*(p+q)'+y'*x-alpha';
w = (left')\(right');
end
function p = update_P(q,u,v,w,alpha1,alpha2,alpha3,rho)
p=1/(3*rho)*(rho*v+alpha1+rho*u+rho*w-rho*q-alpha2-alpha3);
end
function u = update_U(p,alpha2,rho,gamma0)
gamma = gamma0/sqrt(size(p{1,2},2)+1);
x=[p{1,2}+alpha2{1,2}/rho,p{1,3}+alpha2{1,3}/rho];
x = x';
[~,x1] = evalc('prox_l21(x,2*gamma/rho)');
x1=x1';
u = p;
u{1,2}=x1(:,1:end-1);
u{1,3}=x1(:,end);
end
function v = update_V(p,alpha3,rho,gamma0,s2c_map)
numStates = size(p{1,2},2);
v=p;
for i=1:numStates
    x=[p{1,1}(:,s2c_map{1,i})+alpha3{1,1}(:,s2c_map{1,i})/rho,p{1,2}(:,i)+alpha3{1,2}(:,i)/rho];
    x=x';
    gamma = gamma0/sqrt(size(s2c_map{1,i},2)+1);
    [~,x1]=evalc('prox_l21(x,2*gamma/rho)');
    x1=x1';
    v{1,1}(:,s2c_map{1,i})=x1(:,1:end-1);
    v{1,2}(:,i)=x1(:,end);
end
end
function q = update_Q(w,p,alpha1,rho,gamma0)

q={w{1,1}-p{1,1},w{1,2}-p{1,2},w{1,3}-p{1,3}};
numLevels = size(w,2);
for i=1:numLevels
    x=w{1,i}-p{1,i}+alpha1{1,i}/rho;
    x=x';
    gamma = gamma0/sqrt(size(x',2));
    [~,x1]=evalc('prox_l21(x,2*gamma/rho)');
    x1=x1';
    q{1,i}=x1;
end
end

function a = update_A(a,left,right,rho)
a = a + rho*(left-right);
end



