Joshua Lochner commited on
Commit
8a55e13
·
1 Parent(s): 8326048

Add transcript option to streamlit app and visual improvements

Browse files
Files changed (1) hide show
  1. app.py +97 -17
app.py CHANGED
@@ -11,11 +11,13 @@ from urllib.parse import quote
11
  # Allow direct execution
12
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
13
 
14
- from predict import SegmentationArguments, ClassifierArguments, predict as pred # noqa
 
15
  from evaluate import EvaluationArguments
16
  from shared import seconds_to_time, CATGEGORY_OPTIONS
17
  from utils import regex_search
18
- from model import get_model_tokenizer, get_classifier_vectorizer
 
19
 
20
  st.set_page_config(
21
  page_title='SponsorBlock ML',
@@ -105,14 +107,32 @@ for m in MODELS:
105
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
106
 
107
 
108
- def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id):
109
- if video_id not in prediction_cache[model_id]:
110
- prediction_cache[model_id][video_id] = pred(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  video_id, model, tokenizer,
112
  segmentation_args=segmentation_args,
113
- classifier_args=classifier_args
 
114
  )
115
- return prediction_cache[model_id][video_id]
116
 
117
 
118
  def load_predict(model_id):
@@ -133,7 +153,39 @@ def load_predict(model_id):
133
  return prediction_function_cache[model_id]
134
 
135
 
 
 
 
 
 
 
 
 
136
  def main():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
137
  top = st.container()
138
  output = st.empty()
139
 
@@ -143,12 +195,18 @@ def main():
143
  '##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
144
 
145
  # Add controls
146
- model_id = top.selectbox(
147
- 'Select model', MODELS.keys(), index=0, on_change=output.empty)
148
 
149
- video_input = top.text_input(
150
- 'Video URL/ID:', on_change=output.empty)
 
 
 
 
 
 
 
151
 
 
152
  categories = top.multiselect('Categories:',
153
  CATGEGORY_OPTIONS.keys(),
154
  CATGEGORY_OPTIONS.keys(),
@@ -172,8 +230,21 @@ def main():
172
  st.exception(ValueError('Invalid YouTube URL/ID'))
173
  return
174
 
 
 
 
 
 
 
 
 
 
 
 
 
 
175
  with st.spinner('Running model...'):
176
- predictions = predict(video_id)
177
 
178
  if len(predictions) == 0:
179
  st.success('No segments found!')
@@ -214,14 +285,23 @@ def main():
214
  st.write(f'**Text:** "{text}"')
215
 
216
  if len(submit_segments) == 0:
217
- st.success(f'No segments found! ({len(predictions)} ignored due to filters/settings)')
 
218
  return
219
 
 
 
 
 
 
220
  json_data = quote(json.dumps(submit_segments))
221
- link = f'[Submit Segments](https://www.youtube.com/watch?v={video_id}#segments={json_data})'
222
- st.markdown(link, unsafe_allow_html=True)
223
- wiki_link = '[Review generated segments before submitting!](https://wiki.sponsor.ajay.app/w/Automating_Submissions)'
224
- st.markdown(wiki_link, unsafe_allow_html=True)
 
 
 
225
 
226
 
227
  if __name__ == '__main__':
 
11
  # Allow direct execution
12
  sys.path.insert(0, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'src')) # noqa
13
 
14
+ from preprocess import get_words
15
+ from predict import SegmentationArguments, ClassifierArguments, predict as pred
16
  from evaluate import EvaluationArguments
17
  from shared import seconds_to_time, CATGEGORY_OPTIONS
18
  from utils import regex_search
19
+ from model import get_model_tokenizer
20
+ from errors import TranscriptError
21
 
