import numpy as np
from sklearn.cluster import KMeans

def find(theset, key):
    idx_list = []
    for idx in range(0, len(theset)):
        if theset[idx] == key:
            idx_list.append(idx)
    return idx_list


def jaccard(s1, s2):
    return len(set(s1).intersection(set(s2))) * 1.0 / len(set(s1).union(set(s2)))


def F1(s1, s2):
    if len(set(s2)) == 0:
        precision = 0
    else:
        precision = len(set(s1).intersection(set(s2))) * 1.0 / len(set(s2))
    if len(set(s1)) == 0:
        recall = 0
    else:
        recall = len(set(s1).intersection(set(s2))) * 1.0 / len(set(s1))
    if precision == 0 and recall == 0:
        return 0
    return 2 * precision * recall / (precision + recall)


def cal_error(set1, set2, type):
    num = np.max(set1)
    s1 = []
    s2 = []

    for i in range(0, num + 1):
        s1.append(find(set1, i))
        s2.append(find(set2, i))
    #print s1
    #print s2
    max1 = []
    max2 = []
    for i in range(0, num + 1):
        maxval = 0
        for j in range(0, num + 1):
            if type == 'Jaccard':
                error = jaccard(s1[i], s2[j])
            else:
                error = F1(s1[i], s2[j])
            if error > maxval:
                maxval = error
        max1.append(maxval)

    for i in range(0, num + 1):
        maxval = 0
        for j in range(0, num + 1):
            if type == 'Jaccard':
                error = jaccard(s2[i], s1[j])
            else:
                error = F1(s2[i], s1[j])
            if error > maxval:
                maxval = error
        max2.append(maxval)
    #print max1
    #print max2
    return (sum(max1) + sum(max2)) / (2 * (num + 1))

test_number = 4
testsample = 25
feat_dim = 24
pred = np.zeros([testsample, feat_dim])

f1 = open('testdata/3980.embeddings', 'r+')
nextline = f1.readline()
nextline = f1.readline()

while nextline:
    l = nextline.split()
    ln = []
    for i in range(1, len(l)):
        ln.append(float(l[i]))
    # print np.array(ln)
    # print np.array(ln).reshape([1, feat_dim])
    pred[int(l[0])-1, :] = np.array(ln).reshape([1, feat_dim])
    nextline = f1.readline()
f1.close()

# print pred
f1 = open('testdata/3980_truth.txt', 'r+')
nextline = f1.readline()

test_output = np.zeros(pred.shape[0], dtype=np.int64)
i = 0
while nextline:
    nextlist = nextline.split()
    for node in nextlist:
        test_output[int(node)-1] = i
    nextline = f1.readline()
    i = i + 1
f1.close()

clustering = KMeans(n_clusters=test_number, init='k-means++')
clustering.fit(pred)
test_number = 4

# print clustering.labels_
# print test_output
print cal_error(clustering.labels_, test_output, 'F1')
print cal_error(clustering.labels_, test_output, 'Jaccard')