Spaces:
Runtime error
Runtime error
Upload evaluate.py
Browse files- evaluate.py +152 -0
evaluate.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python3 -u
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the MIT license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
import sys
|
10 |
+
import json
|
11 |
+
from itertools import chain
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.distributed as dist
|
16 |
+
from fairseq import distributed_utils, options, tasks, utils
|
17 |
+
from fairseq.dataclass.utils import convert_namespace_to_omegaconf
|
18 |
+
from fairseq.logging import progress_bar
|
19 |
+
from fairseq.utils import reset_logging
|
20 |
+
from omegaconf import DictConfig
|
21 |
+
|
22 |
+
from utils import checkpoint_utils
|
23 |
+
from utils.eval_utils import eval_step
|
24 |
+
|
25 |
+
logging.basicConfig(
|
26 |
+
format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
|
27 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
28 |
+
level=os.environ.get("LOGLEVEL", "INFO").upper(),
|
29 |
+
stream=sys.stdout,
|
30 |
+
)
|
31 |
+
logger = logging.getLogger("ofa.evaluate")
|
32 |
+
|
33 |
+
|
34 |
+
def apply_half(t):
|
35 |
+
if t.dtype is torch.float32:
|
36 |
+
return t.to(dtype=torch.half)
|
37 |
+
return t
|
38 |
+
|
39 |
+
|
40 |
+
def main(cfg: DictConfig):
|
41 |
+
utils.import_user_module(cfg.common)
|
42 |
+
|
43 |
+
reset_logging()
|
44 |
+
logger.info(cfg)
|
45 |
+
|
46 |
+
assert (
|
47 |
+
cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
|
48 |
+
), "Must specify batch size either with --max-tokens or --batch-size"
|
49 |
+
|
50 |
+
# Fix seed for stochastic decoding
|
51 |
+
if cfg.common.seed is not None and not cfg.generation.no_seed_provided:
|
52 |
+
np.random.seed(cfg.common.seed)
|
53 |
+
utils.set_torch_seed(cfg.common.seed)
|
54 |
+
|
55 |
+
use_fp16 = cfg.common.fp16
|
56 |
+
use_cuda = torch.cuda.is_available() and not cfg.common.cpu
|
57 |
+
|
58 |
+
if use_cuda:
|
59 |
+
torch.cuda.set_device(cfg.distributed_training.device_id)
|
60 |
+
|
61 |
+
# Load ensemble
|
62 |
+
overrides = eval(cfg.common_eval.model_overrides)
|
63 |
+
logger.info("loading model(s) from {}".format(cfg.common_eval.path))
|
64 |
+
models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(
|
65 |
+
utils.split_paths(cfg.common_eval.path),
|
66 |
+
arg_overrides=overrides,
|
67 |
+
suffix=cfg.checkpoint.checkpoint_suffix,
|
68 |
+
strict=(cfg.checkpoint.checkpoint_shard_count == 1),
|
69 |
+
num_shards=cfg.checkpoint.checkpoint_shard_count,
|
70 |
+
)
|
71 |
+
|
72 |
+
# loading the dataset should happen after the checkpoint has been loaded so we can give it the saved task config
|
73 |
+
task.load_dataset(cfg.dataset.gen_subset, task_cfg=saved_cfg.task)
|
74 |
+
|
75 |
+
# Move models to GPU
|
76 |
+
for model in models:
|
77 |
+
model.eval()
|
78 |
+
if use_fp16:
|
79 |
+
model.half()
|
80 |
+
if use_cuda and not cfg.distributed_training.pipeline_model_parallel:
|
81 |
+
model.cuda()
|
82 |
+
model.prepare_for_inference_(cfg)
|
83 |
+
|
84 |
+
# Load dataset (possibly sharded)
|
85 |
+
itr = task.get_batch_iterator(
|
86 |
+
dataset=task.dataset(cfg.dataset.gen_subset),
|
87 |
+
max_tokens=cfg.dataset.max_tokens,
|
88 |
+
max_sentences=cfg.dataset.batch_size,
|
89 |
+
max_positions=utils.resolve_max_positions(
|
90 |
+
task.max_positions(), *[m.max_positions() for m in models]
|
91 |
+
),
|
92 |
+
ignore_invalid_inputs=cfg.dataset.skip_invalid_size_inputs_valid_test,
|
93 |
+
required_batch_size_multiple=cfg.dataset.required_batch_size_multiple,
|
94 |
+
seed=cfg.common.seed,
|
95 |
+
num_shards=cfg.distributed_training.distributed_world_size,
|
96 |
+
shard_id=cfg.distributed_training.distributed_rank,
|
97 |
+
num_workers=cfg.dataset.num_workers,
|
98 |
+
data_buffer_size=cfg.dataset.data_buffer_size,
|
99 |
+
).next_epoch_itr(shuffle=False)
|
100 |
+
progress = progress_bar.progress_bar(
|
101 |
+
itr,
|
102 |
+
log_format=cfg.common.log_format,
|
103 |
+
log_interval=cfg.common.log_interval,
|
104 |
+
default_log_format=("tqdm" if not cfg.common.no_progress_bar else "simple"),
|
105 |
+
)
|
106 |
+
|
107 |
+
# Initialize generator
|
108 |
+
generator = task.build_generator(models, cfg.generation)
|
109 |
+
|
110 |
+
results = []
|
111 |
+
score_sum = torch.FloatTensor([0]).cuda()
|
112 |
+
score_cnt = torch.FloatTensor([0]).cuda()
|
113 |
+
for sample in progress:
|
114 |
+
if "net_input" not in sample:
|
115 |
+
continue
|
116 |
+
sample = utils.move_to_cuda(sample) if use_cuda else sample
|
117 |
+
sample = utils.apply_to_sample(apply_half, sample) if cfg.common.fp16 else sample
|
118 |
+
with torch.no_grad():
|
119 |
+
result, scores = eval_step(task, generator, models, sample)
|
120 |
+
results += result
|
121 |
+
score_sum += sum(scores) if scores is not None else 0
|
122 |
+
score_cnt += len(scores) if scores is not None else 0
|
123 |
+
progress.log({"sentences": sample["nsentences"]})
|
124 |
+
|
125 |
+
gather_results = None
|
126 |
+
if cfg.distributed_training.distributed_world_size > 1:
|
127 |
+
gather_results = [None for _ in range(dist.get_world_size())]
|
128 |
+
dist.all_gather_object(gather_results, results)
|
129 |
+
dist.all_reduce(score_sum.data)
|
130 |
+
dist.all_reduce(score_cnt.data)
|
131 |
+
if score_cnt.item() > 0:
|
132 |
+
logger.info("score_sum: {}, score_cnt: {}, score: {}".format(
|
133 |
+
score_sum, score_cnt, round(score_sum.item() / score_cnt.item(), 4)
|
134 |
+
))
|
135 |
+
|
136 |
+
if cfg.distributed_training.distributed_world_size == 1 or dist.get_rank() == 0:
|
137 |
+
os.makedirs(cfg.common_eval.results_path, exist_ok=True)
|
138 |
+
output_path = os.path.join(cfg.common_eval.results_path, "{}_predict.json".format(cfg.dataset.gen_subset))
|
139 |
+
gather_results = list(chain(*gather_results)) if gather_results is not None else results
|
140 |
+
with open(output_path, 'w') as fw:
|
141 |
+
json.dump(gather_results, fw)
|
142 |
+
|
143 |
+
|
144 |
+
def cli_main():
|
145 |
+
parser = options.get_generation_parser()
|
146 |
+
args = options.parse_args_and_arch(parser)
|
147 |
+
cfg = convert_namespace_to_omegaconf(args)
|
148 |
+
distributed_utils.call_main(cfg, main)
|
149 |
+
|
150 |
+
|
151 |
+
if __name__ == "__main__":
|
152 |
+
cli_main()
|