synthesis-test / utils.py
theodotus's picture
Added missing name
6598449
raw
history blame
2.6 kB
import json
class Test:
names = ["袛小", "袘袦", "袨楔", "袉小", "袝小", "袗袣", "袗楔", "肖小", "袧袩", "袧袚", "袦袙", "袗袛", "袨小"]
def __init__(self):
self.data = self.read_test_data()
self.descriptions = self.read_image_descriptions()
def __len__(self):
return len(self.data)
@staticmethod
def read_test_data():
with open("test.json") as f:
test_data = json.load(f)
return test_data
@staticmethod
def read_image_descriptions():
with open("images/descriptions.txt") as f:
lines = f.readlines()
clean_lines = [line.strip() for line in lines]
return clean_lines
def get_question(self, idx):
question = self.data[idx]["question"]
return question
def get_answers(self, idx):
responces = self.data[idx]["responces"]
answers = [responce["answer"] for responce in responces]
return answers
def get_description(self, idx):
description = self.descriptions[idx]
return description
def get_image_path(self, idx):
image_path = f"images/{idx}.jpg"
return image_path
def convert_to_answer_ids(self, responces):
answer_ids = []
for idx, responce in enumerate(responces):
answers = self.get_answers(idx)
index = answers.index(responce)
answer_ids.append(index)
return answer_ids
def total_name_count(self):
name_count = {name: 0 for name in self.names}
for question in self.data:
for answer in question["responces"]:
for name in answer["name"]:
name_count[name] += 1
return name_count
def current_name_count(self, answer_ids):
name_count = {name: 0 for name in self.names}
for question, answer_idx in zip(self.data, answer_ids):
answer = question["responces"][answer_idx]
for name in answer["name"]:
name_count[name] += 1
return name_count
def select_best(self, answer_ids):
total_name_count = self.total_name_count()
current_name_count = self.current_name_count(answer_ids)
name_percent = {}
for name in self.names:
percent = current_name_count[name] / total_name_count[name]
name_percent[name] = percent
best_name = self.dict_max(name_percent)
return best_name
@staticmethod
def dict_max(dict):
key = max(dict, key=dict.get)
return key