|
import colorsys |
|
import json |
|
import os |
|
import random |
|
from concurrent.futures import ThreadPoolExecutor |
|
from dataclasses import dataclass, make_dataclass |
|
from datetime import datetime |
|
from io import BytesIO |
|
|
|
import aiohttp |
|
import evaluate |
|
import numpy as np |
|
import pandas as pd |
|
import plotly.graph_objects as go |
|
from huggingface_hub import hf_hub_download, list_repo_files |
|
from pydub import AudioSegment |
|
|
|
from constants import WHISPER_OPEN_AI_LINK |
|
|
|
|
|
wer_metric = evaluate.load("wer") |
|
|
|
|
|
def compute_average_wer(results): |
|
""" |
|
Compute the average Word Error Rate (WER) for a list of transcription results. |
|
:param results: List of dictionaries, each containing 'reference' and 'prediction' keys |
|
:return: Average WER as a percentage, rounded to 2 decimal places |
|
This function calculates the WER for each reference-prediction pair and returns |
|
the average. If no predictions are provided, it returns 100% WER. |
|
""" |
|
references = [result["reference"] for result in results] |
|
predictions = [result["prediction"] for result in results] |
|
if len(predictions) == 0: |
|
return 1 |
|
return round( |
|
wer_metric.compute(references=references, predictions=predictions) * 100.0, |
|
2, |
|
) |
|
|
|
|
|
def read_json_line_by_line(file_path): |
|
""" |
|
Read a JSON file line by line, parsing each line as a separate JSON object. |
|
:param file_path: Path to the JSON file |
|
:return: List of parsed JSON objects |
|
This function is useful for reading large JSON files that contain one JSON object |
|
per line. It handles JSON parsing errors gracefully, skipping invalid lines. |
|
""" |
|
data = [] |
|
with open(file_path, "r") as f: |
|
for line in f: |
|
try: |
|
item = json.loads(line.strip()) |
|
data.append(item) |
|
except json.JSONDecodeError: |
|
print(f"Skipping invalid JSON in {file_path}: {line}") |
|
return data |
|
|
|
|
|
def group_wer(group): |
|
""" |
|
Calculate the Word Error Rate (WER) for a group of transcriptions. |
|
:param group: DataFrame group containing 'normalized_reference' and 'normalized_prediction' columns |
|
:return: Average WER for the group |
|
This function is typically used with DataFrame groupby operations to calculate |
|
WER for specific groups of transcriptions. |
|
""" |
|
return compute_average_wer( |
|
group[["normalized_reference", "normalized_prediction"]] |
|
.rename( |
|
columns={ |
|
"normalized_reference": "reference", |
|
"normalized_prediction": "prediction", |
|
} |
|
) |
|
.to_dict("records") |
|
) |
|
|
|
|
|
def load_multilingual_results(csv_file): |
|
""" |
|
Load multilingual results from a CSV file into a pandas DataFrame. |
|
:param csv_file: Path to the CSV file containing multilingual results |
|
:return: DataFrame with the loaded results, or None if the file is not found |
|
This function attempts to load a CSV file using pandas, handling potential |
|
FileNotFoundError exceptions. |
|
""" |
|
try: |
|
df = pd.json_normalize(csv_file) |
|
return df |
|
except FileNotFoundError: |
|
return None |
|
|
|
|
|
def download_dataset(repo_id, local_dir, remote_dir, path_includes=""): |
|
""" |
|
Download benchmark result files from a specified Hugging Face repository to a local directory. |
|
:param repo_id: ID of the Hugging Face repository |
|
:param local_dir: Local directory where downloaded files will be saved |
|
:param remote_dir: Remote directory within the repository to download from |
|
This function uses the Hugging Face Hub API to list and download files from a |
|
specific directory in a repository. It forces the download to ensure up-to-date files. |
|
""" |
|
files = list_repo_files(repo_id, repo_type="dataset") |
|
directory_files = [ |
|
file for file in files if file.startswith(remote_dir) and path_includes in file |
|
] |
|
with ThreadPoolExecutor() as executor: |
|
executor.map( |
|
lambda file: hf_hub_download( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
filename=file, |
|
local_dir=local_dir, |
|
force_download=True, |
|
), |
|
directory_files, |
|
) |
|
|
|
|
|
def process_file(file_path): |
|
""" |
|
Process a file containing JSON objects delimited by new lines. |
|
:param file_path: Path to the file to be processed |
|
:return: List of dictionaries, each representing a parsed JSON object |
|
This function reads the file line by line, parsing each line as a JSON object. |
|
It handles potential JSON decoding errors, printing error messages for invalid lines. |
|
""" |
|
data = [] |
|
with open(file_path, "r") as file: |
|
for line in file: |
|
line = line.strip() |
|
if not line: |
|
continue |
|
try: |
|
json_obj = json.loads(line) |
|
data.append(json_obj) |
|
except json.JSONDecodeError as e: |
|
print(f"Error decoding JSON in line: {line}") |
|
print(f"Error message: {str(e)}") |
|
return data |
|
|
|
|
|
def dir_to_json(root_dir, output_file): |
|
""" |
|
Convert a directory of benchmark result files to a single JSON file. |
|
:param root_dir: Root directory containing the benchmark result files |
|
:param output_file: Output file where the JSON data will be saved |
|
This function walks through the directory structure, processes each file, |
|
and writes the combined data to a single JSON file. It extracts metadata |
|
from the file path and includes it in the JSON output. |
|
""" |
|
with open(output_file, "w") as outfile: |
|
for subdir, _, files in os.walk(root_dir): |
|
for file in files: |
|
file_path = os.path.join(subdir, file) |
|
|
|
if file_path.endswith(".DS_Store") or "summary" in file_path: |
|
continue |
|
parts = file_path.split(os.sep) |
|
model_version = parts[2] |
|
device_name = parts[3].replace("_", " ") |
|
os_type_version = parts[4] |
|
dataset_name = parts[5] |
|
timestamp_commit = parts[6].replace(".json", "") |
|
timestamp, commit_hash, commit_timestamp = timestamp_commit.split("_") |
|
|
|
data_list = process_file(file_path) |
|
for data in data_list: |
|
original_entry = { |
|
"model": model_version.replace("_", "/"), |
|
"device": device_name, |
|
"os": os_type_version.replace("_", " "), |
|
"wer": data["wer"], |
|
"dataset_name": dataset_name, |
|
"reference_transcription": data["reference_transcription"], |
|
"prediction_transcription": data["prediction_transcription"], |
|
"difference_transcription": data["difference_transcription"], |
|
"audio_file_url": data["audio_file_url"], |
|
"timestamp": timestamp.replace("-", ":").replace(":", "-", 2), |
|
"commit_hash": commit_hash, |
|
"commit_timestamp": commit_timestamp, |
|
} |
|
|
|
outfile.write(json.dumps(original_entry) + "\n") |
|
|
|
|
|
async def download_audio_to_ndarray(url): |
|
""" |
|
Downloads an audio file from a URL and converts it to a NumPy array. |
|
:param url: The URL of the audio file to download |
|
:return: A tuple containing the sample rate and audio data as a NumPy array |
|
This asynchronous function uses aiohttp to download the audio file, |
|
converts it to an AudioSegment, and then to a NumPy array. It handles |
|
both mono and stereo audio files. |
|
""" |
|
async with aiohttp.ClientSession() as session: |
|
async with session.get(url) as response: |
|
if response.status == 200: |
|
audio_bytes = BytesIO(await response.read()) |
|
audio = AudioSegment.from_file(audio_bytes, format="mp3") |
|
audio_data = np.array(audio.get_array_of_samples()) |
|
if audio.channels == 2: |
|
audio_data = audio_data.reshape((-1, 2)) |
|
return audio.frame_rate, audio_data |
|
else: |
|
return None, None |
|
|
|
|
|
async def play_audio(url): |
|
""" |
|
Wrapper function for Gradio to play audio from a URL. |
|
:param url: The URL of the audio file to play |
|
:return: A tuple of sample rate and audio data, or an error message |
|
This function uses download_audio_to_ndarray to get the audio data |
|
and returns it in a format suitable for Gradio's audio player. |
|
""" |
|
sample_rate, audio_data = await download_audio_to_ndarray(url) |
|
if audio_data is None: |
|
return "Error downloading the file" |
|
else: |
|
return sample_rate, audio_data |
|
|
|
|
|
def get_filter_cond(df, model, device, os, dataset, timestamp=None): |
|
""" |
|
Creates a filter condition for a DataFrame based on specified parameters. |
|
:param df: DataFrame containing the transcription data |
|
:param model: String representing the model name |
|
:param device: String representing the device name |
|
:param os: String representing the OS name |
|
:param dataset: String representing the dataset name |
|
:param timestamp: Optional timestamp for filtering (default: None) |
|
:return: A boolean mask for filtering the DataFrame |
|
This function constructs a complex boolean condition for filtering |
|
the DataFrame based on the provided parameters. |
|
""" |
|
filter_cond = ( |
|
(df["model"] == model) |
|
& (df["device"] == device) |
|
& (df["os"] == os) |
|
& (df["dataset_name"] == dataset) |
|
) |
|
return filter_cond & (df["timestamp"] == timestamp) if timestamp else filter_cond |
|
|
|
|
|
def get_filtered_transcript(df, model, device, os, dataset, timestamp): |
|
""" |
|
Retrieves filtered transcription data from a DataFrame. |
|
:param df: DataFrame containing the transcription data |
|
:param model: String representing the model name |
|
:param device: String representing the device name |
|
:param os: String representing the OS name |
|
:param dataset: String representing the dataset name |
|
:param timestamp: String representing the timestamp |
|
:return: A filtered DataFrame with transcription data |
|
This function applies a filter to the input DataFrame and returns |
|
relevant columns for transcription analysis. |
|
""" |
|
filter_cond = get_filter_cond(df, model, device, os, dataset, timestamp) |
|
df = df[filter_cond][ |
|
[ |
|
"reference_transcription", |
|
"prediction_transcription", |
|
"difference_transcription", |
|
"audio_file_url", |
|
] |
|
] |
|
return df |
|
|
|
|
|
def get_filtered_timestamps(df, model, device, os, dataset): |
|
""" |
|
Retrieves unique timestamps for a specific model, device, OS, and dataset combination. |
|
:param df: DataFrame containing the transcription data |
|
:param model: String representing the model name |
|
:param device: String representing the device name |
|
:param os: String representing the OS name |
|
:param dataset: String representing the dataset name |
|
:return: A filtered DataFrame containing unique timestamps |
|
This function is useful for getting a list of available timestamps |
|
for a specific configuration, which can be used for further analysis or UI elements. |
|
""" |
|
filter_cond = get_filter_cond(df, model, device, os, dataset) |
|
df = df[filter_cond][["timestamp"]].drop_duplicates() |
|
return df |
|
|
|
|
|
def make_model_name_clickable_link(model): |
|
""" |
|
Creates an HTML link to the Hugging Face model page. |
|
:param model: String representing the model name |
|
:return: An HTML string containing a clickable link to the model page |
|
This function generates a formatted HTML link that can be used in |
|
web interfaces to provide direct access to the model's page on Hugging Face. |
|
""" |
|
return f"""<a style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" href="https://huggingface.co/argmaxinc/whisperkit-coreml/tree/main/{model.replace('/', '_')}" target="_blank">{model}</a>""" |
|
|
|
|
|
def make_dataset_wer_clickable_link(row, dataset): |
|
""" |
|
Creates a clickable link for the WER value of a dataset. |
|
:param row: Row containing the dataset WER value |
|
:param dataset: String representing the dataset name |
|
:return: An HTML string containing a clickable link to the dataset's WER details |
|
This function generates a formatted HTML link that can be used in |
|
web interfaces to provide access to detailed WER information for a specific dataset. |
|
""" |
|
dataset_column = f"{dataset}" |
|
href = WHISPER_OPEN_AI_LINK.format( |
|
row["Model"].replace("/", "_"), |
|
dataset, |
|
) |
|
return f'<a style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" href="{href}">{row[dataset_column]}</a>' |
|
|
|
|
|
def make_timestamp_clickable_link(model, dataset, timestamp): |
|
""" |
|
Creates a clickable link for a timestamp. |
|
:param model: String representing the model name |
|
:param dataset: String representing the dataset name |
|
:param timestamp: Timestamp to be displayed and used in the link |
|
:return: An HTML string containing a clickable div for the timestamp |
|
This function generates a formatted HTML div that can be used as a clickable |
|
element in web interfaces, typically for displaying and interacting with specific timestamps. |
|
""" |
|
elem_id = ( |
|
f"{dataset}-{model}-{timestamp}".replace(" ", "_") |
|
.replace('"', "") |
|
.replace("'", "") |
|
.replace(",", "") |
|
) |
|
onclick = f"onclick=\"document.getElementById('{elem_id}').click();\"" |
|
return f'<div style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" {onclick} href="#">{timestamp}</div>' |
|
|
|
|
|
def make_multilingual_model_clickable_link(model): |
|
""" |
|
Creates a clickable link for a multilingual model name. |
|
:param model: String representing the model name |
|
:return: An HTML string containing a clickable div for the model name |
|
This function generates a formatted HTML div that can be used as a clickable |
|
element in web interfaces, typically for displaying and interacting with multilingual model names. |
|
""" |
|
elem_id = ( |
|
f"{model}".replace(" ", "_").replace('"', "").replace("'", "").replace(",", "") |
|
) |
|
onclick = f"onclick=\"document.getElementById('{elem_id}').click();console.log('hello');\"" |
|
return f'<div style="color: #3B82F6; text-decoration: underline; text-decoration-style: dotted;" {onclick} href="#">{model}</div>' |
|
|
|
|
|
def plot_metric( |
|
df, y_axis_col, y_axis_title, fig_title, filter_input=None, exclude_input=None |
|
): |
|
""" |
|
Plots a metric for each model-device-OS group in a DataFrame. |
|
:param df: DataFrame containing the benchmark data |
|
:param y_axis_col: DataFrame column to use as the y-axis |
|
:param y_axis_title: Display name for the y-axis |
|
:param fig_title: Display title for the figure |
|
:param filter_input: Optional string to filter the model-device-OS combinations |
|
:param exclude_input: Optional string to exclude model-device-OS combinations |
|
:return: A Plotly figure object |
|
""" |
|
grouped = df.groupby(["model", "device", "os"]) |
|
sorted_groups = [group.sort_values("commit_timestamp") for _, group in grouped] |
|
|
|
if filter_input: |
|
filters = [f.strip().lower() for f in filter_input.split(";")] |
|
sorted_groups = [ |
|
group |
|
for group in sorted_groups |
|
if any( |
|
f |
|
in f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}".lower() |
|
for f in filters |
|
) |
|
] |
|
|
|
if exclude_input: |
|
excludes = [e.strip().lower() for e in exclude_input.split(";")] |
|
sorted_groups = [ |
|
group |
|
for group in sorted_groups |
|
if not any( |
|
e |
|
in f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}".lower() |
|
for e in excludes |
|
) |
|
] |
|
|
|
base_colors = ["#4542f4", "#0e0c06", "#ccf0a7", "#ff7f4e", "#ffd15a"] |
|
num_colors = len(sorted_groups) |
|
random_colors = generate_random_colors(base_colors, num_colors) |
|
fig = go.Figure() |
|
for i, group in enumerate(sorted_groups): |
|
model_device_os = ( |
|
f"{group['model'].iloc[0]}-{group['device'].iloc[0]}-{group['os'].iloc[0]}" |
|
) |
|
fig.add_trace( |
|
go.Scatter( |
|
x=group["commit_timestamp"].apply( |
|
lambda x: datetime.strptime(x, "%Y-%m-%dT%H%M%S").strftime( |
|
"%Y-%m-%d %H:%M:%S" |
|
) |
|
), |
|
y=group[y_axis_col], |
|
mode="lines+markers", |
|
name=model_device_os, |
|
line=dict(color=random_colors[i % len(random_colors)]), |
|
marker=dict(color=random_colors[i % len(random_colors)]), |
|
hovertemplate=( |
|
f"<b>{model_device_os}</b><br>" |
|
"Timestamp: %{x}<br>" |
|
f"{y_axis_title}: %{{y:.2f}}<br>" |
|
"<extra></extra>" |
|
), |
|
) |
|
) |
|
fig.update_layout( |
|
title=fig_title, |
|
xaxis_title="Commit Timestamp", |
|
yaxis_title=y_axis_title, |
|
legend_title="Model-Device-OS", |
|
width=1100, |
|
height=600, |
|
plot_bgcolor="rgb(250,249,244)", |
|
) |
|
return fig |
|
|
|
|
|
def fields(raw_class): |
|
""" |
|
Returns the fields of a dataclass. |
|
:param raw_class: The dataclass to inspect |
|
:return: List of fields in the dataclass |
|
This utility function extracts and returns all the fields defined in a dataclass, |
|
excluding special methods and attributes. |
|
""" |
|
return [ |
|
v for k, v in raw_class.__dict__.items() if k[:2] != "__" and k[-2:] != "__" |
|
] |
|
|
|
|
|
def get_os_name_and_version(os_string): |
|
""" |
|
Extracts the OS name and major version from a string. |
|
:param os_string: String representing the OS name and version |
|
:return: Formatted string with OS name and major version |
|
This function splits the input string into OS name and version, |
|
then returns a formatted string with just the major version number. |
|
""" |
|
os_name, os_version = os_string.split() |
|
os_version = os_version.split(".")[0] |
|
return f"{os_name} {os_version}" |
|
|
|
|
|
def create_initial_quality_column_dict(): |
|
""" |
|
Creates the initial column dictionary for the quality table. |
|
:return: A list of column dictionaries |
|
This function defines the basic structure of the quality table, |
|
including columns for model, average WER, and QoI (Quality of Implementation). |
|
""" |
|
return [ |
|
[ |
|
"model", |
|
ColumnContent, |
|
ColumnContent("Model", "html", True, never_hidden=True), |
|
], |
|
["average_wer", ColumnContent, ColumnContent("Average WER", "html", True)], |
|
["qoi", ColumnContent, ColumnContent("QoI", "html", True)], |
|
] |
|
|
|
|
|
def calculate_parity(m2_ultra_wer, row): |
|
""" |
|
Calculates the WER parity between M2 Ultra and the current model. |
|
:param m2_ultra_wer: DataFrame containing WER values for M2 Ultra |
|
:param row: Current row being processed |
|
:return: WER difference between M2 Ultra and current model, or None if not applicable |
|
This function computes the percentage difference in WER between the M2 Ultra model |
|
and the current model, providing a measure of relative performance. |
|
""" |
|
if row["Model"] in m2_ultra_wer.index: |
|
return round(m2_ultra_wer[row["Model"]] - row["Average WER"], 2) |
|
return None |
|
|
|
|
|
def create_initial_performance_column_dict(): |
|
""" |
|
Creates the initial column dictionary for the performance table. |
|
:return: A list of column dictionaries |
|
This function defines the basic structure of the performance table, |
|
including columns for model, device, OS, average WER, QoI, speed, and tokens per second. |
|
""" |
|
return [ |
|
[ |
|
"model", |
|
ColumnContent, |
|
ColumnContent("Model", "html", True, never_hidden=True), |
|
], |
|
[ |
|
"device", |
|
ColumnContent, |
|
ColumnContent("Device", "html", True, never_hidden=True), |
|
], |
|
["os", ColumnContent, ColumnContent("OS", "html", True, never_hidden=True)], |
|
["average_wer", ColumnContent, ColumnContent("Average WER", "html", True)], |
|
["qoi", ColumnContent, ColumnContent("QoI", "html", False)], |
|
["speed", ColumnContent, ColumnContent("Speed", "html", False)], |
|
["toks", ColumnContent, ColumnContent("Tok / s", "html", False)], |
|
] |
|
|
|
|
|
def add_datasets_to_quality_columns(column_dict, datasets): |
|
""" |
|
Adds dataset-specific columns to the quality table column dictionary. |
|
:param column_dict: The initial column dictionary |
|
:param datasets: List of dataset names to add |
|
:return: A dictionary containing the updated column dictionary and related metadata |
|
This function extends the quality table structure with columns for each dataset, |
|
and creates a dataclass to represent the table structure. It also generates |
|
metadata about the columns for use in the UI. |
|
""" |
|
updated_column_dict = column_dict.copy() |
|
|
|
for dataset in datasets: |
|
field_name = dataset.replace("-", "") |
|
updated_column_dict.append( |
|
[field_name, ColumnContent, ColumnContent(dataset, "html", True)] |
|
) |
|
|
|
AutoEvalColumn = make_dataclass("AutoEvalColumn", updated_column_dict, frozen=True) |
|
|
|
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] |
|
TYPES = [c.type for c in fields(AutoEvalColumn) if not c.hidden] |
|
ALWAYS_HERE_COLS = [c.name for c in fields(AutoEvalColumn) if c.never_hidden] |
|
TOGGLE_COLS = [c.name for c in fields(AutoEvalColumn) if not c.never_hidden] |
|
SELECTED_COLS = [ |
|
c.name |
|
for c in fields(AutoEvalColumn) |
|
if not c.never_hidden and c.displayed_by_default |
|
] |
|
|
|
return { |
|
"column_dict": updated_column_dict, |
|
"AutoEvalColumn": AutoEvalColumn, |
|
"COLS": COLS, |
|
"TYPES": TYPES, |
|
"ALWAYS_HERE_COLS": ALWAYS_HERE_COLS, |
|
"TOGGLE_COLS": TOGGLE_COLS, |
|
"SELECTED_COLS": SELECTED_COLS, |
|
} |
|
|
|
|
|
def add_datasets_to_performance_columns(column_dict, datasets): |
|
""" |
|
Adds dataset-specific columns to the performance table column dictionary. |
|
:param column_dict: The initial column dictionary |
|
:param datasets: List of dataset names to add |
|
:return: A dictionary containing the updated column dictionary and related metadata |
|
This function extends the performance table structure with columns for each dataset, |
|
adding both speed and tokens per second metrics. It also creates a dataclass to |
|
represent the table structure and generates metadata about the columns for use in the UI. |
|
""" |
|
updated_column_dict = column_dict.copy() |
|
|
|
for dataset in datasets: |
|
field_name = dataset.replace("-", "") |
|
updated_column_dict.append( |
|
[ |
|
f"{field_name}_speed", |
|
ColumnContent, |
|
ColumnContent( |
|
f"{'Short-Form' if dataset == 'librispeech-10mins' else 'Long-Form'} Speed", |
|
"html", |
|
True, |
|
), |
|
] |
|
) |
|
updated_column_dict.append( |
|
[ |
|
f"{field_name}_toks", |
|
ColumnContent, |
|
ColumnContent( |
|
f"{'Short-Form' if dataset == 'librispeech-10mins' else 'Long-Form'} Tok/s", |
|
"html", |
|
True, |
|
), |
|
] |
|
) |
|
|
|
AutoEvalColumn = make_dataclass("AutoEvalColumn", updated_column_dict, frozen=True) |
|
|
|
COLS = [c.name for c in fields(AutoEvalColumn) if not c.hidden] |
|
TYPES = [c.type for c in fields(AutoEvalColumn) if not c.hidden] |
|
ALWAYS_HERE_COLS = [c.name for c in fields(AutoEvalColumn) if c.never_hidden] |
|
TOGGLE_COLS = [c.name for c in fields(AutoEvalColumn) if not c.never_hidden] |
|
SELECTED_COLS = [ |
|
c.name |
|
for c in fields(AutoEvalColumn) |
|
if not c.never_hidden and c.displayed_by_default |
|
] |
|
|
|
return { |
|
"column_dict": updated_column_dict, |
|
"AutoEvalColumn": AutoEvalColumn, |
|
"COLS": COLS, |
|
"TYPES": TYPES, |
|
"ALWAYS_HERE_COLS": ALWAYS_HERE_COLS, |
|
"TOGGLE_COLS": TOGGLE_COLS, |
|
"SELECTED_COLS": SELECTED_COLS, |
|
} |
|
|
|
|
|
def create_confusion_matrix_plot(matrix, labels, is_forced): |
|
""" |
|
Creates a confusion matrix plot for language detection. |
|
:param matrix: 2D numpy array representing the confusion matrix |
|
:param labels: List of language labels |
|
:param is_forced: Boolean indicating whether language hint was used |
|
:return: A Plotly figure object representing the confusion matrix |
|
This function generates a heatmap visualization of the confusion matrix |
|
for language detection, with customized layout and hover information. |
|
""" |
|
fig = go.Figure( |
|
data=go.Heatmap( |
|
z=matrix, |
|
x=labels, |
|
y=labels, |
|
colorscale=[ |
|
[0, "rgb(250,249,244)"], |
|
[0.5, "rgb(69,66,244)"], |
|
[1.0, "rgb(14,12,6)"], |
|
], |
|
hoverongaps=False, |
|
hovertemplate="True: %{y}<br>Predicted: %{x}<br>Value: %{z}<extra></extra>", |
|
) |
|
) |
|
fig.update_layout( |
|
title=f'Language Detection Confusion Matrix with {"Language Hint" if is_forced else "Language Prediction by Model"}', |
|
xaxis_title="Predicted Language", |
|
yaxis_title="True Language", |
|
xaxis=dict(tickangle=-45), |
|
width=600, |
|
height=600, |
|
margin=dict(l=50, r=50, t=50, b=50), |
|
) |
|
return fig |
|
|
|
|
|
def hex_to_rgb(hex_color): |
|
""" |
|
Converts a hexadecimal color code to RGB values. |
|
:param hex_color: String representing a color in hexadecimal format |
|
:return: Tuple of three integers representing RGB values |
|
This function takes a hex color code and returns the corresponding |
|
RGB values as a tuple of integers. |
|
""" |
|
hex_color = hex_color.lstrip("#") |
|
return tuple(int(hex_color[i : i + 2], 16) for i in (0, 2, 4)) |
|
|
|
|
|
def rgb_to_hex(rgb): |
|
""" |
|
Converts RGB values to a hexadecimal color code. |
|
:param rgb: Tuple of three integers representing RGB values |
|
:return: String representing the color in hexadecimal format |
|
This function takes RGB values as a tuple and returns the corresponding |
|
hex color code as a string. |
|
""" |
|
return "#{:02x}{:02x}{:02x}".format(*rgb) |
|
|
|
|
|
def interpolate_colors(color1, color2, factor): |
|
""" |
|
Interpolates between two colors in HSV space. |
|
:param color1: First color in hexadecimal format |
|
:param color2: Second color in hexadecimal format |
|
:param factor: Float between 0 and 1, representing the interpolation factor |
|
:return: Interpolated color in hexadecimal format |
|
This function performs color interpolation in HSV color space, which can |
|
produce more visually pleasing results than simple RGB interpolation. |
|
""" |
|
rgb1 = hex_to_rgb(color1) |
|
rgb2 = hex_to_rgb(color2) |
|
|
|
hsv1 = colorsys.rgb_to_hsv(*[x / 255.0 for x in rgb1]) |
|
hsv2 = colorsys.rgb_to_hsv(*[x / 255.0 for x in rgb2]) |
|
|
|
h = (hsv1[0] + factor * (hsv2[0] - hsv1[0])) % 1.0 |
|
s = hsv1[1] + factor * (hsv2[1] - hsv1[1]) |
|
v = hsv1[2] + factor * (hsv2[2] - hsv1[2]) |
|
|
|
rgb = colorsys.hsv_to_rgb(h, s, v) |
|
return rgb_to_hex(tuple(int(x * 255) for x in rgb)) |
|
|
|
|
|
def color_distance(color1, color2): |
|
""" |
|
Calculates the Euclidean distance between two colors in RGB space. |
|
:param color1: First color in hexadecimal format |
|
:param color2: Second color in hexadecimal format |
|
:return: Float representing the distance between the two colors |
|
This function computes the Euclidean distance between two colors in RGB space, |
|
which can be used as a measure of color similarity. |
|
""" |
|
rgb1 = hex_to_rgb(color1) |
|
rgb2 = hex_to_rgb(color2) |
|
return sum((a - b) ** 2 for a, b in zip(rgb1, rgb2)) ** 0.5 |
|
|
|
|
|
def generate_random_colors(base_colors, num_colors, min_distance=30): |
|
""" |
|
Generates a list of random colors based on a set of base colors. |
|
:param base_colors: List of base colors in hexadecimal format |
|
:param num_colors: Number of colors to generate |
|
:param min_distance: Minimum distance between generated colors (default: 30) |
|
:return: List of generated colors in hexadecimal format |
|
This function creates a list of random colors by interpolating between |
|
the provided base colors. It attempts to maintain a minimum distance |
|
between colors to ensure visual distinctiveness. |
|
""" |
|
generated_colors = [] |
|
attempts = 0 |
|
max_attempts = 1000 |
|
|
|
while len(generated_colors) < num_colors and attempts < max_attempts: |
|
color1, color2 = random.sample(base_colors, 2) |
|
factor = random.random() |
|
new_color = interpolate_colors(color1, color2, factor) |
|
|
|
if all(color_distance(new_color, c) >= min_distance for c in generated_colors): |
|
generated_colors.append(new_color) |
|
attempts = 0 |
|
else: |
|
attempts += 1 |
|
|
|
if attempts > 100: |
|
if random.random() < 0.1: |
|
generated_colors.append(new_color) |
|
attempts = 0 |
|
|
|
return generated_colors |
|
|
|
|
|
@dataclass |
|
class Task: |
|
""" |
|
Dataclass representing a benchmark task. |
|
:param benchmark: String representing the benchmark name |
|
:param metric: String representing the metric used for evaluation |
|
:param col_name: String representing the column name in the results DataFrame |
|
""" |
|
|
|
benchmark: str |
|
metric: str |
|
col_name: str |
|
|
|
|
|
@dataclass(frozen=True) |
|
class ColumnContent: |
|
""" |
|
Dataclass representing a column in the results table. |
|
:param name: String representing the column name |
|
:param type: String representing the data type of the column |
|
:param displayed_by_default: Boolean indicating if the column should be displayed by default |
|
:param hidden: Boolean indicating if the column should be hidden (default: False) |
|
:param never_hidden: Boolean indicating if the column should never be hidden (default: False) |
|
:param dummy: Boolean indicating if this is a dummy column (default: False) |
|
""" |
|
|
|
name: str |
|
type: str |
|
displayed_by_default: bool |
|
hidden: bool = False |
|
never_hidden: bool = False |
|
dummy: bool = False |
|
|
|
|
|
css = """ |
|
@font-face { |
|
font-family: 'Zwizz Regular'; |
|
font-style: normal; |
|
font-weight: normal; |
|
src: local('Zwizz Regular'), url('static/Zwizz-Regular.woff') format('woff'); |
|
} |
|
@font-face { |
|
font-family: 'Zwizz Medium'; |
|
font-style: normal; |
|
font-weight: normal; |
|
src: local('Zwizz Medium'), url('static/Zwizz-Medium.woff') format('woff'); |
|
} |
|
@font-face { |
|
font-family: 'Zwizz SemiBold'; |
|
font-style: normal; |
|
font-weight: normal; |
|
src: local('Zwizz SemiBold'), url('static/Zwizz-SemiBold.woff') format('woff'); |
|
} |
|
|
|
@import url('https://fonts.googleapis.com/css2?family=Noto+Color+Emoji&display=swap'); |
|
@import url('https://fonts.googleapis.com/css2?family=Sora:[email protected]&display=swap'); |
|
/* Typography Scale */ |
|
h1, .h1 { |
|
font-family: 'Sora', sans-serif; |
|
font-weight: 300; |
|
font-size: 2em; |
|
letter-spacing: -0.05em; |
|
} |
|
h2, .h2 { |
|
font-family: 'Sora', sans-serif; |
|
font-weight: 400; |
|
letter-spacing: -0.05em; |
|
} |
|
h3, h4, h5, .h3, .h4, .h5 { |
|
font-family: 'Sora', sans-serif; |
|
font-weight: 400; |
|
letter-spacing: -0.05em; |
|
} |
|
h6, .h6, pre, code, .monospace { |
|
font-family: 'IBM Plex Mono', monospace; |
|
font-weight: 400; |
|
letter-spacing: 0.01em; |
|
} |
|
/* Add strong tag styling */ |
|
strong, b { |
|
font-family: 'Zwizz SemiBold', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; |
|
letter-spacing: -0.02em; |
|
} |
|
/* Global Zwizz styles */ |
|
:root { |
|
--zwizz-spacing: -0.02em; |
|
} |
|
/* All Gradio elements should have Zwizz spacing */ |
|
.gradio-container * { |
|
letter-spacing: var(--zwizz-spacing); |
|
line-height: 1.7; |
|
} |
|
/* UI Elements */ |
|
.tab-buttons button, #models-to-add-text, .gradio-button { |
|
font-family: 'Sora', sans-serif; |
|
font-weight: 400; |
|
letter-spacing: -0.05em; |
|
} |
|
/* Specific Table Styling */ |
|
table, .table, th, td { |
|
font-family: 'IBM Plex Mono', 'Noto Color Emoji', sans-serif, monospace !important; |
|
font-weight: 400; |
|
letter-spacing: 0.01em; |
|
} |
|
/* Technical/Code Elements */ |
|
.code-block, .technical-text { |
|
font-family: 'IBM Plex Mono', monospace; |
|
font-weight: 400; |
|
letter-spacing: 0.01em; |
|
} |
|
/* Additional Elements */ |
|
#methodology-text p, #methodology-text li, .markdown-text { |
|
font-family: 'Zwizz Regular', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; |
|
font-size: 16px !important; |
|
letter-spacing: var(--zwizz-spacing); |
|
line-height: 1.7; |
|
} |
|
/* Font weight utilities */ |
|
.zwizz-medium { |
|
font-family: 'Zwizz Medium', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; |
|
} |
|
.zwizz-semibold { |
|
font-family: 'Zwizz SemiBold', -apple-system, BlinkMacSystemFont, system-ui, sans-serif; |
|
} |
|
/* Maintaining Original Layout Rules */ |
|
.gradio-container { |
|
max-width: 95% !important; |
|
} |
|
/* Table Layouts */ |
|
.large-table, |
|
.large-table .table-wrap, |
|
#multilingual-model-table .table-wrap, |
|
#lookup-table .table-wrap { |
|
height: 35em !important; |
|
overflow-y: scroll !important; |
|
} |
|
/* SVG Container Rules */ |
|
.svg-container, |
|
.main-svg { |
|
width: 100% !important; |
|
} |
|
.large-table, .large-table .table-wrap, #multilingual-model-table .table-wrap, #lookup-table .table-wrap { |
|
height: 35em !important; |
|
overflow-y: scroll !important; |
|
} |
|
.left-side-table .table-wrap { |
|
height: 15em !important; |
|
overflow-y: scroll !important; |
|
} |
|
#average-wer-table .table-wrap { |
|
height: 8em !important; |
|
overflow-y: scroll !important; |
|
} |
|
#general-wer-table .table-wrap { |
|
height: 35em !important; |
|
overflow-y: scroll !important; |
|
} |
|
""" |
|
|