# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import base64 import csv import datetime import difflib import io import json import logging import math import operator import os import pickle from collections import defaultdict from os.path import expanduser from pathlib import Path import dash import dash_bootstrap_components as dbc import diff_match_patch import editdistance import jiwer import librosa import numpy as np import pandas as pd import soundfile as sf import tqdm from dash import dash_table, dcc, html from dash.dependencies import Input, Output, State from dash.exceptions import PreventUpdate from plotly import express as px from plotly import graph_objects as go from plotly.subplots import make_subplots # number of items in a table per page DATA_PAGE_SIZE = 10 # operators for filtering items filter_operators = { '>=': 'ge', '<=': 'le', '<': 'lt', '>': 'gt', '!=': 'ne', '=': 'eq', 'contains ': 'contains', } comparison_mode = False # parse table filter queries def split_filter_part(filter_part): for op in filter_operators: if op in filter_part: name_part, value_part = filter_part.split(op, 1) name = name_part[name_part.find('{') + 1 : name_part.rfind('}')] value_part = value_part.strip() v0 = value_part[0] if v0 == value_part[-1] and v0 in ("'", '"', '`'): value = value_part[1:-1].replace('\\' + v0, v0) else: try: value = float(value_part) except ValueError: value = value_part return name, filter_operators[op], value return [None] * 3 # standard command-line arguments parser def parse_args(): parser = argparse.ArgumentParser(description='Speech Data Explorer') parser.add_argument( 'manifest', help='path to JSON manifest file', ) parser.add_argument('--vocab', help='optional vocabulary to highlight OOV words') parser.add_argument('--port', default='8050', help='serving port for establishing connection') parser.add_argument( '--disable-caching-metrics', action='store_true', help='disable caching metrics for errors analysis' ) parser.add_argument( '--estimate-audio-metrics', '-a', action='store_true', help='estimate frequency bandwidth and signal level of audio recordings', ) parser.add_argument( '--audio-base-path', default=None, type=str, help='A base path for the relative paths in manifest. It defaults to manifest path.', ) parser.add_argument('--debug', '-d', action='store_true', help='enable debug mode') parser.add_argument( '--names_compared', '-nc', nargs=2, type=str, help='names of the two fields that will be compared, example: pred_text_contextnet pred_text_conformer. "pred_text_" prefix IS IMPORTANT!', ) parser.add_argument( '--show_statistics', '-shst', type=str, help='field name for which you want to see statistics (optional). Example: pred_text_contextnet.', ) args = parser.parse_args() # assume audio_filepath is relative to the directory where the manifest is stored if args.audio_base_path is None: args.audio_base_path = os.path.dirname(args.manifest) # automaticly going in comparison mode, if there is names_compared argument if args.names_compared is not None: comparison_mode = True else: comparison_mode = False print(args, comparison_mode) return args, comparison_mode # estimate frequency bandwidth of signal def eval_bandwidth(signal, sr, threshold=-50): time_stride = 0.01 hop_length = int(sr * time_stride) n_fft = 512 spectrogram = np.mean( np.abs(librosa.stft(y=signal, n_fft=n_fft, hop_length=hop_length, window='blackmanharris')) ** 2, axis=1 ) power_spectrum = librosa.power_to_db(S=spectrogram, ref=np.max, top_db=100) freqband = 0 for idx in range(len(power_spectrum) - 1, -1, -1): if power_spectrum[idx] > threshold: freqband = idx / n_fft * sr break return freqband # load data from JSON manifest file def load_data( data_filename, disable_caching=False, estimate_audio=False, vocab=None, audio_base_path=None, comparison_mode=False, names=None, ): if comparison_mode: if names is None: logging.error(f'Please, specify names of compared models') name_1, name_2 = names if not comparison_mode: if vocab is not None: # load external vocab vocabulary_ext = {} with open(vocab, 'r') as f: for line in f: if '\t' in line: # parse word from TSV file word = line.split('\t')[0] else: # assume each line contains just a single word word = line.strip() vocabulary_ext[word] = 1 if not disable_caching: pickle_filename = data_filename.split('.json')[0] json_mtime = datetime.datetime.fromtimestamp(os.path.getmtime(data_filename)) timestamp = json_mtime.strftime('%Y%m%d_%H%M') pickle_filename += '_' + timestamp + '.pkl' if os.path.exists(pickle_filename): with open(pickle_filename, 'rb') as f: data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available = pickle.load(f) if vocab is not None: for item in vocabulary_data: item['OOV'] = item['word'] not in vocabulary_ext if estimate_audio: for item in data: filepath = absolute_audio_filepath(item['audio_filepath'], audio_base_path) signal, sr = librosa.load(path=filepath, sr=None) bw = eval_bandwidth(signal, sr) item['freq_bandwidth'] = int(bw) item['level_db'] = 20 * np.log10(np.max(np.abs(signal))) with open(pickle_filename, 'wb') as f: pickle.dump( [data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available], f, pickle.HIGHEST_PROTOCOL, ) return data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available data = [] wer_count = 0 cer_count = 0 wmr_count = 0 wer = 0 cer = 0 wmr = 0 mwa = 0 num_hours = 0 match_vocab_1 = defaultdict(lambda: 0) match_vocab_2 = defaultdict(lambda: 0) def append_data( data_filename, estimate_audio, field_name='pred_text', ): data = [] wer_dist = 0.0 wer_count = 0 cer_dist = 0.0 cer_count = 0 wmr_count = 0 wer = 0 cer = 0 wmr = 0 mwa = 0 num_hours = 0 vocabulary = defaultdict(lambda: 0) alphabet = set() match_vocab = defaultdict(lambda: 0) sm = difflib.SequenceMatcher() metrics_available = False with open(data_filename, 'r', encoding='utf8') as f: for line in tqdm.tqdm(f): item = json.loads(line) if not isinstance(item['text'], str): item['text'] = '' num_chars = len(item['text']) orig = item['text'].split() num_words = len(orig) for word in orig: vocabulary[word] += 1 for char in item['text']: alphabet.add(char) num_hours += item['duration'] if field_name in item: metrics_available = True pred = item[field_name].split() measures = jiwer.compute_measures(item['text'], item[field_name]) word_dist = measures['substitutions'] + measures['insertions'] + measures['deletions'] char_dist = editdistance.eval(item['text'], item[field_name]) wer_dist += word_dist cer_dist += char_dist wer_count += num_words cer_count += num_chars sm.set_seqs(orig, pred) for m in sm.get_matching_blocks(): for word_idx in range(m[0], m[0] + m[2]): match_vocab[orig[word_idx]] += 1 wmr_count += measures['hits'] else: if comparison_mode: if field_name != 'pred_text': if field_name == name_1: logging.error(f"The .json file has no field with name: {name_1}") exit() if field_name == name_2: logging.error(f"The .json file has no field with name: {name_2}") exit() data.append( { 'audio_filepath': item['audio_filepath'], 'duration': round(item['duration'], 2), 'num_words': num_words, 'num_chars': num_chars, 'word_rate': round(num_words / item['duration'], 2), 'char_rate': round(num_chars / item['duration'], 2), 'text': item['text'], } ) if metrics_available: data[-1][field_name] = item[field_name] if num_words == 0: num_words = 1e-9 if num_chars == 0: num_chars = 1e-9 data[-1]['WER'] = round(word_dist / num_words * 100.0, 2) data[-1]['CER'] = round(char_dist / num_chars * 100.0, 2) data[-1]['WMR'] = round(measures['hits'] / num_words * 100.0, 2) data[-1]['I'] = measures['insertions'] data[-1]['D'] = measures['deletions'] data[-1]['D-I'] = measures['deletions'] - measures['insertions'] if estimate_audio: filepath = absolute_audio_filepath(item['audio_filepath'], data_filename) signal, sr = librosa.load(path=filepath, sr=None) bw = eval_bandwidth(signal, sr) item['freq_bandwidth'] = int(bw) item['level_db'] = 20 * np.log10(np.max(np.abs(signal))) for k in item: if k not in data[-1]: data[-1][k] = item[k] vocabulary_data = [{'word': word, 'count': vocabulary[word]} for word in vocabulary] return ( vocabulary_data, metrics_available, data, wer_dist, wer_count, cer_dist, cer_count, wmr_count, wer, cer, wmr, mwa, num_hours, vocabulary, alphabet, match_vocab, ) ( vocabulary_data, metrics_available, data, wer_dist, wer_count, cer_dist, cer_count, wmr_count, wer, cer, wmr, mwa, num_hours, vocabulary, alphabet, match_vocab, ) = append_data(data_filename, estimate_audio, field_name=fld_nm) if comparison_mode: ( vocabulary_data_1, metrics_available_1, data_1, wer_dist_1, wer_count_1, cer_dist_1, cer_count_1, wmr_count_1, wer_1, cer_1, wmr_1, mwa_1, num_hours_1, vocabulary_1, alphabet_1, match_vocab_1, ) = append_data(data_filename, estimate_audio, field_name=name_1) ( vocabulary_data_2, metrics_available_2, data_2, wer_dist_2, wer_count_2, cer_dist_2, cer_count_2, wmr_count_2, wer_2, cer_2, wmr_2, mwa_2, num_hours_2, vocabulary_2, alphabet_2, match_vocab_2, ) = append_data(data_filename, estimate_audio, field_name=name_2) if not comparison_mode: if vocab is not None: for item in vocabulary_data: item['OOV'] = item['word'] not in vocabulary_ext if metrics_available or comparison_mode: if metrics_available: wer = wer_dist / wer_count * 100.0 cer = cer_dist / cer_count * 100.0 wmr = wmr_count / wer_count * 100.0 if comparison_mode: if metrics_available_1 and metrics_available_2: wer_1 = wer_dist_1 / wer_count_1 * 100.0 cer_1 = cer_dist_1 / cer_count_1 * 100.0 wmr_1 = wmr_count_1 / wer_count_1 * 100.0 wer = wer_dist_2 / wer_count_2 * 100.0 cer = cer_dist_2 / cer_count_2 * 100.0 wmr = wmr_count_2 / wer_count_2 * 100.0 acc_sum_1 = 0 acc_sum_2 = 0 for item in vocabulary_data_1: w = item['word'] word_accuracy_1 = match_vocab_1[w] / vocabulary_1[w] * 100.0 acc_sum_1 += word_accuracy_1 item['accuracy_1'] = round(word_accuracy_1, 1) mwa_1 = acc_sum_1 / len(vocabulary_data_1) for item in vocabulary_data_2: w = item['word'] word_accuracy_2 = match_vocab_2[w] / vocabulary_2[w] * 100.0 acc_sum_2 += word_accuracy_2 item['accuracy_2'] = round(word_accuracy_2, 1) mwa_2 = acc_sum_2 / len(vocabulary_data_2) acc_sum = 0 for item in vocabulary_data: w = item['word'] word_accuracy = match_vocab[w] / vocabulary[w] * 100.0 acc_sum += word_accuracy item['accuracy'] = round(word_accuracy, 1) mwa = acc_sum / len(vocabulary_data) num_hours /= 3600.0 if not comparison_mode: if not disable_caching: with open(pickle_filename, 'wb') as f: pickle.dump( [data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available], f, pickle.HIGHEST_PROTOCOL, ) if comparison_mode: return ( data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available, data_1, wer_1, cer_1, wmr_1, mwa_1, num_hours_1, vocabulary_data_1, alphabet_1, metrics_available_1, data_2, wer_2, cer_2, wmr_2, mwa_2, num_hours_2, vocabulary_data_2, alphabet_2, metrics_available_2, ) return data, wer, cer, wmr, mwa, num_hours, vocabulary_data, alphabet, metrics_available # plot histogram of specified field in data list def plot_histogram(data, key, label): fig = px.histogram( data_frame=[item[key] for item in data], nbins=50, log_y=True, labels={'value': label}, opacity=0.5, color_discrete_sequence=['green'], height=200, ) fig.update_layout(showlegend=False, margin=dict(l=0, r=0, t=0, b=0, pad=0)) return fig def plot_word_accuracy(vocabulary_data): labels = ['Unrecognized', 'Sometimes recognized', 'Always recognized'] counts = [0, 0, 0] for word in vocabulary_data: if word['accuracy'] == 0: counts[0] += 1 elif word['accuracy'] < 100: counts[1] += 1 else: counts[2] += 1 colors = ['red', 'orange', 'green'] fig = go.Figure( data=[ go.Bar( x=labels, y=counts, marker_color=colors, text=['{:.2%}'.format(count / sum(counts)) for count in counts], textposition='auto', ) ] ) fig.update_layout( showlegend=False, margin=dict(l=0, r=0, t=0, b=0, pad=0), height=200, yaxis={'title_text': '#words'} ) return fig def absolute_audio_filepath(audio_filepath, audio_base_path): """Return absolute path to an audio file. Check if a file existst at audio_filepath. If not, assume that the path is relative to audio_base_path. """ audio_filepath = Path(audio_filepath) if not audio_filepath.is_file() and not audio_filepath.is_absolute(): audio_filepath = audio_base_path / audio_filepath if audio_filepath.is_file(): filename = str(audio_filepath) else: filename = expanduser(audio_filepath) else: filename = expanduser(audio_filepath) return filename # parse the CLI arguments args, comparison_mode = parse_args() if args.show_statistics is not None: fld_nm = args.show_statistics else: fld_nm = 'pred_text' # parse names of compared models, if any if comparison_mode: name_1, name_2 = args.names_compared print(name_1, name_2) print('Loading data...') if not comparison_mode: data, wer, cer, wmr, mwa, num_hours, vocabulary, alphabet, metrics_available = load_data( args.manifest, args.disable_caching_metrics, args.estimate_audio_metrics, args.vocab, args.audio_base_path, comparison_mode, args.names_compared, ) else: ( data, wer, cer, wmr, mwa, num_hours, vocabulary, alphabet, metrics_available, data_1, wer_1, cer_1, wmr_1, mwa_1, num_hours_1, vocabulary_1, alphabet_1, metrics_available_1, data_2, wer_2, cer_2, wmr_2, mwa_2, num_hours_2, vocabulary_2, alphabet_2, metrics_available_2, ) = load_data( args.manifest, args.disable_caching_metrics, args.estimate_audio_metrics, args.vocab, args.audio_base_path, comparison_mode, args.names_compared, ) print('Starting server...') app = dash.Dash( __name__, suppress_callback_exceptions=True, external_stylesheets=[dbc.themes.BOOTSTRAP], title=os.path.basename(args.manifest), ) figures_labels = { 'duration': ['Duration', 'Duration, sec'], 'num_words': ['Number of Words', '#words'], 'num_chars': ['Number of Characters', '#chars'], 'word_rate': ['Word Rate', '#words/sec'], 'char_rate': ['Character Rate', '#chars/sec'], 'WER': ['Word Error Rate', 'WER, %'], 'CER': ['Character Error Rate', 'CER, %'], 'WMR': ['Word Match Rate', 'WMR, %'], 'I': ['# Insertions (I)', '#words'], 'D': ['# Deletions (D)', '#words'], 'D-I': ['# Deletions - # Insertions (D-I)', '#words'], 'freq_bandwidth': ['Frequency Bandwidth', 'Bandwidth, Hz'], 'level_db': ['Peak Level', 'Level, dB'], } figures_hist = {} for k in data[0]: val = data[0][k] if isinstance(val, (int, float)) and not isinstance(val, bool): if k in figures_labels: ylabel = figures_labels[k][0] xlabel = figures_labels[k][1] else: title = k.replace('_', ' ') title = title[0].upper() + title[1:].lower() ylabel = title xlabel = title figures_hist[k] = [ylabel + ' (per utterance)', plot_histogram(data, k, xlabel)] if metrics_available: figure_word_acc = plot_word_accuracy(vocabulary) stats_layout = [ dbc.Row(dbc.Col(html.H5(children='Global Statistics'), class_name='text-secondary'), class_name='mt-3'), dbc.Row( [ dbc.Col(html.Div('Number of hours', className='text-secondary'), width=3, class_name='border-end'), dbc.Col(html.Div('Number of utterances', className='text-secondary'), width=3, class_name='border-end'), dbc.Col(html.Div('Vocabulary size', className='text-secondary'), width=3, class_name='border-end'), dbc.Col(html.Div('Alphabet size', className='text-secondary'), width=3), ], class_name='bg-light mt-2 rounded-top border-top border-start border-end', ), dbc.Row( [ dbc.Col( html.H5( '{:.2f} hours'.format(num_hours), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}, ), width=3, class_name='border-end', ), dbc.Col( html.H5(len(data), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}), width=3, class_name='border-end', ), dbc.Col( html.H5( '{} words'.format(len(vocabulary)), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}, ), width=3, class_name='border-end', ), dbc.Col( html.H5( '{} chars'.format(len(alphabet)), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}, ), width=3, ), ], class_name='bg-light rounded-bottom border-bottom border-start border-end', ), ] if metrics_available: stats_layout += [ dbc.Row( [ dbc.Col( html.Div('Word Error Rate (WER), %', className='text-secondary'), width=3, class_name='border-end' ), dbc.Col( html.Div('Character Error Rate (CER), %', className='text-secondary'), width=3, class_name='border-end', ), dbc.Col( html.Div('Word Match Rate (WMR), %', className='text-secondary'), width=3, class_name='border-end', ), dbc.Col(html.Div('Mean Word Accuracy, %', className='text-secondary'), width=3), ], class_name='bg-light mt-2 rounded-top border-top border-start border-end', ), dbc.Row( [ dbc.Col( html.H5( '{:.2f}'.format(wer), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}, ), width=3, class_name='border-end', ), dbc.Col( html.H5( '{:.2f}'.format(cer), className='text-center p-1', style={'color': 'green', 'opacity': 0.7} ), width=3, class_name='border-end', ), dbc.Col( html.H5( '{:.2f}'.format(wmr), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}, ), width=3, class_name='border-end', ), dbc.Col( html.H5( '{:.2f}'.format(mwa), className='text-center p-1', style={'color': 'green', 'opacity': 0.7}, ), width=3, ), ], class_name='bg-light rounded-bottom border-bottom border-start border-end', ), ] stats_layout += [ dbc.Row(dbc.Col(html.H5(children='Alphabet'), class_name='text-secondary'), class_name='mt-3'), dbc.Row( dbc.Col(html.Div('{}'.format(sorted(alphabet))),), class_name='mt-2 bg-light font-monospace rounded border' ), ] for k in figures_hist: stats_layout += [ dbc.Row(dbc.Col(html.H5(figures_hist[k][0]), class_name='text-secondary'), class_name='mt-3'), dbc.Row(dbc.Col(dcc.Graph(id='duration-graph', figure=figures_hist[k][1]),),), ] if metrics_available: stats_layout += [ dbc.Row(dbc.Col(html.H5('Word accuracy distribution'), class_name='text-secondary'), class_name='mt-3'), dbc.Row(dbc.Col(dcc.Graph(id='word-acc-graph', figure=figure_word_acc),),), ] wordstable_columns = [{'name': 'Word', 'id': 'word'}, {'name': 'Count', 'id': 'count'}] if 'OOV' in vocabulary[0]: wordstable_columns.append({'name': 'OOV', 'id': 'OOV'}) if metrics_available: wordstable_columns.append({'name': 'Accuracy, %', 'id': 'accuracy'}) stats_layout += [ dbc.Row(dbc.Col(html.H5('Vocabulary'), class_name='text-secondary'), class_name='mt-3'), dbc.Row( dbc.Col( dash_table.DataTable( id='wordstable', columns=wordstable_columns, filter_action='custom', filter_query='', sort_action='custom', sort_mode='single', page_action='custom', page_current=0, page_size=DATA_PAGE_SIZE, cell_selectable=False, page_count=math.ceil(len(vocabulary) / DATA_PAGE_SIZE), sort_by=[{'column_id': 'word', 'direction': 'asc'}], style_cell={'maxWidth': 0, 'textAlign': 'left'}, style_header={'color': 'text-primary'}, css=[{'selector': '.dash-filter--case', 'rule': 'display: none'},], ), ), class_name='m-2', ), dbc.Row(dbc.Col([html.Button('Download Vocabulary', id='btn_csv'), dcc.Download(id='download-vocab-csv'),]),), ] @app.callback( Output('download-vocab-csv', 'data'), [Input('btn_csv', 'n_clicks'), State('wordstable', 'sort_by'), State('wordstable', 'filter_query')], prevent_initial_call=True, ) def download_vocabulary(n_clicks, sort_by, filter_query): vocabulary_view = vocabulary filtering_expressions = filter_query.split(' && ') for filter_part in filtering_expressions: col_name, op, filter_value = split_filter_part(filter_part) if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'): vocabulary_view = [x for x in vocabulary_view if getattr(operator, op)(x[col_name], filter_value)] elif op == 'contains': vocabulary_view = [x for x in vocabulary_view if filter_value in str(x[col_name])] if len(sort_by): col = sort_by[0]['column_id'] descending = sort_by[0]['direction'] == 'desc' vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=descending) with open('sde_vocab.csv', encoding='utf-8', mode='w', newline='') as fo: writer = csv.writer(fo) writer.writerow(vocabulary_view[0].keys()) for item in vocabulary_view: writer.writerow([str(item[k]) for k in item]) return dcc.send_file("sde_vocab.csv") @app.callback( [Output('wordstable', 'data'), Output('wordstable', 'page_count')], [Input('wordstable', 'page_current'), Input('wordstable', 'sort_by'), Input('wordstable', 'filter_query')], ) def update_wordstable(page_current, sort_by, filter_query): vocabulary_view = vocabulary filtering_expressions = filter_query.split(' && ') for filter_part in filtering_expressions: col_name, op, filter_value = split_filter_part(filter_part) if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'): vocabulary_view = [x for x in vocabulary_view if getattr(operator, op)(x[col_name], filter_value)] elif op == 'contains': vocabulary_view = [x for x in vocabulary_view if filter_value in str(x[col_name])] if len(sort_by): col = sort_by[0]['column_id'] descending = sort_by[0]['direction'] == 'desc' vocabulary_view = sorted(vocabulary_view, key=lambda x: x[col], reverse=descending) if page_current * DATA_PAGE_SIZE >= len(vocabulary_view): page_current = len(vocabulary_view) // DATA_PAGE_SIZE return [ vocabulary_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE], math.ceil(len(vocabulary_view) / DATA_PAGE_SIZE), ] samples_layout = [ dbc.Row(dbc.Col(html.H5('Data'), class_name='text-secondary'), class_name='mt-3'), html.Hr(), dbc.Row( dbc.Col( dash_table.DataTable( id='datatable', columns=[{'name': k.replace('_', ' '), 'id': k, 'hideable': True} for k in data[0]], filter_action='custom', filter_query='', sort_action='custom', sort_mode='single', sort_by=[], row_selectable='single', selected_rows=[0], page_action='custom', page_current=0, page_size=DATA_PAGE_SIZE, page_count=math.ceil(len(data) / DATA_PAGE_SIZE), style_cell={'overflow': 'hidden', 'textOverflow': 'ellipsis', 'maxWidth': 0, 'textAlign': 'center'}, style_header={ 'color': 'text-primary', 'text_align': 'center', 'height': 'auto', 'whiteSpace': 'normal', }, css=[ {'selector': '.dash-spreadsheet-menu', 'rule': 'position:absolute; bottom: 8px'}, {'selector': '.dash-filter--case', 'rule': 'display: none'}, {'selector': '.column-header--hide', 'rule': 'display: none'}, ], ), ) ), ] + [ dbc.Row( [ dbc.Col( html.Div(children=k.replace('_', ' ')), width=2, class_name='mt-1 bg-light font-monospace text-break small rounded border', ), dbc.Col(html.Div(id='_' + k), class_name='mt-1 bg-light font-monospace text-break small rounded border'), ] ) for k in data[0] ] if metrics_available: samples_layout += [ dbc.Row( [ dbc.Col( html.Div(children='text diff'), width=2, class_name='mt-1 bg-light font-monospace text-break small rounded border', ), dbc.Col( html.Iframe( id='_diff', sandbox='', srcDoc='', style={'border': 'none', 'width': '100%', 'height': '100%'}, className='bg-light font-monospace text-break small', ), class_name='mt-1 bg-light font-monospace text-break small rounded border', ), ] ) ] samples_layout += [ dbc.Row(dbc.Col(html.Audio(id='player', controls=True),), class_name='mt-3 '), dbc.Row(dbc.Col(dcc.Graph(id='signal-graph')), class_name='mt-3'), ] # updating vocabulary to show wordstable_columns_tool = [{'name': 'Word', 'id': 'word'}, {'name': 'Count', 'id': 'count'}] wordstable_columns_tool.append({'name': 'Accuracy_1, %', 'id': 'accuracy_1'}) wordstable_columns_tool.append({'name': 'Accuracy_2, %', 'id': 'accuracy_2'}) # wordstable_columns_tool.append({'name': 'Accuracy_' + name_1 + ', %', 'id': 'accuracy_1'}) # wordstable_columns_tool.append({'name': 'Accuracy_' + name_2 + ', %', 'id': 'accuracy_2'}) if comparison_mode: model_name_1, model_name_2 = name_1, name_2 for i in range(len(vocabulary_1)): vocabulary_1[i].update(vocabulary_2[i]) def prepare_data(df, name1=model_name_1, name2=model_name_2): res = pd.DataFrame() tmp = df['word'] res.insert(0, 'word', tmp) res.insert(1, 'count', [float(i) for i in df['count']]) res.insert(2, 'accuracy_model_' + name1, df['accuracy_1']) res.insert(3, 'accuracy_model_' + name2, df['accuracy_2']) res.insert(4, 'accuracy_diff ' + '(' + name1 + ' - ' + name2 + ')', df['accuracy_1'] - df['accuracy_2']) res.insert(2, 'count^(-1)', 1 / df['count']) return res for_col_names = pd.DataFrame() for_col_names.insert(0, 'word', ['a']) for_col_names.insert(1, 'count', [0]) for_col_names.insert(2, 'accuracy_model_' + model_name_1, [0]) for_col_names.insert(3, 'accuracy_model_' + model_name_2, [0]) for_col_names.insert(4, 'accuracy_diff ' + '(' + model_name_1 + ' - ' + model_name_2 + ')', [0]) for_col_names.insert(5, 'count^(-1)', [0]) @app.callback( Output('voc_graph', 'figure'), [ Input('xaxis-column', 'value'), Input('yaxis-column', 'value'), Input('color-column', 'value'), Input('size-column', 'value'), Input("datatable-advanced-filtering", "derived_virtual_data"), Input("dot_spacing", 'value'), Input("radius", 'value'), ], prevent_initial_call=False, ) def draw_vocab(Ox, Oy, color, size, data, dot_spacing='no', rad=0.01): import math import random import pandas as pd df = pd.DataFrame.from_records(data) res = prepare_data(df) res_spacing = res.copy(deep=True) if dot_spacing == 'yes': rad = float(rad) if Ox[0] == 'a' or 'c': tmp = [] for i in range(len(res[Ox])): tmp.append( res[Ox][i] + rad * random.randrange(1, 10) * math.cos(random.randrange(1, len(res[Ox])) * 2 * math.pi / len(res[Ox])) ) res_spacing[Ox] = tmp if Ox[0] == 'a' or 'c': tmp = [] for i in range(len(res[Oy])): tmp.append( res[Oy][i] + rad * random.randrange(1, 10) * math.sin(random.randrange(1, len(res[Oy])) * 2 * math.pi / len(res[Oy])) ) res_spacing[Oy] = tmp res = res_spacing fig = px.scatter( res, x=Ox, y=Oy, color=color, size=size, hover_data={'word': True, Ox: True, Oy: True, 'count': True}, width=1300, height=1000, ) if (Ox == 'accuracy_model_' + model_name_1 and Oy == 'accuracy_model_' + model_name_2) or ( Oy == 'accuracy_model_' + model_name_1 and Ox == 'accuracy_model_' + model_name_2 ): fig.add_shape( type="line", x0=0, y0=0, x1=100, y1=100, line=dict(color="MediumPurple", width=1, dash="dot",) ) return fig @app.callback( Output('filter-query-input', 'style'), Output('filter-query-output', 'style'), Input('filter-query-read-write', 'value'), ) def query_input_output(val): input_style = {'width': '100%'} output_style = {} input_style.update(display='inline-block') output_style.update(display='none') return input_style, output_style @app.callback(Output('datatable-advanced-filtering', 'filter_query'), Input('filter-query-input', 'value')) def write_query(query): if query is None: return '' return query @app.callback(Output('filter-query-output', 'children'), Input('datatable-advanced-filtering', 'filter_query')) def read_query(query): if query is None: return "No filter query" return dcc.Markdown('`filter_query = "{}"`'.format(query)) def display_query(query): if query is None: return '' return html.Details( [ html.Summary('Derived filter query structure'), html.Div( dcc.Markdown( '''```json {} ```'''.format( json.dumps(query, indent=4) ) ) ), ] ) comparison_layout = [ html.Div( [dcc.Markdown("model 1:" + ' ' + model_name_1[10:]), dcc.Markdown("model 2:" + ' ' + model_name_2[10:])] ), html.Hr(), html.Div( [ dcc.Dropdown(for_col_names.columns[::], 'accuracy_model_' + model_name_1, id='xaxis-column'), dcc.Dropdown(for_col_names.columns[::], 'accuracy_model_' + model_name_2, id='yaxis-column'), dcc.Dropdown( for_col_names.select_dtypes(include='number').columns[::], placeholder='Select what will encode color of points', id='color-column', ), dcc.Dropdown( for_col_names.select_dtypes(include='number').columns[::], placeholder='Select what will encode size of points', id='size-column', ), dcc.Dropdown(['yes', 'no'], placeholder='if you want to enable dot spacing', id='dot_spacing'), dcc.Input(id='radius', placeholder='Enter radius of spacing (std is 0.01)'), html.Hr(), dcc.Input(id='filter-query-input', placeholder='Enter filter query'), ], style={'width': '50%', 'display': 'inline-block', 'float': 'middle'}, ), html.Hr(), html.Div(id='filter-query-output'), dash_table.DataTable( id='datatable-advanced-filtering', columns=wordstable_columns_tool, data=vocabulary_1, editable=False, page_action='native', page_size=5, filter_action="native", ), html.Hr(), html.Div(id='datatable-query-structure', style={'whitespace': 'pre'}), html.Hr(), dbc.Row(dbc.Col(dcc.Graph(id='voc_graph'),),), html.Hr(), ] @app.callback( [Output('datatable', 'data'), Output('datatable', 'page_count')], [Input('datatable', 'page_current'), Input('datatable', 'sort_by'), Input('datatable', 'filter_query')], ) def update_datatable(page_current, sort_by, filter_query): data_view = data filtering_expressions = filter_query.split(' && ') for filter_part in filtering_expressions: col_name, op, filter_value = split_filter_part(filter_part) if op in ('eq', 'ne', 'lt', 'le', 'gt', 'ge'): data_view = [x for x in data_view if getattr(operator, op)(x[col_name], filter_value)] elif op == 'contains': data_view = [x for x in data_view if filter_value in str(x[col_name])] if len(sort_by): col = sort_by[0]['column_id'] descending = sort_by[0]['direction'] == 'desc' data_view = sorted(data_view, key=lambda x: x[col], reverse=descending) if page_current * DATA_PAGE_SIZE >= len(data_view): page_current = len(data_view) // DATA_PAGE_SIZE return [ data_view[page_current * DATA_PAGE_SIZE : (page_current + 1) * DATA_PAGE_SIZE], math.ceil(len(data_view) / DATA_PAGE_SIZE), ] if comparison_mode: app.layout = html.Div( [ dcc.Location(id='url', refresh=False), dbc.NavbarSimple( children=[ dbc.NavItem(dbc.NavLink('Statistics', id='stats_link', href='/', active=True)), dbc.NavItem(dbc.NavLink('Samples', id='samples_link', href='/samples')), dbc.NavItem(dbc.NavLink('Comparison tool', id='comp_tool', href='/comparison')), ], brand='Speech Data Explorer', sticky='top', color='green', dark=True, ), dbc.Container(id='page-content'), ] ) else: app.layout = html.Div( [ dcc.Location(id='url', refresh=False), dbc.NavbarSimple( children=[ dbc.NavItem(dbc.NavLink('Statistics', id='stats_link', href='/', active=True)), dbc.NavItem(dbc.NavLink('Samples', id='samples_link', href='/samples')), ], brand='Speech Data Explorer', sticky='top', color='green', dark=True, ), dbc.Container(id='page-content'), ] ) if comparison_mode: @app.callback( [ Output('page-content', 'children'), Output('stats_link', 'active'), Output('samples_link', 'active'), Output('comp_tool', 'active'), ], [Input('url', 'pathname')], ) def nav_click(url): if url == '/samples': return [samples_layout, False, True, False] elif url == '/comparison': return [comparison_layout, False, False, True] else: return [stats_layout, True, False, False] else: @app.callback( [Output('page-content', 'children'), Output('stats_link', 'active'), Output('samples_link', 'active'),], [Input('url', 'pathname')], ) def nav_click(url): if url == '/samples': return [samples_layout, False, True] else: return [stats_layout, True, False] @app.callback( [Output('_' + k, 'children') for k in data[0]], [Input('datatable', 'selected_rows'), Input('datatable', 'data')] ) def show_item(idx, data): if len(idx) == 0: raise PreventUpdate return [data[idx[0]][k] for k in data[0]] @app.callback(Output('_diff', 'srcDoc'), [Input('datatable', 'selected_rows'), Input('datatable', 'data'),]) def show_diff( idx, data, ): if len(idx) == 0: raise PreventUpdate orig_words = data[idx[0]]['text'] orig_words = '\n'.join(orig_words.split()) + '\n' pred_words = data[idx[0]][fld_nm] pred_words = '\n'.join(pred_words.split()) + '\n' diff = diff_match_patch.diff_match_patch() diff.Diff_Timeout = 0 orig_enc, pred_enc, enc = diff.diff_linesToChars(orig_words, pred_words) diffs = diff.diff_main(orig_enc, pred_enc, False) diff.diff_charsToLines(diffs, enc) diffs_post = [] for d in diffs: diffs_post.append((d[0], d[1].replace('\n', ' '))) diff_html = diff.diff_prettyHtml(diffs_post) return diff_html @app.callback(Output('signal-graph', 'figure'), [Input('datatable', 'selected_rows'), Input('datatable', 'data')]) def plot_signal(idx, data): if len(idx) == 0: raise PreventUpdate figs = make_subplots(rows=2, cols=1, subplot_titles=('Waveform', 'Spectrogram')) try: filename = absolute_audio_filepath(data[idx[0]]['audio_filepath'], args.audio_base_path) audio, fs = librosa.load(path=filename, sr=None) if 'offset' in data[idx[0]]: audio = audio[ int(data[idx[0]]['offset'] * fs) : int((data[idx[0]]['offset'] + data[idx[0]]['duration']) * fs) ] time_stride = 0.01 hop_length = int(fs * time_stride) n_fft = 512 # linear scale spectrogram s = librosa.stft(y=audio, n_fft=n_fft, hop_length=hop_length) s_db = librosa.power_to_db(S=np.abs(s) ** 2, ref=np.max, top_db=100) figs.add_trace( go.Scatter( x=np.arange(audio.shape[0]) / fs, y=audio, line={'color': 'green'}, name='Waveform', hovertemplate='Time: %{x:.2f} s
Amplitude: %{y:.2f}
', ), row=1, col=1, ) figs.add_trace( go.Heatmap( z=s_db, colorscale=[[0, 'rgb(30,62,62)'], [0.5, 'rgb(30,128,128)'], [1, 'rgb(30,255,30)'],], colorbar=dict(yanchor='middle', lenmode='fraction', y=0.2, len=0.5, ticksuffix=' dB'), dx=time_stride, dy=fs / n_fft / 1000, name='Spectrogram', hovertemplate='Time: %{x:.2f} s
Frequency: %{y:.2f} kHz
Magnitude: %{z:.2f} dB', ), row=2, col=1, ) figs.update_layout({'margin': dict(l=0, r=0, t=20, b=0, pad=0), 'height': 500}) figs.update_xaxes(title_text='Time, s', row=1, col=1) figs.update_yaxes(title_text='Amplitude', row=1, col=1) figs.update_xaxes(title_text='Time, s', row=2, col=1) figs.update_yaxes(title_text='Frequency, kHz', row=2, col=1) except Exception as ex: app.logger.error(f'ERROR in plot signal: {ex}') return figs @app.callback(Output('player', 'src'), [Input('datatable', 'selected_rows'), Input('datatable', 'data')]) def update_player(idx, data): if len(idx) == 0: raise PreventUpdate try: filename = absolute_audio_filepath(data[idx[0]]['audio_filepath'], args.audio_base_path) signal, sr = librosa.load(path=filename, sr=None) if 'offset' in data[idx[0]]: signal = signal[ int(data[idx[0]]['offset'] * sr) : int((data[idx[0]]['offset'] + data[idx[0]]['duration']) * sr) ] with io.BytesIO() as buf: # convert to PCM .wav sf.write(buf, signal, sr, format='WAV') buf.seek(0) encoded = base64.b64encode(buf.read()) return 'data:audio/wav;base64,{}'.format(encoded.decode()) except Exception as ex: app.logger.error(f'ERROR in audio player: {ex}') return '' if __name__ == '__main__': app.run_server(host='', port=args.port, debug=args.debug)