Spaces:
Runtime error
Runtime error
File size: 7,604 Bytes
382191a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import numpy as np
import argparse
import json
import os
from comet import download_model, load_from_checkpoint
from transformers import AutoTokenizer
COMET_REF_MODELS = ["wmt20-comet-da", "wmt21-comet-mqm", "wmt22-comet-da"]
COMET_SRC_MODELS = ["wmt20-comet-qe-da", "wmt21-comet-qe-mqm", "wmt22-cometkiwi-da"]
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
tokenizer = AutoTokenizer.from_pretrained("microsoft/infoxlm-large")
def _is_doc_boundary(doc_ids, idx):
after_idx = min(len(doc_ids) - 1, idx + 1)
return (not doc_ids[after_idx] == doc_ids[idx]) or (idx == len(doc_ids) - 1)
def _build_context(doc, current_idx, context_window, start_left=True):
balance = context_window
low = current_idx if start_left else max([0, current_idx - (context_window // 2)])
balance -= (current_idx - low)
high = min([len(doc), current_idx + balance])
balance -= (high - current_idx)
low = max([0, low - balance])
pos = current_idx - low
return doc[low:high], pos
def _check_max_tokens(src_context, mt_context, ref_context=None, max_tokens=512):
src = " ".join(src_context).strip()
mt = " ".join(mt_context).strip()
if ref_context:
ref = " ".join(ref_context).strip()
full_input = tokenizer(" </s> ".join([src, mt, ref])).input_ids
else:
full_input = tokenizer(" </s> ".join([src, mt])).input_ids
return len(full_input) < max_tokens
def _calculate_doc_comet(args, model, src_docs, hyp_docs, ref_docs=None):
scores, doc_lengths = [], []
if ref_docs:
for s, h, r in zip(src_docs, hyp_docs, ref_docs):
data_for_eval = []
# Check if the doc has length shorter than the context length
if len(s) <= args.context_length:
data_for_eval.append({"src": " ".join(s).strip(), "mt": " ".join(h).strip(), "ref": " ".join(r).strip()})
else:
prev_context_src, prev_context_mt, prev_context_ref = [], [], []
for i in range(len(s)):
src_context, _ = _build_context(s, i, args.context_length)
mt_context, _ = _build_context(h, i, args.context_length)
ref_context, _ = _build_context(r, i, args.context_length)
# Ensure max_tokens is respected
reduce = 1
while (not _check_max_tokens(src_context, mt_context, ref_context=ref_context)) and (args.context_length - reduce > 1):
src_context, _ = _build_context(s, i, args.context_length - reduce)
mt_context, _ = _build_context(h, i, args.context_length - reduce)
ref_context, _ = _build_context(r, i, args.context_length - reduce)
reduce += 1
# Ensure same context is not evaluated twice
if not src_context == prev_context_src and not mt_context == prev_context_mt and not ref_context == prev_context_ref:
src, mt, ref = " ".join(src_context).strip(), " ".join(mt_context).strip(), " ".join(ref_context).strip()
data_for_eval.append({
"src": src, "mt": mt, "ref": ref
})
prev_context_src, prev_context_mt, prev_context_ref = src_context, mt_context, ref_context
# Compute the score
pred = model.predict(data_for_eval, batch_size=8, gpus=1)
scores.append(pred.system_score)
doc_lengths.append(len(s))
else:
for s, h in zip(src_docs, hyp_docs):
data_for_eval = []
# Check if the doc has length shorter than the context length
if len(s) <= args.context_length:
data_for_eval.append({"src": " ".join(s).strip(), "mt": " ".join(h).strip()})
else:
prev_context_src, prev_context_mt = [], []
for i in range(len(s)):
src_context, _ = _build_context(s, i, args.context_length)
mt_context, _ = _build_context(h, i, args.context_length)
# Ensure max_tokens is respected
reduce = 1
while (not _check_max_tokens(src_context, mt_context)) and (args.context_length - reduce > 1):
src_context, _ = _build_context(s, i, args.context_length - reduce)
mt_context, _ = _build_context(h, i, args.context_length - reduce)
reduce += 1
# Ensure same context is not evaluated twice
if not src_context == prev_context_src and not mt_context == prev_context_mt:
src, mt = " ".join(src_context).strip(), " ".join(mt_context).strip()
data_for_eval.append({
"src": src, "mt": mt
})
prev_context_src, prev_context_mt = src_context, mt_context
# Compute the score
pred = model.predict(data_for_eval, batch_size=8, gpus=1)
scores.append(pred.system_score) # type: ignore
doc_lengths.append(len(s))
return scores, doc_lengths
def _load_data(args):
with open(args.sources_file, 'r') as src_file, open(args.hypotheses_file, 'r') as hyp_file, open(args.docids_file, 'r') as docids_file:
sources = src_file.readlines()
hypotheses = hyp_file.readlines()
docids = docids_file.readlines()
src_docs, hyp_docs, ref_docs = [], [], None
current_src_doc, current_hyp_doc = [], []
i = 0
while i < len(docids):
current_src_doc.append(sources[i].strip())
current_hyp_doc.append(hypotheses[i].strip())
if _is_doc_boundary(docids, i):
src_docs.append(current_src_doc)
hyp_docs.append(current_hyp_doc)
current_src_doc, current_hyp_doc = [], []
i += 1
if args.references_file:
# Load reference files
with open(args.references_file, 'r') as ref_file:
references = ref_file.readlines()
ref_docs = []
current_ref_doc = []
i = 0
while i < len(docids):
current_ref_doc.append(references[i].strip())
if _is_doc_boundary(docids, i):
ref_docs.append(current_ref_doc)
current_ref_doc = []
i += 1
return src_docs, hyp_docs, ref_docs
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--sources-file', '-src', type=str, required=True, help='A path to the source file')
parser.add_argument('--hypotheses-file', '-hyp', type=str, required=True, help='A path to the model output file')
parser.add_argument('--references-file', '-ref', type=str, required=False, help='A path to the reference file')
parser.add_argument('--docids-file', '-doc', type=str, required=True, help='A path to the doc-ids file')
parser.add_argument('--model', type=str, required=True, help='The COMET model name used for automatic evaluation')
parser.add_argument('--sliding-window', type=int, required=False, default=1, help='The stride step over document')
parser.add_argument('--context-length', type=int, required=False, default=4, help='The number of sentences in a single context')
args = parser.parse_args()
comet_model_path = download_model(args.model)
model = load_from_checkpoint(comet_model_path)
if args.references_file:
assert args.model in COMET_REF_MODELS, f"Reference files should not be passed for evaluating {COMET_SRC_MODELS}"
else:
assert args.model not in COMET_REF_MODELS, f"Reference files are required for evaluating {COMET_REF_MODELS}"
src_docs, mt_docs, ref_docs = _load_data(args)
scores, _ = _calculate_doc_comet(args, model, src_docs, mt_docs, ref_docs)
ret = {
'model': args.model,
'sources_file': args.sources_file,
'mt_file': args.hypotheses_file,
'sliding_window': args.sliding_window,
'context_length': args.context_length,
'score': np.mean(scores)
}
print(json.dumps(ret, indent=2))
if __name__ == "__main__":
main()
|