75 lines
2.5 KiB
Python
75 lines
2.5 KiB
Python
import json
|
|
import random
|
|
import numpy as np
|
|
|
|
raw_data = []
|
|
|
|
with open("features_3k.json") as file_json:
|
|
raw_data = json.load(file_json)
|
|
print(len(raw_data))
|
|
|
|
# random.shuffle(data)
|
|
data = raw_data.copy()
|
|
|
|
def get_dangerous_pairs(thresh):
|
|
for i in range(len(raw_data) - 1):
|
|
test = np.array(raw_data[i]["feature"])
|
|
# test = test.reshape(test, (1, test.shape[0]))
|
|
for j in range(i+1, len(raw_data)):
|
|
train = np.array(raw_data[j]["feature"])
|
|
dist = np.sqrt(np.sum(np.square(test-train)))
|
|
if dist < thresh:
|
|
print("Dangerous pairs:",
|
|
raw_data[i]["name"], '(ID: ',raw_data[i]["id"],')',
|
|
'-',
|
|
raw_data[j]["name"], '(ID: ',raw_data[j]["id"],')',
|
|
'(Dist: ',dist,')')
|
|
# train = train.reshape(train, (1, train.shape[0]))
|
|
|
|
def calculate_ratio(split_ratio, thresh_hold):
|
|
# split_ratio = 0.2
|
|
# thresh_hold = 0.7
|
|
|
|
nTest = int (len(data) * split_ratio)
|
|
nTrain = len(data) - nTest
|
|
|
|
test_data = np.array([x["feature"] for x in data[:nTest]])
|
|
train_data = np.array([x["feature"] for x in data[nTest:]])
|
|
# print(split_ratio, test_data.shape, train_data.shape)
|
|
min_dists = []
|
|
for i, row in enumerate(test_data):
|
|
row = np.reshape(row, (1, row.shape[0]))
|
|
repeat_row = np.tile(row, (nTrain, 1))
|
|
|
|
dists = repeat_row - train_data
|
|
dists = np.sqrt(np.sum(np.square(dists), axis=1))
|
|
# print(dists)
|
|
min_dist_idx = np.argmin(dists)
|
|
# print(i, min_dist_idx, dists[min_dist_idx])
|
|
min_dists.append(dists[min_dist_idx])
|
|
|
|
min_dists = np.array(min_dists)
|
|
# print(min_dists.shape)
|
|
# print(min_dists[min_dists > 0.75].shape)
|
|
t1 = min_dists[min_dists > thresh_hold].shape[0]
|
|
t2 = min_dists.shape[0]
|
|
print('\tthresh:', thresh_hold, '\tratio:', t1/t2, '(', t1, '/', t2 ,')')
|
|
return t1/t2
|
|
|
|
def find_best_threshold():
|
|
for j in range(10, 60, 5):
|
|
split = j / 100
|
|
print('\nSplit test/train:', split*10, '/', 10-split*10, '====================================')
|
|
best_thresh = 0
|
|
best_ratio = 0
|
|
for i in range(50, 110, 2):
|
|
thresh = i / 100
|
|
ratio = calculate_ratio(split, thresh)
|
|
if ratio >= best_ratio:
|
|
best_ratio = ratio
|
|
best_thresh =thresh
|
|
print('\tBEST THRESH: ', best_thresh, 'with ratio ', best_ratio)
|
|
|
|
get_dangerous_pairs(0.8)
|
|
find_best_threshold()
|