'''
Author: Carl Yang
Function: Train the supervised embedding model.
Command: library
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from tqdm import tqdm
from datetime import datetime
from torch.autograd import Variable
from experimental.carlyang.place_embedding.model import NN


class Embedding(object):
    # Create a model and put it onto the current device.
    def __init__(self, params):
        self.params = params
        self.nn = NN(self.params)
        self.nn.shared_nn.to(self.params.device)
        self.nn.key_nn.to(self.params.device)
        self.nn.value_nn.to(self.params.device)
        self.nn.query_emb.to(self.params.device)
        self.nn.center_emb.to(self.params.device)

    # Split the input into triplets or pairs.
    def convert_input(self, data):
        ind_x = range(self.params.feat_dim)
        ind_y = range(self.params.feat_dim, 2 * self.params.feat_dim)
        ind_z = range(2 * self.params.feat_dim, data.shape[1])
        return data[:, ind_x].data, data[:, ind_y].data, data[:, ind_z].data

    # Construct training data on the current device.
    def build_train(self, data):
        x, y, z = self.convert_input(data)
        self.x_train = Variable(x.to(self.params.device), requires_grad=False)
        self.y_train = Variable(y.to(self.params.device), requires_grad=False)
        self.z_train = Variable(z.to(self.params.device), requires_grad=False)

    # Construct testing data on the current device.
    def build_test(self, data):
        x, y, z = self.convert_input(data)
        self.x_test = Variable(x.to(self.params.device), requires_grad=False)
        self.y_test = Variable(y.to(self.params.device), requires_grad=False)
        self.z_test = Variable(z.to(self.params.device), requires_grad=False)

    # Update the model with the current batch of training data
    #and record the losses and errors.
    def update_model(self):
        loss_train, n_hard_train, n_margin_train, n_error_train = self.nn.get_loss(
            self.x_train, self.y_train, self.z_train)
        loss_test, n_hard_test, n_margin_test, n_error_test = self.nn.get_loss(
            self.x_test, self.y_test, self.z_test)
        if self.params.hard_sample:
            n_batch_train = n_hard_train
            n_batch_test = n_hard_test
        else:
            n_batch_train = self.x_train.shape[0]
            n_batch_test = self.x_test.shape[0]
        self.losses_train.append(
            loss_train.to(torch.device('cpu')).item() / n_batch_train)
        self.losses_test.append(
            loss_test.to(torch.device('cpu')).item() / n_batch_test)
        self.hard_train.append(n_hard_train)
        self.margin_train.append(n_margin_train)
        self.error_train.append(n_error_train)
        self.hard_test.append(n_hard_test)
        self.margin_test.append(n_margin_test)
        self.error_test.append(n_error_test)

        self.nn.optimizer.zero_grad()
        loss_train.backward()
        self.nn.optimizer.step()

    # Train the model for one complete epoch by calling update_model.
    def one_pass(self, dataset):
        if self.params.read_all:
            ind = torch.split(
                torch.randperm(dataset.train_data.shape[0]),
                self.params.batch_size
            )
            for i in tqdm(range(len(ind) - 1)):
                batch_data = torch.index_select(dataset.train_data, 0, ind[i])
                self.build_train(batch_data)
                self.update_model()
        else:
            if not dataset.hive_open:
                dataset.get_train()
            while dataset.hive_open:
                print("{}: Training with batch {}"
                    .format(datetime.now(), dataset.counter))
                self.build_train(dataset.train_data)
                self.update_model()
                dataset.get_train()

    # Train the model for n_epoch epochs by calling one_pass.
    def train(self, dataset):
        if self.params.raw:
            return
        print('{}: Training the model...'.format(datetime.now()))
        self.losses_train = []
        self.losses_test = []
        self.hard_train = []
        self.hard_test = []
        self.margin_train = []
        self.margin_test = []
        self.error_train = []
        self.error_test = []
        self.build_test(dataset.eval_data)
        for _ in tqdm(range(self.params.n_epochs), ncols=80):
            self.one_pass(dataset)

    # Clear up the current batch of data.
    # Usually called before evaluation to free memory.
    def clean(self):
        del self.x_train, self.y_train, self.z_train
        del self.x_test, self.y_test, self.z_test

    # Compute the embedding on training or testing set.
    # Since the input and output can be very large
    # and the computation only requires a forward pass,
    # this process is always done on CPU.
    def compute(self, data):
        if self.params.raw:
            return
        print("{}: Computing the embedding of {} pages"
            .format(datetime.now(), len(data[0])))
        self.nn.value_nn.to(torch.device('cpu'))
        self.nn.shared_nn.to(torch.device('cpu'))
        emb = self.nn.forward(data[1])
        id = torch.LongTensor([int(x) for x in data[0]])
        print('{}: Finished the embedding computation'.format(datetime.now()))
        if self.params.quantization:
            return [id, self.quantize(emb)]
        else:
            return [id, emb]

    # Take the output embedding (the output of compute() function)
    # and use PCA to compute the vector quantization (0-1 vectors)
    def quantize(self, emb):
        emb -= torch.mean(emb, 0)
        U, _, _ = torch.svd(torch.t(emb))
        reduced_emb = torch.mm(emb, U[:, :self.params.pca_dim])
        quantized_emb = \
            reduced_emb / (torch.relu(reduced_emb) + torch.relu(-reduced_emb))
        return quantized_emb

    # Store the losses, errors and training curves for one complete train of the model.
    # This is mainly for debugging purposes and not necessary for production.
    def store(self, r=-1):
        if self.params.raw:
            return
        matplotlib.rcParams['pdf.fonttype'] = 42
        matplotlib.rcParams['ps.fonttype'] = 42
        length = len(self.losses_train)
        model = self.params.dis + self.params.loss
        if self.params.hard_sample:
            model += 'H'
        if self.params.attentive:
            model += 'A'
        if self.params.denoising:
            model += 'D'
        if r >= 0:
            model += '_{}'.format(r)

        plt.clf()
        plt.plot(np.array(
            range(length)),
            np.array(self.losses_train),
            label='train'
        )
        plt.plot(np.array(
            range(length)),
            np.array(self.losses_test),
            label='test'
        )
        plt.xlabel('Batch', fontsize=15)
        plt.ylabel('Loss', fontsize=15)
        #plt.ylim(0, 0.2)
        plt.grid()
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        plt.legend(fontsize=12)
        plt.savefig(
            '{}losses_{}.png'.format(self.params.tmp_dir, model),
            format='png',
            dps=200,
            bbox_inches='tight'
        )

        plt.clf()
        plt.plot(np.array(
            range(length)),
            np.array(self.hard_train),
            label='hard_train'
        )
        plt.plot(np.array(
            range(length)),
            np.array(self.margin_train),
            label='margin_train'
        )
        plt.plot(np.array(
            range(length)),
            np.array(self.error_train),
            label='error_train'
        )
        plt.plot(np.array(
            range(length)),
            np.array(self.hard_test),
            label='hard_test'
        )
        plt.plot(np.array(
            range(length)),
            np.array(self.margin_test),
            label='margin_test'
        )
        plt.plot(np.array(
            range(length)),
            np.array(self.error_test),
            label='error_test'
        )
        plt.xlabel('Batch', fontsize=15)
        plt.ylabel('Error', fontsize=15)
        #plt.ylim(0, 0.2)
        plt.grid()
        plt.xticks(fontsize=15)
        plt.yticks(fontsize=15)
        plt.legend(fontsize=12)
        plt.savefig(
            '{}errors_{}.png'.format(self.params.tmp_dir, model),
            format='png',
            dps=200,
            bbox_inches='tight'
        )

        with open('{}numbers_{}.txt'.format(self.params.tmp_dir, model), 'w') as f:
            f.write('\nLosses Train:\n')
            f.write(str(self.losses_train))
            f.write('\nLosses Test:\n')
            f.write(str(self.losses_test))
            f.write('\nHard Train:\n')
            f.write(str(self.hard_train))
            f.write('\nHard Test:\n')
            f.write(str(self.hard_test))
            f.write('\nMargin Train:\n')
            f.write(str(self.margin_train))
            f.write('\nMargin Test:\n')
            f.write(str(self.margin_test))
            f.write('\nError Train:\n')
            f.write(str(self.error_train))
            f.write('\nError Test:\n')
            f.write(str(self.error_test))
