'''
Author: Carl Yang
Function: The neural network embedding model
Command: library
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
from collections import OrderedDict


# An L2 normalization layer that projects arbitrary vectors
# onto the surface of a norm-one ball.
class L2Normalization(torch.nn.Module):
    def __init__(self):
        super(L2Normalization, self).__init__()

    def forward(self, x):
        xt = x.t()
        norm = xt.pow(2).sum(0).sqrt()
        yt = torch.where(norm > 0, xt.div(norm), xt)
        return yt.t()


class NN(object):
    # Model initialization with current parameters.
    def __init__(self, params):
        self.params = params
        self.build_model()

    # Build the neural network based supervised
    # metric learning embedding projection model.
    def build_model(self):
        shared_layers = OrderedDict()
        key_layers = OrderedDict()
        value_layers = OrderedDict()
        if self.params.model == 'fnn':
            shared_layers['shared_norm0'] = torch.nn.BatchNorm1d(self.params.feat_dim)
            shared_layers['shared_drop0'] = torch.nn.Dropout(p=0.2)
            if self.params.shared_layer > 0:
                shared_layers['shared_fc0'] = torch.nn.Linear(
                    self.params.feat_dim,
                    self.params.shared_dim[0]
                )
                shared_layers['shared_act0'] = self.activation()

                for i in range(1, self.params.shared_layer):
                    shared_layers['shared_norm{0}'.format(i)] = \
                        torch.nn.BatchNorm1d(self.params.fnn_dim[i - 1])
                    shared_layers['shared_drop{0}'.format(i)] = torch.nn.Dropout(p=0.2)
                    shared_layers['shared_fc{0}'.format(i)] = torch.nn.Linear(
                        self.params.fnn_dim[i - 1],
                        self.params.fnn_dim[i]
                    )
                    shared_layers['shared_act{0}'.format(i)] = self.activation()
                feat_dim_after_shared = \
                    self.params.shared_dim[self.params.shared_layer - 1]
            else:
                feat_dim_after_shared = self.params.feat_dim

            key_layers['key_norm0'] = torch.nn.BatchNorm1d(feat_dim_after_shared)
            key_layers['key_drop0'] = torch.nn.Dropout(p=0.2)
            if self.params.key_layer > 0:
                key_layers['key_fc0'] = torch.nn.Linear(
                    feat_dim_after_shared,
                    self.params.key_dim[0]
                )
                key_layers['key_act0'] = self.activation()

                for i in range(1, self.params.key_layer):
                    key_layers['key_norm{0}'.format(i)] = \
                        torch.nn.BatchNorm1d(self.params.key_dim[i - 1])
                    key_layers['key_drop{0}'.format(i)] = torch.nn.Dropout(p=0.2)
                    key_layers['key_fc{0}'.format(i)] = torch.nn.Linear(
                        self.params.key_dim[i - 1],
                        self.params.key_dim[i]
                    )
                    key_layers['key_act{0}'.format(i)] = self.activation()
                key_dim = self.params.key_dim[self.params.key_layer - 1]
            else:
                key_dim = feat_dim_after_shared
            key_layers['project'] = L2Normalization()

            value_layers['value_norm0'] = torch.nn.BatchNorm1d(feat_dim_after_shared)
            value_layers['value_drop0'] = torch.nn.Dropout(p=0.2)
            if self.params.key_layer > 0:
                value_layers['value_fc0'] = torch.nn.Linear(
                    feat_dim_after_shared,
                    self.params.value_dim[0]
                )
                value_layers['value_act0'] = self.activation()

                for i in range(1, self.params.value_layer):
                    value_layers['value_norm{0}'.format(i)] = \
                        torch.nn.BatchNorm1d(self.params.value_dim[i - 1])
                    value_layers['value_drop{0}'.format(i)] = torch.nn.Dropout(p=0.2)
                    value_layers['value_fc{0}'.format(i)] = torch.nn.Linear(
                        self.params.value_dim[i - 1],
                        self.params.value_dim[i]
                    )
                    value_layers['value_act{0}'.format(i)] = self.activation()
                value_dim = self.params.value_dim[self.params.value_layer - 1]
            else:
                value_dim = feat_dim_after_shared
            value_layers['project'] = L2Normalization()

        else:
            raise Exception('Unknown model {}'.format(self.params.model))

        self.shared_nn = torch.nn.Sequential(shared_layers)
        self.key_nn = torch.nn.Sequential(key_layers)
        self.value_nn = torch.nn.Sequential(value_layers)
        self.query_emb = torch.nn.Embedding(len(self.params.sources), 2 * key_dim)
        self.center_emb = torch.nn.Embedding(self.params.n_center, value_dim)
        self.optimizer = torch.optim.Adam([
            {'params': self.shared_nn.parameters()},
            {'params': self.key_nn.parameters()},
            {'params': self.value_nn.parameters()},
            {'params': self.query_emb.parameters()},
            {'params': self.center_emb.parameters()}
        ], lr=self.params.learning_rate)

    # Compute the projected embedding from input features.
    def forward(self, x):
        y = self.value_nn(self.shared_nn(x))
        return y

    # Compute the loss with current model settings.
    def get_loss(self, x, y, z):
        if self.params.loss == 'trip':
            a_emb = self.forward(x)
            p_emb = self.forward(y)
            n_emb = self.forward(z)
            if self.params.dis == 'eu':
                d1 = (a_emb - p_emb).pow(2).sum(1)
                d2 = (a_emb - n_emb).pow(2).sum(1)
            else:
                d1 = 1 - torch.nn.CosineSimilarity()(a_emb, p_emb)
                d2 = 1 - torch.nn.CosineSimilarity()(a_emb, n_emb)
            hard_idx = (d1 - d2) > (d1 - d2).mean().mul(self.params.beta)
            margin_idx = (d1 + self.params.alpha) > d2
            error_idx = d1 > d2
            if self.params.hard_sample:
                d1 = d1[hard_idx]
                d2 = d2[hard_idx]
            loss = (d1 - d2 + self.params.alpha).clamp(min=0).sum()
        elif self.params.loss == 'pair':
            a_emb = self.forward(x)
            b_emb = self.forward(y)
            if self.params.dis == 'eu':
                d = (a_emb - b_emb).pow(2).sum(1).unsqueeze(dim=1)
            else:
                d = 1 - torch.nn.CosineSimilarity()(a_emb, b_emb).unsqueeze(dim=1)
            s = z[:, -1]
            z = z[:, -2:-1]
            critic = d.mul(z) - d.mul(1 - z)
            loss = d.mul(z) + (self.params.alpha - d).clamp(min=0).mul(1 - z)
            hard_idx = critic > critic.mean().mul(self.params.beta)
            margin_idx = critic > 0.5 * self.params.alpha
            error_idx = critic > 0.5
            if self.params.attentive:
                key = torch.cat((
                    self.key_nn(self.shared_nn(x)),
                    self.key_nn(self.shared_nn(y))
                ), dim=1)
                if self.params.device == torch.device('cuda'):
                    query_idx = torch.cuda.LongTensor(
                        range(len(self.params.sources)))
                else:
                    query_idx = torch.LongTensor(
                        range(len(self.params.sources)))
                query = self.query_emb(query_idx)
                weight = torch.nn.Softmax(dim=1)(torch.mm(key, query.t()))
                loss = loss.mul(weight.gather(1, s.long().view(-1, 1)))
            if self.params.hard_sample:
                loss = loss[hard_idx]
            loss = loss.sum()
            if self.params.denoising:
                a = torch.unsqueeze(
                    a_emb, 0).expand(self.params.n_center, -1, -1)
                if self.params.device == torch.device('cuda'):
                    center_idx = torch.cuda.LongTensor(
                        range(self.params.n_center))
                else:
                    center_idx = torch.LongTensor(
                        range(self.params.n_center))
                c = torch.unsqueeze(
                    self.center_emb(center_idx), 1).expand(-1, x.shape[0], -1)
                du = (a - c).pow(2).sum(dim=2) + 1
                dd = du.pow(-1).sum(dim=0)
                d = du.pow(-1) / dd
                g = d.sum(dim=1)
                cu = (d.pow(2).t() / g).t()
                c = cu / cu.sum(dim=0)
                loss_d = ((c / d).log() * c).sum()
                loss += self.params.rho * loss_d
        else:
            raise Exception('Unknown loss function {}'.format(self.params.loss))
        return (
            loss,
            hard_idx.sum().item(),
            margin_idx.sum().item(),
            error_idx.sum().item()
        )

    # Decide the activation function to use.
    def activation(self):
        if self.params.fnn_act == 'relu':
            return torch.nn.ReLU()
        elif self.params.fnn_act == 'sigmoid':
            return torch.nn.Sigmoid()
        elif self.params.fnn_act == 'elu':
            return torch.nn.ELU()
        else:
            raise Exception('Unknown activation function {}'.format(self.params.act))
