test-feature-tool/test.py

75 lines
2.5 KiB
Python

import json
import random
import numpy as np
raw_data = []
with open("features_3k.json") as file_json:
raw_data = json.load(file_json)
print(len(raw_data))
# random.shuffle(data)
data = raw_data.copy()
def get_dangerous_pairs(thresh):
for i in range(len(raw_data) - 1):
test = np.array(raw_data[i]["feature"])
# test = test.reshape(test, (1, test.shape[0]))
for j in range(i+1, len(raw_data)):
train = np.array(raw_data[j]["feature"])
dist = np.sqrt(np.sum(np.square(test-train)))
if dist < thresh:
print("Dangerous pairs:",
raw_data[i]["name"], '(ID: ',raw_data[i]["id"],')',
'-',
raw_data[j]["name"], '(ID: ',raw_data[j]["id"],')',
'(Dist: ',dist,')')
# train = train.reshape(train, (1, train.shape[0]))
def calculate_ratio(split_ratio, thresh_hold):
# split_ratio = 0.2
# thresh_hold = 0.7
nTest = int (len(data) * split_ratio)
nTrain = len(data) - nTest
test_data = np.array([x["feature"] for x in data[:nTest]])
train_data = np.array([x["feature"] for x in data[nTest:]])
# print(split_ratio, test_data.shape, train_data.shape)
min_dists = []
for i, row in enumerate(test_data):
row = np.reshape(row, (1, row.shape[0]))
repeat_row = np.tile(row, (nTrain, 1))
dists = repeat_row - train_data
dists = np.sqrt(np.sum(np.square(dists), axis=1))
# print(dists)
min_dist_idx = np.argmin(dists)
# print(i, min_dist_idx, dists[min_dist_idx])
min_dists.append(dists[min_dist_idx])
min_dists = np.array(min_dists)
# print(min_dists.shape)
# print(min_dists[min_dists > 0.75].shape)
t1 = min_dists[min_dists > thresh_hold].shape[0]
t2 = min_dists.shape[0]
print('\tthresh:', thresh_hold, '\tratio:', t1/t2, '(', t1, '/', t2 ,')')
return t1/t2
def find_best_threshold():
for j in range(10, 60, 5):
split = j / 100
print('\nSplit test/train:', split*10, '/', 10-split*10, '====================================')
best_thresh = 0
best_ratio = 0
for i in range(50, 110, 2):
thresh = i / 100
ratio = calculate_ratio(split, thresh)
if ratio >= best_ratio:
best_ratio = ratio
best_thresh =thresh
print('\tBEST THRESH: ', best_thresh, 'with ratio ', best_ratio)
get_dangerous_pairs(0.8)
find_best_threshold()