'''
Author: Carl Yang
Function: All model configuration parameters.
Command: library
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import argparse
import os
import torch


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--device_type',
        type=str,
        default='gpu',
        help='Training device (cpu or gpu)'
    )
    parser.add_argument(
        '--device_id',
        type=int,
        default=1,
        help='Training device ids'
    )
    parser.add_argument(
        '--n_threads',
        type=int,
        default=32,
        help='Number of threads'
    )
    parser.add_argument(
        '--table_namespace',
        type=str,
        default='entities',
        help='Table namespace'
    )
    parser.add_argument(
        '--table_partition',
        type=str,
        default='',
        help='Table partition list'
    )
    parser.add_argument(
        '--table_triplet',
        type=str,
        default='daiquery_1942228235816900',
        help='Triplet table name (11818456)'
    )
    parser.add_argument(
        '--table_triplet_columns',
        type=list,
        default=['anchor_features', 'dupe_features', 'diff_features', 'set_name'],
        help='Triplet table columns'
    )
    parser.add_argument(
        '--table_pair',
        type=str,
        default='daiquery_1827190020675747',
        help='Random walk pair table name (4604623)'
    )
    parser.add_argument(
        '--table_pair_columns',
        type=str,
        default=['page1_features', 'page2_features', 'label', 'set_name'],
        help='Pair table columns'
    )
    parser.add_argument(
        '--table_train_feat',
        type=str,
        default='daiquery_323451304861616',
        help='Training data table name. (page_id, features)'
    )
    parser.add_argument(
        '--table_test_feat',
        type=str,
        default='daiquery_383770218812171',
        help='Testing data table name (47642)'
    )
    parser.add_argument(
        '--table_feat_columns',
        type=list,
        default=['page_id', 'features'],
        help='Evaluation table columns'
    )
    parser.add_argument(
        '--table_test_truth',
        type=str,
        default='daiquery_166454950895654',
        help='Truth on testing data table name (47642)'
    )
    parser.add_argument(
        '--table_truth_columns',
        type=str,
        default=['page_id', 'pos_set', 'neg_set'],
        help='Truth table columns'
    )
    parser.add_argument(
        '--table_test_baseline',
        type=str,
        default='daiquery_1775936799166806',
        help='Testing pairs for baselines'
    )
    parser.add_argument(
        '--table_test_baseline_columns',
        type=str,
        default=['page1_features', 'page2_features', 'label'],
        help='Testing pairs of baseline colunms'
    )
    parser.add_argument(
        '--write_train',
        type=bool,
        default=True,
        help='Write the embedding of training data to hive'
    )
    parser.add_argument(
        '--write_test',
        type=bool,
        default=True,
        help='Write the testing results'
    )
    parser.add_argument(
        '--write_model',
        type=bool,
        default=True,
        help='Write the embedding model files'
    )
    parser.add_argument(
        '--table_output',
        type=str,
        default='',
        help='Output table name'
    )
    parser.add_argument(
        '--random_seed',
        type=int,
        default=0,
        help='Random seed for train-test split'
    )
    parser.add_argument(
        '--split_ratio',
        type=float,
        default=0.99,
        help='Train-test split ratio'
    )
    parser.add_argument(
        '--model',
        type=str,
        default='fnn',
        help='Neural network model'
    )
    parser.add_argument(
        '--fnn_act',
        type=str,
        default='elu',
        help='Activation function of fully connected layers'
    )
    parser.add_argument(
        '--shared_layer',
        type=int,
        default=1,
        help='Number of fully connected nn layers. (for attentive training)'
    )
    parser.add_argument(
        '--shared_dim',
        type=list,
        default=[400],
        help='Dimension of fully connected nn layers'
    )
    parser.add_argument(
        '--key_layer',
        type=int,
        default=1,
        help='Number of fully connected nn layers'
    )
    parser.add_argument(
        '--key_dim',
        type=list,
        default=[300],
        help='Dimension of fully connected nn layers'
    )
    parser.add_argument(
        '--value_layer',
        type=int,
        default=1,
        help='Number of fully connected nn layers'
    )
    parser.add_argument(
        '--value_dim',
        type=list,
        default=[300],
        help='Dimension of fully connected nn layers'
    )
    parser.add_argument(
        '--raw',
        type=bool,
        default=False,
        help='Use raw features without metric learning'
    )
    parser.add_argument(
        '--read_all',
        type=bool,
        default=True,
        help='Read all training data into memory to facilitate fast batching'
    )
    parser.add_argument(
        '--loss',
        type=str,
        default='pair',
        help='Loss function'
    )
    parser.add_argument(
        '--dis',
        type=str,
        default='eu',
        help='Loss function'
    )
    parser.add_argument(
        '--sources',
        type=list,
        default=[
            'pqi_difftool_dev_austin:train',
            'pqiv3_dupe_measurement',
            'crowdsourcing:train',
            'curation:train',
            'lm:train',
            'pqiv4_dev_recall',
            'bad_redirects:valid',
            'checkin_cleanliness:train',
            'pqiv4_dev_precision',
            'dedup_bug_tool:train',
            'pqiv2_development:train',
            'bad_redirects:train',
            'disassociations:train',
            'dedup_bug_tool:valid',
            'checkin_cleanliness:valid',
            'metapage_tool:train',
            'pqiv2_development:valid',
            'disassociations:valid',
            'metapage_tool:valid'
        ],
        help='Sources of training data'
    )
    parser.add_argument(
        '--dimension',
        type=tuple,
        default=(29, 429),
        help='Dimension of input features to use'
    )
    parser.add_argument(
        '--eval',
        type=str,
        default='all',
        help='Evaluation metric'
    )
    parser.add_argument(
        '--k_max',
        type=int,
        default=100,
        help='Max K for KNN'
    )
    parser.add_argument(
        '--hard_sample',
        type=bool,
        default=False,
        help='Enable hard sampling'
    )
    parser.add_argument(
        '--attentive',
        type=bool,
        default=False,
        help='Enable attentive training (pair loss only)'
    )
    parser.add_argument(
        '--denoising',
        type=bool,
        default=False,
        help='Enable label denoising (pair loss only)'
    )
    parser.add_argument(
        '--quantization',
        type=bool,
        default=False,
        help='Compute the quantization of embeddings'
    )
    parser.add_argument(
        '--pca_dim',
        type=int,
        default=32,
        help='The dimension of PCA used for vector quantization'
    )
    parser.add_argument(
        '--n_center',
        type=int,
        default=100,
        help='Number of soft cluster centers (for denoising)'
    )
    parser.add_argument(
        '--learning_rate',
        type=float,
        default=1e-4,
        help='Learning rate'
    )
    parser.add_argument(
        '--alpha',
        type=float,
        default=0.8,
        help='Loss margin parameter'
    )
    parser.add_argument(
        '--beta',
        type=float,
        default=0.8,
        help='Hard sampling margin parameter'
    )
    parser.add_argument(
        '--rho',
        type=float,
        default=1,
        help='Label Denoising parameter'
    )
    parser.add_argument(
        '--batch_size',
        type=int,
        default=65536,
        help='Training batch size'
    )
    parser.add_argument(
        '--n_epochs',
        type=int,
        default=1,
        help='Training epoch number'
    )
    parser.add_argument(
        '--n_rounds',
        type=int,
        default=5,
        help='Batch experiment round number'
    )
    parser.add_argument(
        '--verbose',
        type=bool,
        default=True,
        help='Print out notifications'
    )
    parser.add_argument(
        '--baselines',
        type=list,
        default=['LR', 'SVM', 'GBDT', 'RF'],
        help='Baselinse to evaluate'
    )
    return parser.parse_args()


# Initialize the directories
def init_dir(args):
    args.model_dir = '/data/users/carlyang/fbsource/fbcode/experimental/carlyang/model/'
    args.log_dir = '/data/users/carlyang/fbsource/fbcode/experimental/carlyang/log/'
    args.tmp_dir = '/data/users/carlyang/fbsource/fbcode/experimental/carlyang/tmp/'
    if not os.path.exists(args.model_dir):
        os.makedirs(args.model_dir)
    if not os.path.exists(args.log_dir):
        os.makedirs(args.log_dir)
    if not os.path.exists(args.tmp_dir):
        os.makedirs(args.tmp_dir)


# Initialize cuda settings
def init_dev(args):
    if args.device_type == 'gpu' and torch.cuda.is_available():
        #use list of devices
        args.device = torch.device('cuda')
        torch.cuda.set_device(args.device_id)
    else:
        args.device = torch.device('cpu')


# Unmark the init_small in init to run with toy data and model
# mainly for debugging purpose
def init_small(args):
    args.table_pair = 'daiquery_422695608216055'
    args.loss = 'pair'
    args.dis = 'eu'
    args.hard_sample = False
    args.attentive = False
    args.denoising = False
    args.batch_size = 1024
    args.n_epochs = 1


# Set parameters for batch experiments
def init_batch(args):
    args.write_test = False
    args.write_train = False
    args.write_model = False
    args.read_all = True
    args.raw = False
    args.eval = 'all'
    args.verbose = False
    #init_small(args)


# Initialize parameters that depend on other parameters
def init(args):
    args.feat_dim = args.dimension[1] - args.dimension[0]
    init_dir(args)
    init_dev(args)
    #init_small(args)


args = parse_args()
init(args)
