'''
Author: Carl Yang
Function: Run multiple experiments with multiple threads on multiple GPUs
Command: buck run @mode/dev-nosan //experimental/carlyang/place_embedding:multithread
Note: Training with multiple threads on multiple GPU should be useful,
but is unstable in practice.
The program is prone to memory overflow and incorrect memory access.
The code is broken and no longer maintained.
'''
from __future__ import absolute_import, division, print_function, unicode_literals
import numpy as np
from datetime import datetime
import threading
import torch
from experimental.carlyang.place_embedding.config import args, init_batch
from experimental.carlyang.place_embedding.embed import Embedding
from experimental.carlyang.place_embedding.util import Dataset
from experimental.carlyang.place_embedding.eval import Evaluation


class run_exp(threading.Thread):
    def __init__(
        self,
        params,
        device,
        dataset,
        write_lock,
        knn_scores,
        pairwise_scores,
        times
    ):
        threading.Thread.__init__(self)
        self.params = params
        self.device = device
        self.dataset = dataset
        self.write_lock = write_lock
        self.knn_scores = knn_scores
        self.pairwise_scores = pairwise_scores
        self.times = times

    def run(self):
        with torch.cuda.device(self.device):
            embed = Embedding(self.params)
            start = datetime.now()
            embed.train(self.dataset)
            time = (datetime.now() - start).seconds
            eval = Evaluation(self.params)
            eval.evaluate(self.dataset, embed)
            self.write_lock.acquire()
            knn_scores.append(eval.knn_avg)
            pairwise_scores.append(eval.pairwise_test)
            times.append(time)
            self.write_lock.release()


if __name__ == '__main__':
    params = args
    init_batch(args)
    rounds = 2
    hard_sample = [False, True]
    loss = ['triplet', 'pair']
    knn_mean = []
    knn_std = []
    pairwise_mean = []
    pairwise_std = []
    runtime = []
    dataset = Dataset(params)
    dataset.get_test()
    for l in loss:
        params.loss = l
        dataset.reset(params)
        dataset.get_train()
        for h in hard_sample:
            params.hard_sample = h
            knn_scores = []
            pairwise_scores = []
            times = []
            threads = []
            write_lock = threading.Lock()
            for i in range(rounds):
                t = run_exp(
                    params,
                    i,
                    dataset,
                    write_lock,
                    knn_scores,
                    pairwise_scores,
                    times
                )
                t.start()
                threads.append(t)
            for t in threads:
                t.join()
            knn_scores = np.asarray(knn_scores)
            pairwise_scores = np.asarray(pairwise_scores)
            times = np.asarray(times)
            knn_mean.append(np.mean(knn_scores))
            knn_std.append(np.std(knn_scores))
            pairwise_mean.append(np.mean(pairwise_scores))
            pairwise_std.append(np.std(pairwise_scores))
            runtime.append(np.mean(times))

    print(knn_mean, knn_std)
    print(pairwise_mean, pairwise_std)
    print(runtime)
