'''
Author: Carl Yang
Function: Utility functions, mainly for data loading from hive.
Command: library
'''
from __future__ import absolute_import, division, print_function, unicode_literals
from contextlib import contextmanager
import os
import sys
import random
import string
import numpy as np
import torch
import torch.hiveio as hiveio
from analytics.bamboo import Bamboo
from datetime import datetime


@contextmanager
def suppress_stdout():
    with open(os.devnull, "w") as devnull:
        old_stdout = sys.stdout
        sys.stdout = devnull
        try:
            yield
        finally:
            sys.stdout = old_stdout


class Dataset(object):
    # Initialize the Dataset object with current parameters.
    def __init__(self, params):
        self.params = params
        self.hive_open = False

    # Reset the parameters, so that different data can be read again.
    def reset(self, params):
        self.params = params
        self.clean()

    # Remove the data read, to free memory.
    def clean(self):
        del self.train_data
        del self.eval_data

    # Get training data from Hive.
    # Can either read all data at once or read one batch at a time.
    def get_train(self):
        if self.params.raw:
            return
        self.train_data = None
        self.eval_data = None
        self.data = None
        if self.params.loss == 'trip':
            feats = hiveio.read(
                self.params.table_namespace,
                self.params.table_triplet,
                self.params.table_partition,
                self.params.table_triplet_columns
            )
            self.data = torch.cat(feats, dim=1)
            if self.params.verbose:
                print('Got training data: {} triplets'.format(self.data.shape[0]))
        elif self.params.loss == 'pair':
            if not self.hive_open:
                self.hive_open = True
                self.counter = 0
                hiveio.start_reading(
                    self.params.table_namespace,
                    self.params.table_pair,
                    self.params.table_partition,
                    self.params.table_pair_columns,
                    self.params.batch_size
                )
            if self.params.read_all:
                feats = hiveio.get_all()
                self.hive_open = False
            else:
                self.counter += 1
                print("{}: Fetching training data batch {}"
                    .format(datetime.now(), self.counter))
                feats = hiveio.get_batch()
                if feats is None or feats[0].shape[0] != self.params.batch_size:
                    hiveio.stop_reading()
                    self.hive_open = False
            page1_feats = \
                feats[0][:, self.params.dimension[0]:self.params.dimension[1]]
            page2_feats = \
                feats[1][:, self.params.dimension[0]:self.params.dimension[1]]
            source = torch.zeros([len(feats[3]), 1], dtype=torch.float)
            label = torch.unsqueeze(
                torch.FloatTensor([float(l) for l in feats[2]]), 1)
            if self.params.attentive and len(self.params.sources) > 0:
                source = torch.unsqueeze(torch.FloatTensor(
                    [self.params.sources.index(s) for s in feats[3]]), 1)
            self.data = torch.cat([
                page1_feats,
                page2_feats,
                label,
                source
            ], dim=1)
            if self.params.verbose:
                print('Got training data: {} pairs'.format(self.data.shape[0]))
        else:
            raise Exception('Unknown loss function {}'.format(self.params.loss))

        self.shuffle_train()

    # Shuffle and re-split the training data
    def shuffle_train(self):
        if self.params.random_seed != 0:
            np.random.seed(self.params.random_seed)
        n_train = int(np.floor(self.params.split_ratio * self.data.shape[0]))
        ind = np.arange(self.data.shape[0])
        np.random.shuffle(ind)
        ind_train, ind_test = ind[:n_train], ind[n_train:]
        self.train_data = self.data[ind_train]
        self.eval_data = self.data[ind_test]
        if self.params.verbose:
            print('Split training data: {} for training and {} for evaluation'
                .format(self.train_data.shape[0], self.eval_data.shape[0]))

    # Get the features of training places for computing and outputting the embeddings.
    def get_feat_train(self):
        self.feat_train = None
        feat = hiveio.read(
            self.params.table_namespace,
            self.params.table_train_feat,
            self.params.table_partition,
            self.params.table_feat_columns
        )
        self.feat_train = [feat[0],
            feat[1][:, self.params.dimension[0]:self.params.dimension[1]]]
        if self.params.verbose:
            print('Got training feature: {} vectors'
                .format(self.feat_train[1].shape[0]))

    # Get the features and truth of testing places for evaluation.
    def get_test(self):
        self.get_feat_test()
        self.get_truth()

    # Get the features and truth of testing places for evaluation.
    def get_feat_test(self):
        self.feat_test = None
        feat = hiveio.read(
            self.params.table_namespace,
            self.params.table_test_feat,
            self.params.table_partition,
            self.params.table_feat_columns
        )
        self.feat_test = [feat[0],
            feat[1][:, self.params.dimension[0]:self.params.dimension[1]]]
        if self.params.verbose:
            print('Got testing feature: {} vectors'.format(self.feat_test[1].shape[0]))

    # Get the ground-truth of duplications and non-duplications on testing set.
    def get_truth(self):
        if self.params.eval in ('knn', 'all'):
            self.get_truth_knn()
        if self.params.eval in ('pairwise', 'all'):
            self.get_truth_pair()

    # Get the ground-truth for knn evaluations.
    def get_truth_knn(self):
        self.truth_knn = None
        df = Bamboo().read_hive_table(
            namespace=self.params.table_namespace,
            table=self.params.table_test_truth,
            column_names=self.params.table_truth_columns
        )
        self.truth_knn = {}
        for _, row in df.iterrows():
            if len(row.pos_set) > 0:
                self.truth_knn[row.page_id] = row.pos_set
        if self.params.verbose:
            print('Got knn truth: {} places with pos'
                .format(len(self.truth_knn)))

    # Get the ground-truth for pair-wise evaluations.
    def get_truth_pair(self):
        self.truth_pair = None
        df = Bamboo().read_hive_table(
            namespace=self.params.table_namespace,
            table=self.params.table_test_truth,
            column_names=self.params.table_truth_columns
        )
        self.truth_pair = {}
        for _, row in df.iterrows():
            if len(row.pos_set) > 0 and len(row.neg_set) > 0:
                self.truth_pair[row.page_id] = (row.pos_set, row.neg_set)
        if self.params.verbose:
            print('Got pairwise truth: {} places with pos and neg'
                .format(len(self.truth_pair)))

    # Get pairwise testing data for evaluating baselines only
    def get_test_baseline(self):
        feats = hiveio.read(
            self.params.table_namespace,
            self.params.table_test_baseline,
            self.params.table_partition,
            self.params.table_test_baseline_columns
        )
        page1_feats = \
            feats[0][:, self.params.dimension[0]:self.params.dimension[1]]
        page2_feats = \
            feats[1][:, self.params.dimension[0]:self.params.dimension[1]]
        label = torch.unsqueeze(
            torch.FloatTensor([float(l) for l in feats[2]]), 1)
        self.test_data = torch.cat([
            page1_feats,
            page2_feats,
            label
        ], dim=1)
        if self.params.verbose:
            print('Got testing data: {} pairs'.format(self.test_data.shape[0]))

    # Output the embeddings to a particular table or new table in Hive.
    def write_embedding(self, data):
        if len(self.params.table_output) > 0:
            table_name = self.params.table_output
            Bamboo().query_hive(
                self.params.table_namespace,
                '''
                    DELETE FROM {};
                '''.format(table_name)
            )
        else:
            table_name = 'tmp_carl_' + ''.join(
                [random.choice(string.ascii_letters + string.digits) for i in range(10)]
            )
            Bamboo().query_hive(
                self.params.table_namespace,
                '''
                    CREATE TABLE IF NOT EXISTS {} (
                        page_id BIGINT,
                        features ARRAY<FLOAT>
                    ) TBLPROPERTIES('RETENTION' = '1');
                '''.format(table_name),
                get_results=False,
            )
        hiveio.write(
            self.params.table_namespace,
            table_name,
            self.params.table_partition,
            data=data
        )
        print('Embeddings output to table {}'.format(table_name))
