Spaces:
Runtime error
Runtime error
"""The main entry point for performing comparison on analysis_gpt_mts.""" | |
from __future__ import annotations | |
import argparse | |
import os | |
import pandas as pd | |
from zeno_build.experiments import search_space | |
from zeno_build.experiments.experiment_run import ExperimentRun | |
from zeno_build.reporting.visualize import visualize | |
import config | |
from modeling import ( | |
GptMtInstance, | |
process_data, | |
process_output, | |
) | |
def analysis_gpt_mt_main( | |
input_dir: str, | |
results_dir: str, | |
) -> None: | |
"""Run the analysis of GPT-MT experiment.""" | |
# Get the dataset configuration | |
lang_pair_preset = config.main_space.dimensions["lang_pairs"] | |
if not isinstance(lang_pair_preset, search_space.Constant): | |
raise ValueError( | |
"All experiments must be run on a single set of language pairs." | |
) | |
lang_pairs = config.lang_pairs[lang_pair_preset.value] | |
# Load and exhaustiveize the format of the necessary data. | |
test_data: list[GptMtInstance] = process_data( | |
input_dir=input_dir, | |
lang_pairs=lang_pairs, | |
) | |
results: list[ExperimentRun] = [] | |
model_presets = config.main_space.dimensions["model_preset"] | |
if not isinstance(model_presets, search_space.Categorical): | |
raise ValueError("The model presets must be a categorical parameter.") | |
for model_preset in model_presets.choices: | |
output = process_output( | |
input_dir=input_dir, | |
lang_pairs=lang_pairs, | |
model_preset=model_preset, | |
) | |
results.append( | |
ExperimentRun(model_preset, {"model_preset": model_preset}, output) | |
) | |
# Perform the visualization | |
df = pd.DataFrame( | |
{ | |
"data": [x.data for x in test_data], | |
"label": [x.label for x in test_data], | |
"lang_pair": [x.lang_pair for x in test_data], | |
"doc_id": [x.doc_id for x in test_data], | |
} | |
) | |
labels = [x.label for x in test_data] | |
visualize( | |
df, | |
labels, | |
results, | |
"text-classification", | |
"data", | |
config.zeno_distill_and_metric_functions, | |
zeno_config={ | |
"cache_path": os.path.join(results_dir, "zeno_cache"), | |
"port": 7860, | |
"host": "0.0.0.0", | |
"editable": False, | |
}, | |
) | |
if __name__ == "__main__": | |
# Parse the command line arguments | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--input-dir", | |
type=str, | |
help="The directory of the GPT-MT repo.", | |
) | |
parser.add_argument( | |
"--results-dir", | |
type=str, | |
default="results", | |
help="The directory to store the results in.", | |
) | |
args = parser.parse_args() | |
analysis_gpt_mt_main( | |
input_dir=args.input_dir, | |
results_dir=args.results_dir, | |
) | |