import json import random import numpy as np raw_data = [] with open("features_3k.json") as file_json: raw_data = json.load(file_json) print(len(raw_data)) # random.shuffle(data) data = raw_data.copy() def get_dangerous_pairs(thresh): for i in range(len(raw_data) - 1): test = np.array(raw_data[i]["feature"]) # test = test.reshape(test, (1, test.shape[0])) for j in range(i+1, len(raw_data)): train = np.array(raw_data[j]["feature"]) dist = np.sqrt(np.sum(np.square(test-train))) if dist < thresh: print("Dangerous pairs:", raw_data[i]["name"], '(ID: ',raw_data[i]["id"],')', '-', raw_data[j]["name"], '(ID: ',raw_data[j]["id"],')', '(Dist: ',dist,')') # train = train.reshape(train, (1, train.shape[0])) def calculate_ratio(split_ratio, thresh_hold): # split_ratio = 0.2 # thresh_hold = 0.7 nTest = int (len(data) * split_ratio) nTrain = len(data) - nTest test_data = np.array([x["feature"] for x in data[:nTest]]) train_data = np.array([x["feature"] for x in data[nTest:]]) # print(split_ratio, test_data.shape, train_data.shape) min_dists = [] for i, row in enumerate(test_data): row = np.reshape(row, (1, row.shape[0])) repeat_row = np.tile(row, (nTrain, 1)) dists = repeat_row - train_data dists = np.sqrt(np.sum(np.square(dists), axis=1)) # print(dists) min_dist_idx = np.argmin(dists) # print(i, min_dist_idx, dists[min_dist_idx]) min_dists.append(dists[min_dist_idx]) min_dists = np.array(min_dists) # print(min_dists.shape) # print(min_dists[min_dists > 0.75].shape) t1 = min_dists[min_dists > thresh_hold].shape[0] t2 = min_dists.shape[0] print('\tthresh:', thresh_hold, '\tratio:', t1/t2, '(', t1, '/', t2 ,')') return t1/t2 def find_best_threshold(): for j in range(10, 60, 5): split = j / 100 print('\nSplit test/train:', split*10, '/', 10-split*10, '====================================') best_thresh = 0 best_ratio = 0 for i in range(50, 110, 2): thresh = i / 100 ratio = calculate_ratio(split, thresh) if ratio >= best_ratio: best_ratio = ratio best_thresh =thresh print('\tBEST THRESH: ', best_thresh, 'with ratio ', best_ratio) get_dangerous_pairs(0.8) find_best_threshold()