Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
3d1c770
1
Parent(s):
36f7534
Update streamlit app to use new classifier
Browse files
app.py
CHANGED
@@ -12,11 +12,10 @@ from urllib.parse import quote
|
|
12 |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
|
13 |
|
14 |
from preprocess import get_words
|
15 |
-
from predict import
|
16 |
-
from
|
17 |
-
from shared import seconds_to_time, CATGEGORY_OPTIONS
|
18 |
from utils import regex_search
|
19 |
-
from model import
|
20 |
from errors import TranscriptError
|
21 |
|
22 |
st.set_page_config(
|
@@ -104,7 +103,7 @@ for m in MODELS:
|
|
104 |
prediction_cache[m] = {}
|
105 |
|
106 |
|
107 |
-
CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
|
108 |
|
109 |
|
110 |
TRANSCRIPT_TYPES = {
|
@@ -122,15 +121,15 @@ TRANSCRIPT_TYPES = {
|
|
122 |
}
|
123 |
|
124 |
|
125 |
-
def predict_function(model_id, model, tokenizer, segmentation_args,
|
126 |
cache_id = f'{video_id}_{ts_type_id}'
|
127 |
|
128 |
if cache_id not in prediction_cache[model_id]:
|
129 |
prediction_cache[model_id][cache_id] = pred(
|
130 |
video_id, model, tokenizer,
|
131 |
segmentation_args=segmentation_args,
|
132 |
-
|
133 |
-
|
134 |
)
|
135 |
return prediction_cache[model_id][cache_id]
|
136 |
|
@@ -140,15 +139,15 @@ def load_predict(model_id):
|
|
140 |
|
141 |
if model_id not in prediction_function_cache:
|
142 |
# Use default segmentation and classification arguments
|
143 |
-
|
|
|
144 |
segmentation_args = SegmentationArguments()
|
145 |
-
classifier_args = ClassifierArguments(
|
146 |
-
min_probability=0) # Filtering done later
|
147 |
|
148 |
-
model, tokenizer =
|
149 |
|
150 |
prediction_function_cache[model_id] = partial(
|
151 |
-
predict_function, model_id, model, tokenizer, segmentation_args,
|
|
|
152 |
|
153 |
return prediction_function_cache[model_id]
|
154 |
|
@@ -252,7 +251,8 @@ def main():
|
|
252 |
|
253 |
submit_segments = []
|
254 |
for index, prediction in enumerate(predictions, start=1):
|
255 |
-
|
|
|
256 |
continue # Skip
|
257 |
|
258 |
confidence = prediction['probability'] * 100
|
@@ -262,13 +262,13 @@ def main():
|
|
262 |
|
263 |
submit_segments.append({
|
264 |
'segment': [prediction['start'], prediction['end']],
|
265 |
-
'category': prediction['category']
|
266 |
'actionType': 'skip'
|
267 |
})
|
268 |
start_time = seconds_to_time(prediction['start'])
|
269 |
end_time = seconds_to_time(prediction['end'])
|
270 |
with st.expander(
|
271 |
-
f"[{
|
272 |
):
|
273 |
|
274 |
url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
|
@@ -280,7 +280,7 @@ def main():
|
|
280 |
text = ' '.join(w['text'] for w in prediction['words'])
|
281 |
st.write(f"**Times:** {start_time} \u2192 {end_time}")
|
282 |
st.write(
|
283 |
-
f"**Category:** {CATGEGORY_OPTIONS[
|
284 |
st.write(f"**Confidence:** {confidence:.2f}%")
|
285 |
st.write(f'**Text:** "{text}"')
|
286 |
|
|
|
12 |
sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
|
13 |
|
14 |
from preprocess import get_words
|
15 |
+
from predict import PredictArguments, SegmentationArguments, predict as pred
|
16 |
+
from shared import GeneralArguments, seconds_to_time, CATGEGORY_OPTIONS
|
|
|
17 |
from utils import regex_search
|
18 |
+
from model import get_model_tokenizer_classifier
|
19 |
from errors import TranscriptError
|
20 |
|
21 |
st.set_page_config(
|
|
|
103 |
prediction_cache[m] = {}
|
104 |
|
105 |
|
106 |
+
CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier-v2'
|
107 |
|
108 |
|
109 |
TRANSCRIPT_TYPES = {
|
|
|
121 |
}
|
122 |
|
123 |
|
124 |
+
def predict_function(model_id, model, tokenizer, segmentation_args, classifier, video_id, words, ts_type_id):
|
125 |
cache_id = f'{video_id}_{ts_type_id}'
|
126 |
|
127 |
if cache_id not in prediction_cache[model_id]:
|
128 |
prediction_cache[model_id][cache_id] = pred(
|
129 |
video_id, model, tokenizer,
|
130 |
segmentation_args=segmentation_args,
|
131 |
+
words=words,
|
132 |
+
classifier=classifier
|
133 |
)
|
134 |
return prediction_cache[model_id][cache_id]
|
135 |
|
|
|
139 |
|
140 |
if model_id not in prediction_function_cache:
|
141 |
# Use default segmentation and classification arguments
|
142 |
+
predict_args = PredictArguments(model_name_or_path=model_info['repo_id'])
|
143 |
+
general_args = GeneralArguments()
|
144 |
segmentation_args = SegmentationArguments()
|
|
|
|
|
145 |
|
146 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(predict_args, general_args)
|
147 |
|
148 |
prediction_function_cache[model_id] = partial(
|
149 |
+
predict_function, model_id, model, tokenizer, segmentation_args, classifier)
|
150 |
+
|
151 |
|
152 |
return prediction_function_cache[model_id]
|
153 |
|
|
|
251 |
|
252 |
submit_segments = []
|
253 |
for index, prediction in enumerate(predictions, start=1):
|
254 |
+
category_key = prediction['category'].upper()
|
255 |
+
if category_key not in categories:
|
256 |
continue # Skip
|
257 |
|
258 |
confidence = prediction['probability'] * 100
|
|
|
262 |
|
263 |
submit_segments.append({
|
264 |
'segment': [prediction['start'], prediction['end']],
|
265 |
+
'category': prediction['category'],
|
266 |
'actionType': 'skip'
|
267 |
})
|
268 |
start_time = seconds_to_time(prediction['start'])
|
269 |
end_time = seconds_to_time(prediction['end'])
|
270 |
with st.expander(
|
271 |
+
f"[{category_key}] Prediction #{index} ({start_time} \u2192 {end_time})"
|
272 |
):
|
273 |
|
274 |
url = f"https://www.youtube-nocookie.com/embed/{video_id}?&start={floor(prediction['start'])}&end={ceil(prediction['end'])}"
|
|
|
280 |
text = ' '.join(w['text'] for w in prediction['words'])
|
281 |
st.write(f"**Times:** {start_time} \u2192 {end_time}")
|
282 |
st.write(
|
283 |
+
f"**Category:** {CATGEGORY_OPTIONS[category_key]}")
|
284 |
st.write(f"**Confidence:** {confidence:.2f}%")
|
285 |
st.write(f'**Text:** "{text}"')
|
286 |
|