Alex Cabrera
initial
382191a
raw
history blame
7.6 kB
import numpy as np
import argparse
import json
import os
from comet import download_model, load_from_checkpoint
from transformers import AutoTokenizer
COMET_REF_MODELS = ["wmt20-comet-da", "wmt21-comet-mqm", "wmt22-comet-da"]
COMET_SRC_MODELS = ["wmt20-comet-qe-da", "wmt21-comet-qe-mqm", "wmt22-cometkiwi-da"]
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
tokenizer = AutoTokenizer.from_pretrained("microsoft/infoxlm-large")
def _is_doc_boundary(doc_ids, idx):
after_idx = min(len(doc_ids) - 1, idx + 1)
return (not doc_ids[after_idx] == doc_ids[idx]) or (idx == len(doc_ids) - 1)
def _build_context(doc, current_idx, context_window, start_left=True):
balance = context_window
low = current_idx if start_left else max([0, current_idx - (context_window // 2)])
balance -= (current_idx - low)
high = min([len(doc), current_idx + balance])
balance -= (high - current_idx)
low = max([0, low - balance])
pos = current_idx - low
return doc[low:high], pos
def _check_max_tokens(src_context, mt_context, ref_context=None, max_tokens=512):
src = " ".join(src_context).strip()
mt = " ".join(mt_context).strip()
if ref_context:
ref = " ".join(ref_context).strip()
full_input = tokenizer(" </s> ".join([src, mt, ref])).input_ids
else:
full_input = tokenizer(" </s> ".join([src, mt])).input_ids
return len(full_input) < max_tokens
def _calculate_doc_comet(args, model, src_docs, hyp_docs, ref_docs=None):
scores, doc_lengths = [], []
if ref_docs:
for s, h, r in zip(src_docs, hyp_docs, ref_docs):
data_for_eval = []
# Check if the doc has length shorter than the context length
if len(s) <= args.context_length:
data_for_eval.append({"src": " ".join(s).strip(), "mt": " ".join(h).strip(), "ref": " ".join(r).strip()})
else:
prev_context_src, prev_context_mt, prev_context_ref = [], [], []
for i in range(len(s)):
src_context, _ = _build_context(s, i, args.context_length)
mt_context, _ = _build_context(h, i, args.context_length)
ref_context, _ = _build_context(r, i, args.context_length)
# Ensure max_tokens is respected
reduce = 1
while (not _check_max_tokens(src_context, mt_context, ref_context=ref_context)) and (args.context_length - reduce > 1):
src_context, _ = _build_context(s, i, args.context_length - reduce)
mt_context, _ = _build_context(h, i, args.context_length - reduce)
ref_context, _ = _build_context(r, i, args.context_length - reduce)
reduce += 1
# Ensure same context is not evaluated twice
if not src_context == prev_context_src and not mt_context == prev_context_mt and not ref_context == prev_context_ref:
src, mt, ref = " ".join(src_context).strip(), " ".join(mt_context).strip(), " ".join(ref_context).strip()
data_for_eval.append({
"src": src, "mt": mt, "ref": ref
})
prev_context_src, prev_context_mt, prev_context_ref = src_context, mt_context, ref_context
# Compute the score
pred = model.predict(data_for_eval, batch_size=8, gpus=1)
scores.append(pred.system_score)
doc_lengths.append(len(s))
else:
for s, h in zip(src_docs, hyp_docs):
data_for_eval = []
# Check if the doc has length shorter than the context length
if len(s) <= args.context_length:
data_for_eval.append({"src": " ".join(s).strip(), "mt": " ".join(h).strip()})
else:
prev_context_src, prev_context_mt = [], []
for i in range(len(s)):
src_context, _ = _build_context(s, i, args.context_length)
mt_context, _ = _build_context(h, i, args.context_length)
# Ensure max_tokens is respected
reduce = 1
while (not _check_max_tokens(src_context, mt_context)) and (args.context_length - reduce > 1):
src_context, _ = _build_context(s, i, args.context_length - reduce)
mt_context, _ = _build_context(h, i, args.context_length - reduce)
reduce += 1
# Ensure same context is not evaluated twice
if not src_context == prev_context_src and not mt_context == prev_context_mt:
src, mt = " ".join(src_context).strip(), " ".join(mt_context).strip()
data_for_eval.append({
"src": src, "mt": mt
})
prev_context_src, prev_context_mt = src_context, mt_context
# Compute the score
pred = model.predict(data_for_eval, batch_size=8, gpus=1)
scores.append(pred.system_score) # type: ignore
doc_lengths.append(len(s))
return scores, doc_lengths
def _load_data(args):
with open(args.sources_file, 'r') as src_file, open(args.hypotheses_file, 'r') as hyp_file, open(args.docids_file, 'r') as docids_file:
sources = src_file.readlines()
hypotheses = hyp_file.readlines()
docids = docids_file.readlines()
src_docs, hyp_docs, ref_docs = [], [], None
current_src_doc, current_hyp_doc = [], []
i = 0
while i < len(docids):
current_src_doc.append(sources[i].strip())
current_hyp_doc.append(hypotheses[i].strip())
if _is_doc_boundary(docids, i):
src_docs.append(current_src_doc)
hyp_docs.append(current_hyp_doc)
current_src_doc, current_hyp_doc = [], []
i += 1
if args.references_file:
# Load reference files
with open(args.references_file, 'r') as ref_file:
references = ref_file.readlines()
ref_docs = []
current_ref_doc = []
i = 0
while i < len(docids):
current_ref_doc.append(references[i].strip())
if _is_doc_boundary(docids, i):
ref_docs.append(current_ref_doc)
current_ref_doc = []
i += 1
return src_docs, hyp_docs, ref_docs
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--sources-file', '-src', type=str, required=True, help='A path to the source file')
parser.add_argument('--hypotheses-file', '-hyp', type=str, required=True, help='A path to the model output file')
parser.add_argument('--references-file', '-ref', type=str, required=False, help='A path to the reference file')
parser.add_argument('--docids-file', '-doc', type=str, required=True, help='A path to the doc-ids file')
parser.add_argument('--model', type=str, required=True, help='The COMET model name used for automatic evaluation')
parser.add_argument('--sliding-window', type=int, required=False, default=1, help='The stride step over document')
parser.add_argument('--context-length', type=int, required=False, default=4, help='The number of sentences in a single context')
args = parser.parse_args()
comet_model_path = download_model(args.model)
model = load_from_checkpoint(comet_model_path)
if args.references_file:
assert args.model in COMET_REF_MODELS, f"Reference files should not be passed for evaluating {COMET_SRC_MODELS}"
else:
assert args.model not in COMET_REF_MODELS, f"Reference files are required for evaluating {COMET_REF_MODELS}"
src_docs, mt_docs, ref_docs = _load_data(args)
scores, _ = _calculate_doc_comet(args, model, src_docs, mt_docs, ref_docs)
ret = {
'model': args.model,
'sources_file': args.sources_file,
'mt_file': args.hypotheses_file,
'sliding_window': args.sliding_window,
'context_length': args.context_length,
'score': np.mean(scores)
}
print(json.dumps(ret, indent=2))
if __name__ == "__main__":
main()