'''
Author: Carl Yang
Date: June 2017
Location: Didichuxing@Beijing

This script process the csv files produced by Hive queries from the DiDi data platform
and store the processed data as Torch tensors

Compile the positions and numbers of order, driver, passenger features to correctly resolve each transaction

'''

import os
import time
import torch
import numpy as np

class DataProc(object):
	def __init__(
		self,
		flag_process = True,
		flag_split = True,
		path = '../data/',
		dataset = 'didi',
		version = '1day1city',
		processed_file = 'processed',
		splited_file = 'splited',
		sampled_file = 'sampled',
		score_pos = 0,
		spot_pos = [1, 4, 7],
		driver_pos = 10,
		driver_feat = 52,
		pas_pos = 72,
		pas_feat = 51,
		num_fold = 5,
		test_fold = 0,
		n_sample = 1000,
		n_size = 8,
		n_pair = 10,
		n_pattern = [100, 100, 100, 100, 100]):

		self.datapath = path + dataset + version + '/'
		self.processed_file = processed_file
		self.splited_file = splited_file
		self.sampled_file = sampled_file
		self.score_pos = score_pos
		self.spot_pos = spot_pos
		self.driver_pos = driver_pos
		self.pas_pos = pas_pos
		self.driver_feat = driver_feat
		self.pas_feat = pas_feat

		self.num_fold = num_fold
		self.test_fold = test_fold
		self.n_sample = n_sample
		self.n_size = n_size
		self.n_pair = n_pair
		self.n_pattern = n_pattern
		
		self.processed = False
		if os.path.exists(processed_file) == False or flag_process:
			self.process()
			self.processed = True
		if os.path.exists(splited_file) == False or flag_split:
			self.split()
		if os.path.exists(sampled_file) == False or flag_sample:
			self.sample()

	def process(self):
		self.order = []
		self.driver = []
		self.pas =  []
		self.spot = []

		pas_map = {}
		driver_map = {}
		spot_map = {}

		for datafile in os.listdir(self.datapath):
			if datafile[0] == '.':
				continue
			with open(self.datapath+datafile, 'r') as f:
				print('dataproc: processing data file '+datafile+'...')
				for line in f:
					feats = line.split(',')
					#order=[driver, pas, spot1, spot2, spot3, score]
					score = float(feats[self.score_pos])
					
					driver_id = feats[self.driver_pos]
					if driver_id not in driver_map.keys():
						driver_feat = []
						for i in range(self.driver_pos+1, self.driver_pos+self.driver_num):
							try:
								feat = float(feats[i])
								driver_feat.append(feat)
							except(TypeError, ValueError):
								if feats[i] == 'A':
									driver_feat.append(1)
								else:
									driver_feat.append(0)
						self.driver.append(driver_feat)
						driver_map[driver_id] = len(self.driver)-1
					driver = driver_map[driver_id]

					pas_id = feats[self.pas_pos]
					if pas_id not in pas_map.keys():
						pas_feat = []
						for i in range(self.pas_pos+1, self.pas_pos+self.pas_num):
							try:
								feat = float(feats[i])
								pas_feat.append(feat)
							except(TypeError, ValueError):
								pas_feat.append(0)
						self.pas.append(pas_feat)
						pas_map[pas_id] = len(self.pas)-1
					pas = pas_map[pas_id]

					spots = []
					for k in range(3):
						lat = feats[self.spot_pos[k][0]]
						lon = feats[self.spot_pos[k][1]]
						ts = feats[self.spot_pos[k][2]]
						try:
							t = time.strptime(ts, "%Y-%m=%d %H:%M:%S")
							td = t.tm_hour*6+int(t.tm_min/10)
						except(ValueError):
							td = 0

						#without spot grouping
						spot_sig = str(int(lat*1000))+str(int(lon*1000))+str(td)
						spot_feat = [int(lat*1000), int(lon*1000), td]
						if spot_sig not in spot_map.keys():
							self.spot.append(spot_feat)
							spot_map[spot_sig] = len(self.spot)-1
						spots.append(spot_map)

					self.order.append([driver, pas]+spots+[score])

		print('dataproc: processed '+str(len(self.order))+' transactions of '
			+str(len(self.driver))+' drivers, '+str(len(self.pas))+' passengers and '+str(len(self.spot))+' spots.')

		processed = []
		processed['order'] = torch.LongTensor(self.order)
		processed['driver'] = torch.FloatTensor(self.driver)
		processed['passenger'] = torch.FloatTensor(self.pas)
		processed['spot'] = torch.FloatTensor(self.spot)
		f = open(self.processed_file, 'wb')
		torch.save(processed, f)
		f.close()
		print('dataproc: processed data saved.')


	def split(self):
		if self.processed == False:
			if os.path.exists(self.processed_file):
				processed = torch.load(self.processed_file)
				self.order = processed['order']
			else:
				self.data_process()

		print('dataproc: spliting data.')
		weights = torch.Tensor([1]*self.num_fold)
		labels = torch.multinomial(weights, len(self.order), replacement=True)
		train_idx = (labels!=self.test_fold).nonzero().squeeze()
		test_idx = (labels==self.test_fold).nonzero().squeeze()
		splited = []
		splited['train'] = self.order.index_select(0, train_idx)
		splited['test'] = self.order.index_select(0, test_idx)
		f = open(self.splited_file, 'wb')
		torch.save(splited, f)
		f.close()
		print('dataproc: splited data saved.')

	def sample(self):
		if self.processed == False:
			if os.path.exists(self.processed_file):
				processed = torch.load(self.processed_file)
				self.order = processed['order']
				self.driver = processed['driver']
				self.pas = processed['passenger']
				self.spot = processed['spot']
			else:
				self.data_process()

		#type 0: driver; type 1: passenger; type 2: spot
		self.context = []
		self.offsets = [0, len(self.driver), len(self.driver)+len(self.passenger)]
		s_s_s_edge = {}
		s_s_t_edge = {}
		p_s_edge = {}
		s_p_edge = {}
		s_d_edge = {}
		d_p_edge = {}

		print('dataproc: sampling contexts.')
		for i in range(self.n_sample):

			#pattern a
			for j in range(self.n_pattern[0]):
				seed = np.random.randint(len(self.spot))
				node_set = [seed]
				for k in range(self.n_size):
					t = nodeset[np.random.randint(len(node_set))]
					if t not in s_s_s_edge.keys():
						s_s_s_edge[t] = []
						for c in range(len(self.spot)):
							if np.linalg.norm(np.array(self.spot[t][0:2])-np.array(self.spot[c][0:2])) < 10:
								s_s_s_edge[t].append(c)
					c = s_s_s_edge[t][np.random.randint(len(s_s_s_edge[t]))]
					while c==t:
						c = s_s_s_edge[t][np.random.randint(len(s_s_s_edge[t]))]
					if c not in node_set:
						node_set.append(c)
						type_set.append(2) 
				self.sample_pair(node_set, type_set)
			#pattern b
			for j in range(self.n_pattern[1]):
				seed = np.random.randint(len(self.spot))
				node_set = [seed]
				for k in range(self.n_size):
					t = nodeset[np.random.randint(len(node_set))]
					if t not in s_s_t_edge.keys():
						s_s_t_edge[t] = []
						for c in range(len(self.spot)):
							if abs(self.spot[t][2] - self.spot[c][2]) < 2:
								s_s_t_edge[t].append(c)
					c = s_s_t_edge[t][np.random.randint(len(s_s_t_edge[t]))]
					while c==t:
						c = s_s_t_edge[t][np.random.randint(len(s_s_t_edge[t]))]
					if c not in node_set:
						node_set.append(c)
						type_set.append(2) 
				self.sample_pair(node_set, type_set)
			#pattern c
			for j in range(self.n_pattern[2]):
				pas = np.random.randint(len(self.pas))
				node_set = []
				type_set = []
				if pas not in p_s_edge.keys():
					p_s_edge[pas] = []
					for c in range(len(self.order)):
						if self.order[c][1] == pas:
							p_s_edge[pas] += ([self.order[c][2], self.order[c][3], self.order[c][4]])
				for k in range(self.n_size):
					c = p_s_edge[pas][np.random.randint(len(p_s_edge[pas]))]
					if c not in node_set:
						node_set.append(c)
						type_set.append(2)
				self.sample_pair(node_set, type_set)
			#pattern d
			for j in range(self.n_pattern[3]):
				spot = np.random.randint(len(self.spot))
				node_set = []
				type_set = []
				if spot not in s_p_edge.keys():
					s_p_edge[spot] = []
					s_d_edge[spot] = []
					for c in range(len(self.order)):
						if self.order[c][2] == spot or self.order[c][3] == spot or self.order[c][4] == spot:
							s_d_edge[spot].append(self.order[c][0])
							s_p_edge[spot].append(self.order[c][1])
				for k in range(self.n_size):
					c = np.random.randint(len(s_d_edge[spot])+len(s_p_edge[spot]))
					if c < len(s_d_edge[spot]):
						node_set.append(s_d_edge[spot][c])
						type_set.append(0)
					else:
						node_set.append(s_p_edge[spot][c-len(s_d_edge[spot])])
						type_set.append(1)
				self.sample_pair(node_set, type_set)
			#pattern e
			for j in range(self.n_pattern[2]):
				driver = np.random.randint(len(self.driver))
				node_set = []
				type_set = []
				if driver not in d_p_edge.keys():
					d_p_edge[driver] = []
					for c in range(len(self.order)):
						if self.order[c][0] == driver:
							d_p_edge[driver].append(self.order[c][1])
				for k in range(self.n_size):
					c = d_p_edge[driver][np.random.randint(len(d_p_edge[driver]))]
					if c not in node_set:
						node_set.append(c)
						type_set.append(1)
				self.sample_pair(node_set, type_set)

		sampled = torch.LongTensor(self.context)
		f = open(self.sampled_file, 'wb')
		torch.save(sampled, f)
		f.close()
		print('dataproc: sampled contexts saved.')


	def sample_pair(node_set, type_set):
		for i in range(self.n_pair):
			perm = np.random.permutation(len(node_set))
			t = perm[0]
			c = perm[1]
			self.context.append([self.offsets[type_set[t]]+nodeset[t], self.offsets[type_set[c]]+nodeset[c]])

						
