"""Task abstract class for evaluation and results.""" import logging from abc import ABC, abstractmethod from enum import Enum from importlib.metadata import version from typing import Any, List, Literal, Optional import datasets from pydantic import BaseModel, model_validator # HACK: if Modality is not defined, then import it from modality.py try: from ..modality import Modality except Exception: # if not, super hack to get the leaderboard working. # SHOULD MATCH the code exactly in modality.py # can we read the file and run that code? from enum import Enum class Modality(Enum): """Data modality, either DNA or protein sequence.""" PROTEIN = "protein" DNA = "dna" logging.basicConfig(level=logging.INFO) TaskType = Literal[ "classification", "pair_classification", "clustering", "eds", "bigene_mining", "retrieval", ] class TaskMetric(BaseModel): id: str display_name: str description: Optional[str] = None value: float = 0.0 class LayerResult(BaseModel): layer_number: int layer_display_name: str metrics: List[TaskMetric] class DGEBModel(BaseModel): hf_name: str num_layers: int num_params: int embed_dim: int class Dataset(BaseModel): path: str revision: str def load(self) -> datasets.DatasetDict: ds = datasets.load_dataset(self.path, revision=self.revision) if not isinstance(ds, datasets.DatasetDict): raise ValueError( f"Dataset {self.path} is not a datasets.DatasetDict object." ) return ds class TaskMetadata(BaseModel): id: str display_name: str description: str modality: Modality type: TaskType # List of datasets used by the task. # Each dataset is a dict of all arguments to pass to `datasets.load_dataset()`. datasets: List[Dataset] primary_metric_id: str # tasks.py class TaskResult(BaseModel): dgeb_version: str task: "TaskMetadata" # TODO: Convert model to ModelMetadata model: DGEBModel results: List[LayerResult] @model_validator(mode="after") def check_valid_primary_metric(self): for result in self.results: if all( metric.id != self.task.primary_metric_id for metric in result.metrics ): raise ValueError( f"Primary metric {self.task.primary_metric_id} not found in results.metrics" ) return self @staticmethod def from_dict( task_metadata: "TaskMetadata", layer_results: LayerResult, model_metadata: DGEBModel, ): return TaskResult( dgeb_version=version("dgeb"), task=task_metadata, model=model_metadata, results=list( LayerResult( layer_number=int(layer), layer_display_name=str(layer), metrics=[ TaskMetric(id=metric, display_name=metric, value=value) for metric, value in metrics.items() ], ) for layer, metrics in layer_results["layers"].items() ), ) # move to model.py? class Task(ABC): metadata: TaskMetadata # using Any instead of "BioSeqTransformer" to avoid installing all deps in leaderboard @abstractmethod def run(self, model: Any, layers: Optional[List[int]] = None) -> TaskResult: pass