Spaces:
Runtime error
Runtime error
"""Config for analyzing GPT-MT.""" | |
from __future__ import annotations | |
from collections.abc import Callable | |
from dataclasses import dataclass | |
from zeno_build.evaluation.text_features.capitalization import input_capital_char_ratio | |
from zeno_build.evaluation.text_features.exact_match import avg_exact_match, exact_match | |
from zeno_build.evaluation.text_features.frequency import output_max_word_freq | |
from zeno_build.evaluation.text_features.length import ( | |
doc_context_length, | |
input_length, | |
label_length, | |
output_length, | |
) | |
from zeno_build.evaluation.text_metrics.critique import ( | |
avg_bert_score, | |
avg_chrf, | |
avg_comet, | |
avg_length_ratio, | |
bert_score, | |
chrf, | |
comet, | |
length_ratio, | |
) | |
from zeno_build.experiments import search_space | |
from modeling import remove_leading_language | |
lang_pairs: dict[str, list[str]] = { | |
# All language pairs used in any experiment | |
"all_lang_pairs": [ | |
"csen", | |
"deen", | |
"defr", | |
"encs", | |
"ende", | |
"enha", | |
"enis", | |
"enja", | |
"enru", | |
"enuk", | |
"enzh", | |
"frde", | |
"haen", | |
"isen", | |
"jaen", | |
"ruen", | |
"uken", | |
"zhen", | |
], | |
# Language pairs used in the experiments on a limited number of language pairs | |
"limited_lang_pairs": [ | |
"deen", | |
"defr", | |
"ende", | |
"enru", | |
"enzh", | |
"frde", | |
"ruen", | |
"zhen", | |
], | |
} | |
# The search space for the main experiments | |
main_space = search_space.CombinatorialSearchSpace( | |
{ | |
"lang_pairs": search_space.Constant("all_lang_pairs"), | |
"model_preset": search_space.Categorical( | |
[ | |
"text-davinci-003-zeroshot", | |
"text-davinci-003-RR-1-shot", | |
"text-davinci-003-RR-5-shot", | |
"text-davinci-003-QR-1-shot", | |
"text-davinci-003-QR-5-shot", | |
"gpt-3.5-turbo-0301-zeroshot", | |
"gpt-4-0314-zeroshot", | |
"gpt-4-0314-zeroshot-postprocess", | |
"MS-Translator", | |
"google-cloud", | |
"wmt-best", | |
] | |
), | |
} | |
) | |
class GptMtConfig: | |
"""Config for gpt-MT models.""" | |
path: str | |
base_model: str | |
prompt_strategy: str | None = None | |
prompt_shots: int | None = None | |
post_processors: list[Callable[[str], str]] | None = None | |
# The details of each model | |
model_configs = { | |
"text-davinci-003-RR-1-shot": GptMtConfig( | |
"text-davinci-003/RR/1-shot", "text-davinci-003", "RR", 1 | |
), | |
"text-davinci-003-RR-5-shot": GptMtConfig( | |
"text-davinci-003/RR/5-shot", "text-davinci-003", "RR", 5 | |
), | |
"text-davinci-003-QR-1-shot": GptMtConfig( | |
"text-davinci-003/QR/1-shot", "text-davinci-003", "QR", 1 | |
), | |
"text-davinci-003-QR-5-shot": GptMtConfig( | |
"text-davinci-003/QR/5-shot", "text-davinci-003", "QR", 5 | |
), | |
"text-davinci-003-zeroshot": GptMtConfig( | |
"text-davinci-003/zeroshot", "text-davinci-003", None, 0 | |
), | |
"gpt-3.5-turbo-0301-zeroshot": GptMtConfig( | |
"gpt-3.5-turbo-0301/zeroshot", "gpt-3.5-turbo-0301", None, 0 | |
), | |
"gpt-4-0314-zeroshot": GptMtConfig("gpt-4-0314/zeroshot", "gpt-4-0314", None, 0), | |
"gpt-4-0314-zeroshot-postprocess": GptMtConfig( | |
"gpt-4-0314/zeroshot", "gpt-4-0314", None, 0, [remove_leading_language] | |
), | |
"MS-Translator": GptMtConfig("MS-Translator", "MS-Translator"), | |
"google-cloud": GptMtConfig("google-cloud", "google-cloud"), | |
"wmt-best": GptMtConfig("wmt-best", "wmt-best"), | |
} | |
sweep_distill_functions = [chrf] | |
sweep_metric_function = avg_chrf | |
# The functions used for Zeno visualization | |
zeno_distill_and_metric_functions = [ | |
output_length, | |
input_length, | |
label_length, | |
doc_context_length, | |
input_capital_char_ratio, | |
output_max_word_freq, | |
chrf, | |
comet, | |
length_ratio, | |
bert_score, | |
exact_match, | |
avg_chrf, | |
avg_comet, | |
avg_length_ratio, | |
avg_bert_score, | |
avg_exact_match, | |
] | |