Joshua Lochner commited on
Commit
3d1c770
·
1 Parent(s): 36f7534

Update streamlit app to use new classifier

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -12,11 +12,10 @@ from urllib.parse import quote
12
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
13
 
14
  from preprocess import get_words
15
- from predict import SegmentationArguments, ClassifierArguments, predict as pred
16
- from evaluate import EvaluationArguments
17
- from shared import seconds_to_time, CATGEGORY_OPTIONS
18
  from utils import regex_search
19
- from model import get_model_tokenizer
20
  from errors import TranscriptError
21
 
22
  st.set_page_config(
@@ -104,7 +103,7 @@ for m in MODELS:
104
  prediction_cache[m] = {}
105
 
106
 
107
- CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
108
 
109
 
110
  TRANSCRIPT_TYPES = {
@@ -122,15 +121,15 @@ TRANSCRIPT_TYPES = {
122
  }
123
 
124
 
125
- def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id, words, ts_type_id):
126
  cache_id = f'{video_id}_{ts_type_id}'
127
 
128
  if cache_id not in prediction_cache[model_id]:
129
  prediction_cache[model_id][cache_id] = pred(
130
  video_id, model, tokenizer,
131
  segmentation_args=segmentation_args,
132
- classifier_args=classifier_args,
133
- words=words
134
  )
135
  return prediction_cache[model_id][cache_id]
136
 
@@ -140,15 +139,15 @@ def load_predict(model_id):
140
 
141
  if model_id not in prediction_function_cache:
142
  # Use default segmentation and classification arguments
143
- evaluation_args = EvaluationArguments(model_path=model_info['repo_id'])
 
144
  segmentation_args = SegmentationArguments()
145
- classifier_args = ClassifierArguments(
146
- min_probability=0) # Filtering done later
147
 
148
- model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
149
 
150
  prediction_function_cache[model_id] = partial(
151
- predict_function, model_id, model, tokenizer, segmentation_args, classifier_args)
 
152
 
153
  return prediction_function_cache[model_id]
154
 
@@ -252,7 +251,8 @@ def main():
252
 
253
  submit_segments = []
254
  for index, prediction in enumerate(predictions, start=1):
255
- if prediction['category'] not in categories:
 
256
  continue # Skip
257
 
258
  confidence = prediction['probability'] * 100
@@ -262,13 +262,13 @@ def main():
262
 
263
  submit_segments.append({
264
  'segment': [prediction['start'], prediction['end']],
265
- 'category': prediction['category'].lower(),
266
  'actionType': 'skip'
267
  })
268
  start_time = seconds_to_time(prediction['start'])
269
  end_time = seconds_to_time(prediction['end'])
270
  with st.expander(
271
- f"[{prediction['category']}] Prediction #{index} ({start_time} \u2192 {end_time})"
272
  ):
273
 
274
  url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
@@ -280,7 +280,7 @@ def main():
280
  text = ' '.join(w['text'] for w in prediction['words'])
281
  st.write(f"**Times:** {start_time} \u2192 {end_time}")
282
  st.write(
283
- f"**Category:** {CATGEGORY_OPTIONS[prediction['category']]}")
284
  st.write(f"**Confidence:** {confidence:.2f}%")
285
  st.write(f'**Text:** "{text}"')
286
 
 
12
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
13
 
14
  from preprocess import get_words
15
+ from predict import PredictArguments, SegmentationArguments, predict as pred
16
+ from shared import GeneralArguments, seconds_to_time, CATGEGORY_OPTIONS
 
17
  from utils import regex_search
18
+ from model import get_model_tokenizer_classifier
19
  from errors import TranscriptError
20
 
21
  st.set_page_config(
 
103
  prediction_cache[m] = {}
104
 
105
 
106
+ CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier-v2'
107
 
108
 
109
  TRANSCRIPT_TYPES = {
 
121
  }
122
 
123
 
124
+ def predict_function(model_id, model, tokenizer, segmentation_args, classifier, video_id, words, ts_type_id):
125
  cache_id = f'{video_id}_{ts_type_id}'
126
 
127
  if cache_id not in prediction_cache[model_id]:
128
  prediction_cache[model_id][cache_id] = pred(
129
  video_id, model, tokenizer,
130
  segmentation_args=segmentation_args,
131
+ words=words,
132
+ classifier=classifier
133
  )
134
  return prediction_cache[model_id][cache_id]
135
 
 
139
 
140
  if model_id not in prediction_function_cache:
141
  # Use default segmentation and classification arguments
142
+ predict_args = PredictArguments(model_name_or_path=model_info['repo_id'])
143
+ general_args = GeneralArguments()
144
  segmentation_args = SegmentationArguments()
 
 
145
 
146
+ model, tokenizer, classifier = get_model_tokenizer_classifier(predict_args, general_args)
147
 
148
  prediction_function_cache[model_id] = partial(
149
+ predict_function, model_id, model, tokenizer, segmentation_args, classifier)
150
+
151
 
152
  return prediction_function_cache[model_id]
153
 
 
251
 
252
  submit_segments = []
253
  for index, prediction in enumerate(predictions, start=1):
254
+ category_key = prediction['category'].upper()
255
+ if category_key not in categories:
256
  continue # Skip
257
 
258
  confidence = prediction['probability'] * 100
 
262
 
263
  submit_segments.append({
264
  'segment': [prediction['start'], prediction['end']],
265
+ 'category': prediction['category'],
266
  'actionType': 'skip'
267
  })
268
  start_time = seconds_to_time(prediction['start'])
269
  end_time = seconds_to_time(prediction['end'])
270
  with st.expander(
271
+ f"[{category_key}] Prediction #{index} ({start_time} \u2192 {end_time})"
272
  ):
273
 
274
  url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
 
280
  text = ' '.join(w['text'] for w in prediction['words'])
281
  st.write(f"**Times:** {start_time} \u2192 {end_time}")
282
  st.write(
283
+ f"**Category:** {CATGEGORY_OPTIONS[category_key]}")
284
  st.write(f"**Confidence:** {confidence:.2f}%")
285
  st.write(f'**Text:** "{text}"')
286