'''
Author: Carl Yang
Function: Train and evaluate non-embedding baselines
Command: buck run @mode/dev-nosan //experimental/carlyang/place_embedding:baseline
'''
from __future__ import absolute_import, division, print_function, unicode_literals
from experimental.carlyang.place_embedding.config import args
from experimental.carlyang.place_embedding.util import Dataset
from sklearn.linear_model import LogisticRegression
from sklearn.svm import SVC
from sklearn.ensemble import GradientBoostingClassifier
from sklearn.ensemble import RandomForestClassifier
import numpy as np
from datetime import datetime


if __name__ == '__main__':
    params = args

    dataset = Dataset(params)
    idx = range(2 * params.feat_dim)
    idy = range(2 * params.feat_dim, 2 * params.feat_dim + 1)

    dataset.get_train()
    dataset.get_test_baseline()
    test_x = dataset.test_data[:, idx].data.numpy()
    test_y = dataset.test_data[:, idy].data.squeeze().numpy()

    mean = {}
    std = {}
    for baseline in params.baselines:
        if baseline == 'LR':
            model = LogisticRegression(max_iter=100)
        elif baseline == 'SVM':
            model = SVC(max_iter=1000)
        elif baseline == 'GBDT':
            model = GradientBoostingClassifier(max_features=20, max_depth=5)
        elif baseline == 'RF':
            model = RandomForestClassifier()

        accs = []
        for _ in range(params.n_rounds):
            print("{}: shuffling training data.".format(datetime.now()))
            dataset.shuffle_train()
            train_x = dataset.train_data[:, idx].data.numpy()
            train_y = dataset.train_data[:, idy].data.squeeze().numpy()
            print("{}: fitting the {} mode.".format(datetime.now(), baseline))
            model.fit(train_x, train_y)
            print("{}: predicting with the {} mode.".format(datetime.now(), baseline))
            pred_y = model.predict(test_x)

            acc = (pred_y == test_y).tolist().count(True) * 1.0 / len(test_y)
            accs.append(acc)
        mean[baseline] = np.array(accs).mean()
        std[baseline] = np.array(accs).std()

    print(mean)
    print(std)
