lvwerra HF staff commited on
Commit
9b085fb
·
1 Parent(s): 5929372

Update Space (evaluate main: e4a27243)

Browse files
Files changed (2) hide show
  1. requirements.txt +1 -1
  2. xtreme_s.py +20 -7
requirements.txt CHANGED
@@ -1,2 +1,2 @@
1
- git+https://github.com/huggingface/evaluate@80448674f5447a9682afe051db243c4a13bfe4ff
2
  sklearn
 
1
+ git+https://github.com/huggingface/evaluate@e4a2724377909fe2aeb4357e3971e5a569673b39
2
  sklearn
xtreme_s.py CHANGED
@@ -13,7 +13,8 @@
13
  # limitations under the License.
14
  """ XTREME-S benchmark metric. """
15
 
16
- from typing import List
 
17
 
18
  import datasets
19
  from datasets.config import PY_VERSION
@@ -218,11 +219,22 @@ def wer_and_cer(preds, labels, concatenate_texts, config_name):
218
  return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
219
 
220
 
 
 
 
 
 
 
 
 
 
221
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
222
  class XtremeS(evaluate.Metric):
223
- def _info(self):
224
- if self.config_name not in _CONFIG_NAMES:
225
- raise KeyError(f"You should supply a configuration name selected in {_CONFIG_NAMES}")
 
 
226
 
227
  pred_type = "int64" if self.config_name in ["fleurs-lang_id", "minds14"] else "string"
228
 
@@ -230,6 +242,7 @@ class XtremeS(evaluate.Metric):
230
  description=_DESCRIPTION,
231
  citation=_CITATION,
232
  inputs_description=_KWARGS_DESCRIPTION,
 
233
  features=datasets.Features(
234
  {"predictions": datasets.Value(pred_type), "references": datasets.Value(pred_type)}
235
  ),
@@ -238,10 +251,10 @@ class XtremeS(evaluate.Metric):
238
  format="numpy",
239
  )
240
 
241
- def _compute(self, predictions, references, bleu_kwargs=None, wer_kwargs=None):
242
 
243
- bleu_kwargs = bleu_kwargs if bleu_kwargs is not None else {}
244
- wer_kwargs = wer_kwargs if wer_kwargs is not None else {}
245
 
246
  if self.config_name == "fleurs-lang_id":
247
  return {"accuracy": simple_accuracy(predictions, references)}
 
13
  # limitations under the License.
14
  """ XTREME-S benchmark metric. """
15
 
16
+ from dataclasses import dataclass
17
+ from typing import List, Optional
18
 
19
  import datasets
20
  from datasets.config import PY_VERSION
 
219
  return {"wer": compute_score(preds, labels, "wer"), "cer": compute_score(preds, labels, "cer")}
220
 
221
 
222
+ @dataclass
223
+ class XtremeSConfig(evaluate.info.Config):
224
+
225
+ name: str = "default"
226
+
227
+ bleu_kwargs: Optional[dict] = None
228
+ wer_kwargs: Optional[dict] = None
229
+
230
+
231
  @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
232
  class XtremeS(evaluate.Metric):
233
+
234
+ CONFIG_CLASS = XtremeSConfig
235
+ ALLOWED_CONFIG_NAMES = _CONFIG_NAMES
236
+
237
+ def _info(self, config):
238
 
239
  pred_type = "int64" if self.config_name in ["fleurs-lang_id", "minds14"] else "string"
240
 
 
242
  description=_DESCRIPTION,
243
  citation=_CITATION,
244
  inputs_description=_KWARGS_DESCRIPTION,
245
+ config=config,
246
  features=datasets.Features(
247
  {"predictions": datasets.Value(pred_type), "references": datasets.Value(pred_type)}
248
  ),
 
251
  format="numpy",
252
  )
253
 
254
+ def _compute(self, predictions, references):
255
 
256
+ bleu_kwargs = self.config.bleu_kwargs if self.config.bleu_kwargs is not None else {}
257
+ wer_kwargs = self.config.wer_kwargs if self.config.wer_kwargs is not None else {}
258
 
259
  if self.config_name == "fleurs-lang_id":
260
  return {"accuracy": simple_accuracy(predictions, references)}