'''
Author: Carl Yang
Function: Run multiple experiments with a single thread
Command: buck run @mode/dev-nosan //experimental/carlyang/place_embedding:batch
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
from datetime import datetime
import pickle
from experimental.carlyang.place_embedding.config import args, init_batch
from experimental.carlyang.place_embedding.embed import Embedding
from experimental.carlyang.place_embedding.util import Dataset
from experimental.carlyang.place_embedding.eval import Evaluation


# Run the experiments with the same configurations for n_rounds times
# and record the mean and std of evaluation metrics.
def run_exp(params, dataset, eval):
    pre = []
    pre_list = []
    rec = []
    rec_list = []
    pair = []
    runtime = []
    for r in range(params.n_rounds):
        embed = Embedding(params)
        start = datetime.now()
        embed.train(dataset)
        timedelta = (datetime.now() - start).seconds
        embed.store(r)
        eval.evaluate(dataset, embed)
        pre.append(eval.pre_avg)
        pre_list.append(eval.pre_list)
        rec.append(eval.rec_avg)
        rec_list.append(eval.rec_list)
        pair.append(eval.pair)
        runtime.append(timedelta)

    pre = np.asarray(pre)
    pre_list = np.asarray(pre_list)
    rec = np.asarray(rec)
    rec_list = np.asarray(rec_list)
    pair = np.asarray(pair)
    runtime = np.asarray(runtime)

    results = {}
    results['pre_mean'] = np.mean(pre)
    results['pre_std'] = np.std(pre)
    results['pre_list_mean'] = np.mean(pre_list, axis=0).tolist()
    results['pre_list_std'] = np.std(pre_list, axis=0).tolist()
    results['rec_mean'] = np.mean(rec)
    results['rec_std'] = np.std(rec)
    results['rec_list_mean'] = np.mean(rec_list, axis=0).tolist()
    results['rec_list_std'] = np.std(rec_list, axis=0).tolist()
    results['pair_mean'] = np.mean(pair)
    results['pair_std'] = np.std(pair)
    results['runtime_mean'] = np.mean(runtime)
    results['runtime_std'] = np.std(runtime)
    return results


# Iterate through the parameter sets and call run_exp in each iteration.
# Report the experimental results in console.
# Store the results in a pickle file.
if __name__ == '__main__':
    params = args
    init_batch(args)
    loss = ['pair']
    dis = ['eu']
    hard_sample = [False, True]
    attentive = [False, True]
    denoising = [False]
    sources = [
        'pqi_difftool_dev_austin:train',
        'pqiv3_dupe_measurement',
        'crowdsourcing:train',
        'curation:train',
        'lm:train',
        'pqiv4_dev_recall',
        'bad_redirects:valid',
        'checkin_cleanliness:train',
        'pqiv4_dev_precision',
        'dedup_bug_tool:train',
        'pqiv2_development:train',
        'bad_redirects:train',
        'disassociations:train',
        'dedup_bug_tool:valid',
        'checkin_cleanliness:valid',
        'metapage_tool:train',
        'pqiv2_development:valid',
        'disassociations:valid',
        'metapage_tool:valid'
    ]
    pre_mean = []
    pre_std = []
    pre_list_mean = []
    pre_list_std = []
    rec_mean = []
    rec_std = []
    rec_list_mean = []
    rec_list_std = []
    pair_mean = []
    pair_std = []
    runtime_mean = []
    runtime_std = []
    dataset = Dataset(params)
    dataset.get_test()
    eval = Evaluation(params)

    '''
    # evaluating raw features
    #dims = [(29, 79), (29, 129), (0, 129)]
    dims = [(0, 129), (129, 429), (0, 429)]
    for dim in dims:
        eval.evaluate(dataset, dim=dim)
        pre_mean.append(eval.pre_avg)
        pre_std.append(0)
        pre_list_mean.append(eval.pre_list)
        pre_list_std.append(np.zeros_like(eval.pre_list))
        rec_mean.append(eval.rec_avg)
        rec_std.append(0)
        rec_list_mean.append(eval.rec_list)
        rec_list_std.append(np.zeros_like(eval.rec_list))
        pair_mean.append(eval.pair_test)
        pair_std.append(0)
        runtime_mean.append(0)
        runtime_std.append(0)
    '''
    '''
    # evaluating different sources
    for s in sources:
        params.source = s
        dataset.reset(params)
        dataset.get_train()
        results = run_exp(params, dataset, eval)
        pre_mean.append(results['pre_mean'])
        pre_std.append(results['pre_std'])
        pre_list_mean.append(results['pre_list_mean'])
        pre_list_std.append(results['pre_list_std'])
        rec_mean.append(results['rec_mean'])
        rec_std.append(results['rec_std'])
        rec_list_mean.append(results['rec_list_mean'])
        rec_list_std.append(results['rec_list_std'])
        pair_mean.append(results['pair_mean'])
        pair_std.append(results['pair_std'])
        runtime_mean.append(results['runtime_mean'])
        runtime_std.append(results['runtime_std'])

    '''

    for l in loss:
        params.loss = l
        dataset.reset(params)
        dataset.get_train()
        for d in dis:
            for h in hard_sample:
                for a in attentive:
                    for de in denoising:
                        params.dis = d
                        params.hard_sample = h
                        params.attentive = a
                        params.denoising = de
                        results = run_exp(params, dataset, eval)
                        pre_mean.append(results['pre_mean'])
                        pre_std.append(results['pre_std'])
                        pre_list_mean.append(results['pre_list_mean'])
                        pre_list_std.append(results['pre_list_std'])
                        rec_mean.append(results['rec_mean'])
                        rec_std.append(results['rec_std'])
                        rec_list_mean.append(results['rec_list_mean'])
                        rec_list_std.append(results['rec_list_std'])
                        pair_mean.append(results['pair_mean'])
                        pair_std.append(results['pair_std'])
                        runtime_mean.append(results['runtime_mean'])
                        runtime_std.append(results['runtime_std'])

    pre_mean = [round(i, 6) for i in pre_mean]
    pre_std = [round(i, 6) for i in pre_std]
    rec_mean = [round(i, 4) for i in rec_mean]
    rec_std = [round(i, 4) for i in rec_std]
    pair_mean = [round(i, 4) for i in pair_mean]
    pair_std = [round(i, 4) for i in pair_std]
    runtime_mean = [round(i, 0) for i in runtime_mean]
    runtime_std = [round(i, 0) for i in runtime_std]

    print(pre_list_mean, pre_list_std)
    print(rec_list_mean, rec_list_std)
    print(pre_mean, pre_std)
    print(rec_mean, rec_std)
    print(pair_mean, pair_std)
    print(runtime_mean, runtime_std)

    with open('{}results.txt'.format(params.tmp_dir), 'w') as f:
        f.write('\nknn results:\n')
        f.write(str(pre_mean) + '\n')
        f.write(str(pre_std) + '\n')
        f.write(str(rec_mean) + '\n')
        f.write(str(rec_std) + '\n')

        f.write('\npairwise results:\n')
        f.write(str(pair_mean) + '\n')
        f.write(str(pair_std) + '\n')

        f.write('\nruntime results:\n')
        f.write(str(runtime_mean) + '\n')
        f.write(str(runtime_std) + '\n')

    with open('{}knn_list.pkl'.format(params.tmp_dir), 'wb') as f:
        pickle.dump([pre_list_mean, pre_list_std, rec_list_mean, rec_list_std], f)
