'''
Author: Carl Yang
Date: June 2017
Location: Didichuxing@Beijing

This script trains the phine neural framework for joint training on rating and context
Possible tunable components: 
	feats, 
	data amount,
	neural architechtures (number of layers, sizes of layers, activation functions, etc),
	network tricks (batch normalization, dropout, residual networks, autoencoder, attention, etc)

'''

import torch
from torch.autograd import Variable
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import datasplit

class Train(object):
	def __init__(
		self,
		batch_size = 1024,
		iter_num = 1000,
		learning_rate = 1e-2,
		T1 = 100,
		T2 = 10,
		emb_layer = 1,
		emb_arch = [16],
		inter_layer = 3,
		inter_arch = [64, 32, 16],
		negative_sample = 5
		):

		try:
			data = torch.load('dataset')
		except(FileNotFoundError):
			datasplit.DataSplit()
			data = torch.load('dataset')

		self.driver = data['driver']	#floattensor
		self.passenger = data['passenger']	#floattensor
		self.spot = data['spot']	#floattensor
		self.train = data['train']	#longtensor
		self.test = data['test']	#longtensor
		self.context = data['context']	#longtensor

		self.n_driver = self.driver.size()[0]
		self.n_passenger = self.passenger.size()[0]
		self.n_spot = self.spot.size()[0]
		self.n = self.n_driver + self.n_passenger + self.n_spot
		self.n_train = self.train.size()[0]
		self.n_context = self.context.size()[0]

		self.batch_size = batch_size
		self.iter_num = iter_num
		self.learning_rate = learning_rate
		self.T1 = T1
		self.T2 = T2
		self.emb_layer = emb_layer
		self.emb_arch = emb_arch
		self.inter_layer = inter_layer
		self.inter_arch = inter_arch
		self.negative_sample = negative_sample

		#define embedding models
		D_in = self.driver.size()[1]
		model_emb = OrderedDict([('l0', torch.nn.Linear(D_in, self.emb_arch[0])), ('n0', torch.nn.ReLU())])
		for i in range(1, self.emb_layer):
			model_emb.append(('l'+str(i), torch.nn.Linear(self.emb_arch[i-1], self.emb_arch[i])))
			model_emb.append(('n'+str(i), torch.nn.ReLU()))
		self.model_emb_driver = torch.nn.Sequential(model_emb)

		D_in = self.passenger.size()[1]
		model_emb = OrderedDict([('l0', torch.nn.Linear(D_in, self.emb_arch[0])), ('n0', torch.nn.ReLU())])
		for i in range(1, self.emb_layer):
			model_emb.append(('l'+str(i), torch.nn.Linear(self.emb_arch[i-1], self.emb_arch[i])))
			model_emb.append(('n'+str(i), torch.nn.ReLU()))
		self.model_emb_passenger = torch.nn.Sequential(model_emb)

		D_in = self.spot.size()[1]
		model_emb = OrderedDict([('l0', torch.nn.Linear(D_in, self.emb_arch[0])), ('n0', torch.nn.ReLU())])
		for i in range(1, self.emb_layer):
			model_emb.append(('l'+str(i), torch.nn.Linear(self.emb_arch[i-1], self.emb_arch[i])))
			model_emb.append(('n'+str(i), torch.nn.ReLU()))
		self.model_emb_spot = torch.nn.Sequential(model_emb)

		D_in = self.emb_arch[self.emb_layer-1]*5
		model_emb = OrderedDict([('l0', torch.nn.Linear(D_in, self.emb_arch[0])), ('n0', torch.nn.ReLU())])
		for i in range(1, self.inter_layer):
			model_emb.append(('l'+str(i), torch.nn.Linear(self.inter_arch[i-1], self.inter_arch[i])))
			model_emb.append(('n'+str(i), torch.nn.ReLU()))
		model_emb.append(('sigmoid', torch.nn.Sigmoid()))
		model_emb.append(('multiply', MulLayer(5)))
		self.model_inter = torch.nn.Sequential(model_emb)

		D_in = self.emb_arch[self.emb_layer-1]
		self.context_emb = Variable(torch.randn(self.n, D_in), requires_grad=True)

		self.train()
		self.plot()

	def train(self):
		loss_fn_inter = torch.nn.MSELoss(size_average=False)
		optimizer_inter = torch.optim.Adam([
			('driver_emb', self.model_emb_driver.parameters()),
			('passenger_emb', self.model_emb_passenger.parameters()),
			('spot_emb', self.model_emb_spot.parameters())],
			lr=self.learning_rate)
		loss_fn_context = torch.nn.MSELoss(size_average=False)
		optimizer_context = torch.optim.Adam([
			('driver_emb', self.model_emb_driver.parameters()),
			('passenger_emb', self.model_emb_passenger.parameters()),
			('spot_emb', self.model_emb_spot.parameters()),
			('context_emb', self.context_emb)],
			lr=self.learning_rate)

		losses_train_total = []
		losses_test_total = []

		for t in range(self.iter_num):
			losses_train = []
			losses_test = []
			for t1 in range(self.T1):
				#compute model
				driver_emb = self.model_emb_driver(Variable(self.driver, requires_grad = False))
				passenger_emb = self.model_emb_passenger(Variable(self.passenger, requires_grad = False))
				spot_emb = self.model_emb_spot(Variable(self.spot, requires_grad = False))

				#train
				train_index = torch.LongTensor(np.random.randint(self.n_train, size=self.batch_size))
				train = torch.index_select(self.train, 0, train_index)
				train_dirver_emb = torch.index_select(driver_emb, 0, Variable(train[:, 0]))
				train_passenger_emb = torch.index_select(passenger_emb, 0, Variable(train[:, 1]))
				train_spot1_emb = torch.index_select(spot_emb, 0, Variable(train[:, 2]))
				train_spot2_emb = torch.index_select(spot_emb, 0, Variable(train[:, 3]))
				train_spot3_emb = torch.index_select(spot_emb, 0, Variable(train[:, 4]))
				train_y = Variable(train[:, 5], requires_grad = False)			
				train_y_pred = self.model_inter(torch.cat((train_driver_emb, train_passenger_emb, train_spot1_emb, train_spot2_emb, train_spot3_emb), 1))
				loss_train = loss_fn_inter(train_y_pred, train_y)
				losses_train.append(loss_train.data[0]/len(train_y))
				print('training:', t, t1, loss_train.data[0]/len(train_y))#to delete

				#test: implement loss_test same as loss_train
				#print('training:', t, t1, loss_train.data[0]/len(train_y), loss_test.data[0]/len(test_y))
				#losses_test.append(loss_test.data[0]/len(test_y))

				#update model
				optimizer_inter.zero_grad()
				loss_train.backward()
				optimizer_inter.step()

			losses_train_total+=losses_train
			losses_test_total+=losses_test

			for t2 in range(self.T2):
				#compute model
				driver_emb = self.model_emb_driver(Variable(self.driver, requires_grad = False))
				passenger_emb = self.model_emb_passenger(Variable(self.passenger, requires_grad = False))
				spot_emb = self.model_emb_spot(Variable(self.spot, requires_grad = False))
				entity_emb = torch.cat((driver_emb, passenger_emb, spot_emb), 0)

				#train
				context_index = torch.LongTensor(np.random.randint(self.n_context, size=self.batch_size))
				context = torch.index_select(self.context, 0, context_index)
				entity_id = context[:, 0]
				context_id = context[:, 1]
				entity_id = entity_id.repeat(1, self.negative_sample).squeeze()
				corrupted_cotext_id = torch.LongTensor(np.random.randint(self.n, size=(self.batch_size*self.negative_sample)))
				context_id = torch.cat((context_id, corrupted_context_id), 0).squeeze()
				y = torch.cat((torch.FloatTensor([1]*self.n), torch.FloatTensor([0]*self.negative_sample)), 0).squeeze()
				y_pred = torch.nn.Sigmoid(torch.diag(torch.mm(torch.index_select(entity_emb, 0, Variable(entity_id)), torch.index_select(context_emb, 0, Variable(context_id)).t())))
				loss = loss_fn_context(y_pred, y)

				#update
				print('contex preserving:', t, t2, loss.data[0]/len(y))
				optimizer_context.zero_grad()
				loss.backward()
				optimizer_context.step()


	def plot(self):
		matplotlib.rcParams['pdf.fonttype'] = 42
		matplotlib.rcParams['ps.fonttype'] = 42
		length = len(self.losses['train'])
		plt.plot(np.array(range(length)), self.losses['train'], label='train')
		#plt.plot(np.array(range(length)), self.losses['test'], label='test')
		plt.xlabel('Epoch',fontsize=15)
		plt.ylabel('Loss',fontsize=15)
		plt.grid()
		plt.xticks(fontsize=15)
		plt.yticks(fontsize=15)
		plt.legend(fontsize=12)
		#plt.savefig('losses.eps', format='eps', dps=1000, bbox_inches='tight')
		plt.show()

class MulLayer(torch.nn.Module):
	def __init__(self, scaler):
		super(MulLayer, self).__init__()
		self.scaler = scaler

	def forward(self, x):
		return x*self.scaler