22
  st.set_page_config(
23
  page_title='SponsorBlock ML',
 
107
  CLASSIFIER_PATH = 'Xenova/sponsorblock-classifier'
108
 
109
 
110
+ TRANSCRIPT_TYPES = {
111
+ 'AUTO_MANUAL': {
112
+ 'label': 'Auto-generated (fallback to manual)',
113
+ 'type': 'auto',
114
+ 'fallback': 'manual'
115
+ },
116
+ 'MANUAL_AUTO': {
117
+ 'label': 'Manual (fallback to auto-generated)',
118
+ 'type': 'manual',
119
+ 'fallback': 'auto'
120
+ },
121
+ # 'TRANSLATED': 'Translated to English' # Coming soon
122
+ }
123
+
124
+
125
+ def predict_function(model_id, model, tokenizer, segmentation_args, classifier_args, video_id, words, ts_type_id):
126
+ cache_id = f'{video_id}_{ts_type_id}'
127
+
128
+ if cache_id not in prediction_cache[model_id]:
129
+ prediction_cache[model_id][cache_id] = pred(
130
  video_id, model, tokenizer,
131
  segmentation_args=segmentation_args,
132
+ classifier_args=classifier_args,
133
+ words=words
134
  )
135
+ return prediction_cache[model_id][cache_id]
136
 
137
 
138
  def load_predict(model_id):
 
153
  return prediction_function_cache[model_id]
154
 
155
 
156
+ def create_button(text, url):
157
+ return f"""<div class="row-widget stButton" style="text-align: center">
158
+ <a href="{url}" target="_blank" rel="noopener noreferrer" class="btn-link">
159
+ <button kind="primary" class="btn">{text}</button>
160
+ </a>
161
+ </div>"""
162
+
163
+
164
  def main():
165
+ st.markdown("""<style>
166
+ .btn {
167
+ display: inline-flex;
168
+ -webkit-box-align: center;
169
+ align-items: center;
170
+ -webkit-box-pack: center;
171
+ justify-content: center;
172
+ font-weight: 600;
173
+ padding: 0.25rem 0.75rem;
174
+ border-radius: 0.25rem;
175
+ margin: 0px;
176
+ line-height: 1.5;
177
+ color: inherit;
178
+ width: auto;
179
+ user-select: none;
180
+ background-color: rgb(255, 255, 255);
181
+ border: 1px solid rgba(49, 51, 63, 0.2);
182
+ }
183
+ .btn-link {
184
+ color: inherit;
185
+ text-decoration: none;
186
+ }
187
+ </style>""", unsafe_allow_html=True)
188
+
189
  top = st.container()
190
  output = st.empty()
191
 
 
195
  '##### Automatically detect in-video YouTube sponsorships, self/unpaid promotions, and interaction reminders.')
196
 
197
  # Add controls
 
 
198
 
199
+ col1, col2 = top.columns(2)
200
+
201
+ with col1:
202
+ model_id = st.selectbox(
203
+ 'Select model', MODELS.keys(), index=0, on_change=output.empty)
204
+
205
+ with col2:
206
+ ts_type_id = st.selectbox(
207
+ 'Transcript type', TRANSCRIPT_TYPES.keys(), index=0, format_func=lambda x: TRANSCRIPT_TYPES[x]['label'], on_change=output.empty)
208
 
209
+ video_input = top.text_input('Video URL/ID:', on_change=output.empty)
210
  categories = top.multiselect('Categories:',
211
  CATGEGORY_OPTIONS.keys(),
212
  CATGEGORY_OPTIONS.keys(),
 
230
  st.exception(ValueError('Invalid YouTube URL/ID'))
231
  return
232
 
233
+ try:
234
+ with st.spinner('Downloading transcript...'):
235
+ words = get_words(video_id,
236
+ transcript_type=TRANSCRIPT_TYPES[ts_type_id]['type'],
237
+ fallback=TRANSCRIPT_TYPES[ts_type_id]['fallback']
238
+ )
239
+ except TranscriptError:
240
+ pass
241
+
242
+ if not words:
243
+ st.error('No transcript found!')
244
+ return
245
+
246
  with st.spinner('Running model...'):
247
+ predictions = predict(video_id, words, ts_type_id)
248
 
249
  if len(predictions) == 0:
250
  st.success('No segments found!')
 
285
  st.write(f'**Text:** "{text}"')
286
 
287
  if len(submit_segments) == 0:
288
+ st.success(
289
+ f'No segments found! ({len(predictions)} ignored due to filters/settings)')
290
  return
291
 
292
+ num_hidden = len(predictions) - len(submit_segments)
293
+ if num_hidden > 0:
294
+ st.info(
295
+ f'{num_hidden} predictions hidden (adjust the settings and filters to view them all).')
296
+
297
  json_data = quote(json.dumps(submit_segments))
298
+ link = f'https://www.youtube.com/watch?v={video_id}#segments={json_data}'
299
+ st.markdown(create_button('Submit Segments', link),
300
+ unsafe_allow_html=True)
301
+
302
+ st.markdown(f"""<div style="text-align: center;font-size: 16px;margin-top: 6px">
303
+ <a href="https://wiki.sponsor.ajay.app/w/Automating_Submissions" target="_blank" rel="noopener noreferrer">(Review before submitting!)</a>
304
+ </div>""", unsafe_allow_html=True)
305
 
306
 
307
  if __name__ == '__main__':