Joshua Lochner commited on
Commit
36f7534
·
1 Parent(s): 15626e5

Upgrade classifier to transformer-based model

Browse files
Files changed (12) hide show
  1. src/classify.py +41 -0
  2. src/errors.py +2 -6
  3. src/evaluate.py +10 -10
  4. src/model.py +179 -65
  5. src/moderate.py +104 -0
  6. src/predict.py +26 -203
  7. src/preprocess.py +89 -45
  8. src/segment.py +2 -0
  9. src/shared.py +153 -0
  10. src/train.py +120 -298
  11. src/train_classifier.py +287 -0
  12. src/utils.py +0 -4
src/classify.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import TextClassificationPipeline
2
+ import preprocess
3
+ import segment
4
+
5
+
6
+ class SponsorBlockClassificationPipeline(TextClassificationPipeline):
7
+ def __init__(self, model, tokenizer):
8
+ device = next(model.parameters()).device.index
9
+ super().__init__(model=model, tokenizer=tokenizer,
10
+ return_all_scores=True, truncation=True, device=device)
11
+
12
+ def preprocess(self, data, **tokenizer_kwargs):
13
+ # TODO add support for lists
14
+ texts = []
15
+
16
+ if not isinstance(data, list):
17
+ data = [data]
18
+
19
+ for d in data:
20
+ if isinstance(d, dict): # Otherwise, get data from transcript
21
+ words = preprocess.get_words(d['video_id'])
22
+ segment_words = segment.extract_segment(
23
+ words, d['start'], d['end'])
24
+ text = preprocess.clean_text(
25
+ ' '.join(x['text'] for x in segment_words))
26
+ texts.append(text)
27
+ elif isinstance(d, str): # If string, assume this is what user wants to classify
28
+ texts.append(d)
29
+ else:
30
+ raise ValueError(f'Invalid input type: "{type(d)}"')
31
+
32
+ return self.tokenizer(
33
+ texts, return_tensors=self.framework, **tokenizer_kwargs)
34
+
35
+
36
+ def main():
37
+ pass
38
+
39
+
40
+ if __name__ == '__main__':
41
+ main()
src/errors.py CHANGED
@@ -1,9 +1,10 @@
 
1
  class SponsorBlockException(Exception):
2
  """Base class for all sponsor block exceptions"""
3
  pass
4
 
5
 
6
- class PredictionException(SponsorBlockException):
7
  """An exception occurred while predicting sponsor segments"""
8
  pass
9
 
@@ -21,8 +22,3 @@ class ModelError(SponsorBlockException):
21
  class ModelLoadError(ModelError):
22
  """An exception occurred while loading the model"""
23
  pass
24
-
25
-
26
- class ClassifierLoadError(ModelError):
27
- """An exception occurred while loading the classifier"""
28
- pass
 
1
+
2
  class SponsorBlockException(Exception):
3
  """Base class for all sponsor block exceptions"""
4
  pass
5
 
6
 
7
+ class InferenceException(SponsorBlockException):
8
  """An exception occurred while predicting sponsor segments"""
9
  pass
10
 
 
22
  class ModelLoadError(ModelError):
23
  """An exception occurred while loading the model"""
24
  pass
 
 
 
 
 
src/evaluate.py CHANGED
@@ -1,10 +1,10 @@
1
 
2
- from model import get_model_tokenizer
3
  from utils import jaccard
4
  from transformers import HfArgumentParser
5
- from preprocess import DatasetArguments, get_words
6
- from shared import GeneralArguments
7
- from predict import ClassifierArguments, predict, InferenceArguments
8
  from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
9
  import pandas as pd
10
  from dataclasses import dataclass, field
@@ -134,11 +134,10 @@ def main():
134
  EvaluationArguments,
135
  DatasetArguments,
136
  SegmentationArguments,
137
- ClassifierArguments,
138
  GeneralArguments
139
  ))
140
 
141
- evaluation_args, dataset_args, segmentation_args, classifier_args, general_args = hf_parser.parse_args_into_dataclasses()
142
 
143
  # Load labelled data:
144
  final_path = os.path.join(
@@ -149,8 +148,8 @@ def main():
149
  f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
150
  return
151
 
152
- model, tokenizer = get_model_tokenizer(
153
- evaluation_args.model_path, evaluation_args.cache_dir, general_args.no_cuda)
154
 
155
  with open(final_path) as fp:
156
  final_data = json.load(fp)
@@ -187,8 +186,9 @@ def main():
187
  continue
188
 
189
  # Make predictions
190
- predictions = predict(video_id, model, tokenizer,
191
- segmentation_args, words, classifier_args)
 
192
 
193
  # Get labels
194
  sponsor_segments = final_data.get(video_id)
 
1
 
2
+ from model import get_model_tokenizer_classifier, InferenceArguments
3
  from utils import jaccard
4
  from transformers import HfArgumentParser
5
+ from preprocess import get_words
6
+ from shared import GeneralArguments, DatasetArguments
7
+ from predict import predict
8
  from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
9
  import pandas as pd
10
  from dataclasses import dataclass, field
 
134
  EvaluationArguments,
135
  DatasetArguments,
136
  SegmentationArguments,
 
137
  GeneralArguments
138
  ))
139
 
140
+ evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
141
 
142
  # Load labelled data:
143
  final_path = os.path.join(
 
148
  f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
149
  return
150
 
151
+ model, tokenizer, classifier = get_model_tokenizer_classifier(
152
+ evaluation_args, general_args)
153
 
154
  with open(final_path) as fp:
155
  final_data = json.load(fp)
 
186
  continue
187
 
188
  # Make predictions
189
+ predictions = predict(video_id, model, tokenizer, segmentation_args,
190
+ classifier=classifier,
191
+ min_probability=evaluation_args.min_probability)
192
 
193
  # Get labels
194
  sponsor_segments = final_data.get(video_id)
src/model.py CHANGED
@@ -1,13 +1,68 @@
1
- from huggingface_hub import hf_hub_download
2
- from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
3
- from shared import CustomTokens
4
- from errors import ClassifierLoadError, ModelLoadError
5
  from functools import lru_cache
6
- import pickle
7
- import os
8
  from dataclasses import dataclass, field
9
- from typing import Optional
10
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
  @dataclass
@@ -17,34 +72,24 @@ class ModelArguments:
17
  """
18
 
19
  model_name_or_path: str = field(
20
- default=None,
21
- # default='google/t5-v1_1-small', # t5-small
22
  metadata={
23
  'help': 'Path to pretrained model or model identifier from huggingface.co/models'
24
  }
25
  )
26
 
27
- # config_name: Optional[str] = field( # TODO remove?
28
- # default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
29
- # )
30
- # tokenizer_name: Optional[str] = field(
31
- # default=None, metadata={
32
- # 'help': 'Pretrained tokenizer name or path if not the same as model_name'
33
- # }
34
- # )
35
  cache_dir: Optional[str] = field(
36
  default='models',
37
  metadata={
38
  'help': 'Where to store the pretrained models downloaded from huggingface.co'
39
  },
40
  )
41
- use_fast_tokenizer: bool = field( # TODO remove?
42
  default=True,
43
  metadata={
44
  'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
45
  },
46
  )
47
- model_revision: str = field( # TODO remove?
48
  default='main',
49
  metadata={
50
  'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
@@ -57,62 +102,131 @@ class ModelArguments:
57
  'with private models).'
58
  },
59
  )
60
- resize_position_embeddings: Optional[bool] = field(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  default=None,
62
  metadata={
63
- 'help': "Whether to automatically resize the position embeddings if `max_source_length` exceeds the model's position embeddings."
64
- },
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  )
66
 
 
 
67
 
68
- @lru_cache(maxsize=None)
69
- def get_classifier_vectorizer(classifier_args):
70
- # Classifier
71
- classifier_path = os.path.join(
72
- classifier_args.classifier_dir, classifier_args.classifier_file)
73
- if not os.path.exists(classifier_path):
74
- hf_hub_download(repo_id=classifier_args.classifier_model,
75
- filename=classifier_args.classifier_file,
76
- cache_dir=classifier_args.classifier_dir,
77
- force_filename=classifier_args.classifier_file,
78
- )
79
- with open(classifier_path, 'rb') as fp:
80
- classifier = pickle.load(fp)
81
-
82
- # Vectorizer
83
- vectorizer_path = os.path.join(
84
- classifier_args.classifier_dir, classifier_args.vectorizer_file)
85
- if not os.path.exists(vectorizer_path):
86
- hf_hub_download(repo_id=classifier_args.classifier_model,
87
- filename=classifier_args.vectorizer_file,
88
- cache_dir=classifier_args.classifier_dir,
89
- force_filename=classifier_args.vectorizer_file,
90
- )
91
- with open(vectorizer_path, 'rb') as fp:
92
- vectorizer = pickle.load(fp)
93
-
94
- return classifier, vectorizer
95
-
96
-
97
- @lru_cache(maxsize=None)
98
- def get_model_tokenizer(model_name_or_path, cache_dir=None, no_cuda=False):
99
- if model_name_or_path is None:
100
- raise ModelLoadError('Invalid model_name_or_path.')
101
-
102
- # Load pretrained model and tokenizer
103
- model = AutoModelForSeq2SeqLM.from_pretrained(
104
- model_name_or_path, cache_dir=cache_dir)
105
- if not no_cuda:
106
- model.to('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
 
108
  tokenizer = AutoTokenizer.from_pretrained(
109
- model_name_or_path, cache_dir=cache_dir)
 
 
 
 
 
110
 
111
- # Ensure model and tokenizer contain the custom tokens
 
 
 
 
 
 
 
 
 
112
  CustomTokens.add_custom_tokens(tokenizer)
113
  model.resize_token_embeddings(len(tokenizer))
114
 
115
- # TODO find a way to adjust based on model's input size
116
- # print('tokenizer.model_max_length', tokenizer.model_max_length)
 
117
 
118
  return model, tokenizer
 
1
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
2
+ from shared import CustomTokens, GeneralArguments
 
 
3
  from functools import lru_cache
 
 
4
  from dataclasses import dataclass, field
5
+ from typing import Optional, Union
6
  import torch
7
+ import classify
8
+ import base64
9
+ import re
10
+ import requests
11
+ import json
12
+ import logging
13
+
14
+ logging.basicConfig()
15
+ logger = logging.getLogger(__name__)
16
+
17
+ # Public innertube key (b64 encoded so that it is not incorrectly flagged)
18
+ INNERTUBE_KEY = base64.b64decode(
19
+ b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
20
+
21
+ YT_CONTEXT = {
22
+ 'client': {
23
+ 'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
24
+ 'clientName': 'WEB',
25
+ 'clientVersion': '2.20211221.00.00',
26
+ }
27
+ }
28
+ _YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
29
+
30
+
31
+ def get_all_channel_vids(channel_id):
32
+ continuation = None
33
+ while True:
34
+ if continuation is None:
35
+ params = {'list': channel_id.replace('UC', 'UU', 1)}
36
+ response = requests.get(
37
+ 'https://www.youtube.com/playlist', params=params)
38
+ items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
39
+ 'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
40
+ else:
41
+ params = {'key': INNERTUBE_KEY}
42
+ data = {
43
+ 'context': YT_CONTEXT,
44
+ 'continuation': continuation
45
+ }
46
+ response = requests.post(
47
+ 'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
48
+ items = response.json()[
49
+ 'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
50
+
51
+ new_token = None
52
+ for vid in items:
53
+ info = vid.get('playlistVideoRenderer')
54
+ if info:
55
+ yield info['videoId']
56
+ continue
57
+
58
+ info = vid.get('continuationItemRenderer')
59
+ if info:
60
+ new_token = info['continuationEndpoint']['continuationCommand']['token']
61
+
62
+ if new_token is None:
63
+ break
64
+ continuation = new_token
65
+
66
 
67
 
68
  @dataclass
 
72
  """
73
 
74
  model_name_or_path: str = field(
 
 
75
  metadata={
76
  'help': 'Path to pretrained model or model identifier from huggingface.co/models'
77
  }
78
  )
79
 
 
 
 
 
 
 
 
 
80
  cache_dir: Optional[str] = field(
81
  default='models',
82
  metadata={
83
  'help': 'Where to store the pretrained models downloaded from huggingface.co'
84
  },
85
  )
86
+ use_fast_tokenizer: bool = field(
87
  default=True,
88
  metadata={
89
  'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
90
  },
91
  )
92
+ model_revision: str = field(
93
  default='main',
94
  metadata={
95
  'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
 
102
  'with private models).'
103
  },
104
  )
105
+
106
+ import itertools
107
+ from errors import InferenceException
108
+
109
+ @dataclass
110
+ class InferenceArguments(ModelArguments):
111
+
112
+ model_name_or_path: str = field(
113
+ default='Xenova/sponsorblock-small',
114
+ metadata={
115
+ 'help': 'Path to pretrained model used for prediction'
116
+ }
117
+ )
118
+ classifier_model_name_or_path: str = field(
119
+ default='Xenova/sponsorblock-classifier-v2',
120
+ metadata={
121
+ 'help': 'Use a pretrained classifier'
122
+ }
123
+ )
124
+
125
+ max_videos: Optional[int] = field(
126
  default=None,
127
  metadata={
128
+ 'help': 'The number of videos to test on'
129
+ }
130
+ )
131
+ start_index: int = field(default=None, metadata={
132
+ 'help': 'Video to start the evaluation at.'})
133
+ channel_id: Optional[str] = field(
134
+ default=None,
135
+ metadata={
136
+ 'help': 'Used to evaluate a channel'
137
+ }
138
+ )
139
+ video_ids: str = field(
140
+ default_factory=lambda: [],
141
+ metadata={
142
+ 'nargs': '+'
143
+ }
144
  )
145
 
146
+ output_as_json: bool = field(default=False, metadata={
147
+ 'help': 'Output evaluations as JSON'})
148
 
149
+ min_probability: float = field(
150
+ default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
151
+
152
+ def __post_init__(self):
153
+
154
+ self.video_ids = list(map(str.strip, self.video_ids))
155
+
156
+ if any(len(video_id) != 11 for video_id in self.video_ids):
157
+ raise InferenceException('Invalid video IDs (length not 11)')
158
+
159
+ if self.channel_id is not None:
160
+ start = self.start_index or 0
161
+ end = None if self.max_videos is None else start + self.max_videos
162
+
163
+ channel_video_ids = list(itertools.islice(get_all_channel_vids(
164
+ self.channel_id), start, end))
165
+ logger.info(
166
+ f'Found {len(channel_video_ids)} for channel {self.channel_id}')
167
+
168
+ self.video_ids += channel_video_ids
169
+
170
+
171
+
172
+ def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_args: GeneralArguments):
173
+
174
+ original_path = inference_args.model_name_or_path
175
+
176
+ # Load main model and tokenizer
177
+ model, tokenizer = get_model_tokenizer(inference_args, general_args)
178
+
179
+ # Load classifier
180
+ inference_args.model_name_or_path = inference_args.classifier_model_name_or_path
181
+ classifier_model, classifier_tokenizer = get_model_tokenizer(
182
+ inference_args, general_args, model_type='classifier')
183
+
184
+ classifier = classify.SponsorBlockClassificationPipeline(
185
+ classifier_model, classifier_tokenizer)
186
+
187
+ # Reset to original model_name_or_path
188
+ inference_args.model_name_or_path = original_path
189
+
190
+ return model, tokenizer, classifier
191
+
192
+
193
+ def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
194
+ if config_args is None:
195
+ config_args = {}
196
+
197
+ use_auth_token = True if model_args.use_auth_token else None
198
+
199
+ config = AutoConfig.from_pretrained(
200
+ model_args.model_name_or_path,
201
+ cache_dir=model_args.cache_dir,
202
+ revision=model_args.model_revision,
203
+ use_auth_token=use_auth_token,
204
+ **config_args
205
+ )
206
 
207
  tokenizer = AutoTokenizer.from_pretrained(
208
+ model_args.model_name_or_path,
209
+ cache_dir=model_args.cache_dir,
210
+ use_fast=model_args.use_fast_tokenizer,
211
+ revision=model_args.model_revision,
212
+ use_auth_token=use_auth_token,
213
+ )
214
 
215
+ model_type = AutoModelForSeq2SeqLM if model_type == 'seq2seq' else AutoModelForSequenceClassification
216
+ model = model_type.from_pretrained(
217
+ model_args.model_name_or_path,
218
+ config=config,
219
+ cache_dir=model_args.cache_dir,
220
+ revision=model_args.model_revision,
221
+ use_auth_token=use_auth_token,
222
+ )
223
+
224
+ # Add custom tokens
225
  CustomTokens.add_custom_tokens(tokenizer)
226
  model.resize_token_embeddings(len(tokenizer))
227
 
228
+ # Potentially move model to gpu
229
+ if general_args is not None and not general_args.no_cuda:
230
+ model.to('cuda' if torch.cuda.is_available() else 'cpu')
231
 
232
  return model, tokenizer
src/moderate.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import (
3
+ AutoModelForSequenceClassification,
4
+ AutoTokenizer,
5
+ HfArgumentParser
6
+ )
7
+
8
+ from train_classifier import ClassifierModelArguments
9
+ from shared import CATEGORIES, DatasetArguments
10
+ from tqdm import tqdm
11
+
12
+ from preprocess import get_words, clean_text
13
+ from segment import extract_segment
14
+ import os
15
+ import json
16
+ import numpy as np
17
+
18
+
19
+ def softmax(_outputs):
20
+ maxes = np.max(_outputs, axis=-1, keepdims=True)
21
+ shifted_exp = np.exp(_outputs - maxes)
22
+ return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
23
+
24
+
25
+ def main():
26
+ # See all possible arguments in src/transformers/training_args.py
27
+ # or by passing the --help flag to this script.
28
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
29
+
30
+ parser = HfArgumentParser((ClassifierModelArguments, DatasetArguments))
31
+ model_args, dataset_args = parser.parse_args_into_dataclasses()
32
+
33
+ model = AutoModelForSequenceClassification.from_pretrained(
34
+ model_args.model_name_or_path)
35
+ tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
36
+
37
+ processed_db_path = os.path.join(
38
+ dataset_args.data_dir, dataset_args.processed_database)
39
+ with open(processed_db_path) as fp:
40
+ data = json.load(fp)
41
+
42
+ mapped_categories = {
43
+ str(v).lower(): k for k, v in enumerate(CATEGORIES)
44
+ }
45
+
46
+ for video_id, segments in tqdm(data.items()):
47
+
48
+ words = get_words(video_id)
49
+
50
+ if not words:
51
+ continue # No/empty transcript for video_id
52
+
53
+ valid_segments = []
54
+ texts = []
55
+ for segment in segments:
56
+ segment_words = extract_segment(
57
+ words, segment['start'], segment['end'])
58
+ text = clean_text(' '.join(x['text'] for x in segment_words))
59
+
60
+ duration = segment['end'] - segment['start']
61
+ wps = len(segment_words)/duration if duration > 0 else 0
62
+ if wps < 1.5:
63
+ continue
64
+
65
+ # Do not worry about those that are locked or have enough votes
66
+ if segment['locked']: # or segment['votes'] > 5:
67
+ continue
68
+
69
+ texts.append(text)
70
+ valid_segments.append(segment)
71
+
72
+ if not texts:
73
+ continue # No valid segments
74
+
75
+ model_inputs = tokenizer(
76
+ texts, return_tensors='pt', padding=True, truncation=True)
77
+
78
+ with torch.no_grad():
79
+ model_outputs = model(**model_inputs)
80
+ outputs = list(map(lambda x: x.numpy(), model_outputs['logits']))
81
+
82
+ scores = softmax(outputs)
83
+
84
+ for segment, text, score in zip(valid_segments, texts, scores):
85
+ predicted_index = score.argmax().item()
86
+
87
+ if predicted_index == mapped_categories[segment['category']]:
88
+ continue # Ignore correct segments
89
+
90
+ a = {k: round(float(score[i]), 3)
91
+ for i, k in enumerate(CATEGORIES)}
92
+
93
+ del segment['submission_time']
94
+ segment.update({
95
+ 'predicted': str(CATEGORIES[predicted_index]).lower(),
96
+ 'text': text,
97
+ 'scores': a
98
+ })
99
+
100
+ print(json.dumps(segment))
101
+
102
+
103
+ if __name__ == "__main__":
104
+ main()
src/predict.py CHANGED
@@ -1,17 +1,8 @@
1
- import itertools
2
- import base64
3
- import re
4
- import requests
5
- import json
6
  from transformers import HfArgumentParser
7
- from transformers.trainer_utils import get_last_checkpoint
8
  from dataclasses import dataclass, field
9
  import logging
10
- import os
11
- import itertools
12
- from utils import re_findall
13
- from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, OutputArguments, seconds_to_time
14
- from typing import Optional
15
  from segment import (
16
  generate_segments,
17
  extract_segment,
@@ -22,129 +13,12 @@ from segment import (
22
  SegmentationArguments
23
  )
24
  import preprocess
25
- from errors import PredictionException, TranscriptError, ModelLoadError, ClassifierLoadError
26
- from model import ModelArguments, get_classifier_vectorizer, get_model_tokenizer
27
 
28
  logging.basicConfig()
29
  logger = logging.getLogger(__name__)
30
 
31
- # Public innertube key (b64 encoded so that it is not incorrectly flagged)
32
- INNERTUBE_KEY = base64.b64decode(
33
- b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
34
-
35
- YT_CONTEXT = {
36
- 'client': {
37
- 'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
38
- 'clientName': 'WEB',
39
- 'clientVersion': '2.20211221.00.00',
40
- }
41
- }
42
- _YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
43
-
44
-
45
- def get_all_channel_vids(channel_id):
46
- continuation = None
47
- while True:
48
- if continuation is None:
49
- params = {'list': channel_id.replace('UC', 'UU', 1)}
50
- response = requests.get(
51
- 'https://www.youtube.com/playlist', params=params)
52
- items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
53
- 'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
54
- else:
55
- params = {'key': INNERTUBE_KEY}
56
- data = {
57
- 'context': YT_CONTEXT,
58
- 'continuation': continuation
59
- }
60
- response = requests.post(
61
- 'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
62
- items = response.json()[
63
- 'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
64
-
65
- new_token = None
66
- for vid in items:
67
- info = vid.get('playlistVideoRenderer')
68
- if info:
69
- yield info['videoId']
70
- continue
71
-
72
- info = vid.get('continuationItemRenderer')
73
- if info:
74
- new_token = info['continuationEndpoint']['continuationCommand']['token']
75
-
76
- if new_token is None:
77
- break
78
- continuation = new_token
79
-
80
-
81
- @dataclass
82
- class InferenceArguments:
83
-
84
- model_path: str = field(
85
- default='Xenova/sponsorblock-small',
86
- metadata={
87
- 'help': 'Path to pretrained model used for prediction'
88
- }
89
- )
90
- cache_dir: Optional[str] = ModelArguments.__dataclass_fields__['cache_dir']
91
-
92
- output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
93
- 'output_dir']
94
-
95
- max_videos: Optional[int] = field(
96
- default=None,
97
- metadata={
98
- 'help': 'The number of videos to test on'
99
- }
100
- )
101
- start_index: int = field(default=None, metadata={
102
- 'help': 'Video to start the evaluation at.'})
103
- channel_id: Optional[str] = field(
104
- default=None,
105
- metadata={
106
- 'help': 'Used to evaluate a channel'
107
- }
108
- )
109
- video_ids: str = field(
110
- default_factory=lambda: [],
111
- metadata={
112
- 'nargs': '+'
113
- }
114
- )
115
-
116
- output_as_json: bool = field(default=False, metadata={
117
- 'help': 'Output evaluations as JSON'})
118
-
119
- def __post_init__(self):
120
- # Try to load model from latest checkpoint
121
- if self.model_path is None:
122
- if os.path.exists(self.output_dir):
123
- last_checkpoint = get_last_checkpoint(self.output_dir)
124
- if last_checkpoint is not None:
125
- self.model_path = last_checkpoint
126
- else:
127
- raise ModelLoadError(
128
- 'Unable to load model from checkpoint, explicitly set `--model_path`')
129
- else:
130
- raise ModelLoadError(
131
- f'Unable to find model in {self.output_dir}, explicitly set `--model_path`')
132
-
133
- if any(len(video_id) != 11 for video_id in self.video_ids):
134
- raise PredictionException('Invalid video IDs (length not 11)')
135
-
136
- if self.channel_id is not None:
137
- start = self.start_index or 0
138
- end = None if self.max_videos is None else start + self.max_videos
139
-
140
- channel_video_ids = list(itertools.islice(get_all_channel_vids(
141
- self.channel_id), start, end))
142
- logger.info(
143
- f'Found {len(channel_video_ids)} for channel {self.channel_id}')
144
-
145
- self.video_ids += channel_video_ids
146
-
147
-
148
  @dataclass
149
  class PredictArguments(InferenceArguments):
150
  video_id: str = field(
@@ -160,10 +34,6 @@ class PredictArguments(InferenceArguments):
160
  super().__post_init__()
161
 
162
 
163
- _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
164
- _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
165
- SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
166
-
167
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
168
  MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
169
 
@@ -171,70 +41,35 @@ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
171
  START_TIME_ZERO_THRESHOLD = 0.08
172
 
173
 
174
- @dataclass(frozen=True, eq=True)
175
- class ClassifierArguments:
176
- classifier_model: Optional[str] = field(
177
- default='Xenova/sponsorblock-classifier',
178
- metadata={
179
- 'help': 'Use a pretrained classifier'
180
- }
181
- )
182
-
183
- classifier_dir: Optional[str] = field(
184
- default='classifiers',
185
- metadata={
186
- 'help': 'The directory that contains the classifier and vectorizer.'
187
- }
188
- )
189
-
190
- classifier_file: Optional[str] = field(
191
- default='classifier.pickle',
192
- metadata={
193
- 'help': 'The name of the classifier'
194
- }
195
- )
196
-
197
- vectorizer_file: Optional[str] = field(
198
- default='vectorizer.pickle',
199
- metadata={
200
- 'help': 'The name of the vectorizer'
201
- }
202
- )
203
-
204
- min_probability: float = field(
205
- default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
206
-
207
-
208
- def filter_and_add_probabilities(predictions, classifier_args):
209
  """Use classifier to filter predictions"""
210
  if not predictions:
211
  return predictions
212
 
213
- classifier, vectorizer = get_classifier_vectorizer(classifier_args)
 
214
 
215
- transformed_segments = vectorizer.transform([
216
  preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
217
  for pred in predictions
218
- ])
219
- probabilities = classifier.predict_proba(transformed_segments)
220
 
221
- # Transformer sometimes says segment is of another category, so we
222
- # update category and probabilities if classifier is confident it is another category
223
  filtered_predictions = []
224
- for prediction, probabilities in zip(predictions, probabilities):
225
- predicted_probabilities = {k: v for k,
226
- v in zip(CATEGORIES, probabilities)}
227
 
228
  # Get best category + probability
229
  classifier_category = max(
230
  predicted_probabilities, key=predicted_probabilities.get)
231
  classifier_probability = predicted_probabilities[classifier_category]
232
 
233
- if classifier_category is None and classifier_probability > classifier_args.min_probability:
234
  continue # Ignore
235
 
236
  if (prediction['category'] not in predicted_probabilities) \
237
- or (classifier_category is not None and classifier_probability > 0.5): # TODO make param
238
  # Unknown category or we are confident enough to overrule,
239
  # so change category to what was predicted by classifier
240
  prediction['category'] = classifier_category
@@ -252,7 +87,7 @@ def filter_and_add_probabilities(predictions, classifier_args):
252
  return filtered_predictions
253
 
254
 
255
- def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier_args=None):
256
  # Allow words to be passed in so that we don't have to get the words if we already have them
257
  if words is None:
258
  words = preprocess.get_words(video_id)
@@ -272,13 +107,9 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
272
  prediction['words'] = extract_segment(
273
  words, prediction['start'], prediction['end'])
274
 
275
- # TODO add back
276
- if classifier_args is not None:
277
- try:
278
- predictions = filter_and_add_probabilities(
279
- predictions, classifier_args)
280
- except ClassifierLoadError:
281
- print('Unable to load classifer')
282
 
283
  return predictions
284
 
@@ -300,9 +131,6 @@ def greedy_match(list, sublist):
300
  return best_i, best_j, best_k
301
 
302
 
303
- CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
304
-
305
-
306
  def predict_sponsor_text(text, model, tokenizer):
307
  """Given a body of text, predict the words which are part of the sponsor"""
308
  model_device = next(model.parameters()).device
@@ -322,11 +150,7 @@ def predict_sponsor_text(text, model, tokenizer):
322
 
323
  def predict_sponsor_matches(text, model, tokenizer):
324
  sponsorship_text = predict_sponsor_text(text, model, tokenizer)
325
-
326
- if CustomTokens.NO_SEGMENT.value in sponsorship_text:
327
- return []
328
-
329
- return re_findall(SEGMENT_MATCH_RE, sponsorship_text)
330
 
331
 
332
  def segments_to_predictions(segments, model, tokenizer):
@@ -400,24 +224,23 @@ def main():
400
  hf_parser = HfArgumentParser((
401
  PredictArguments,
402
  SegmentationArguments,
403
- ClassifierArguments,
404
  GeneralArguments
405
  ))
406
- predict_args, segmentation_args, classifier_args, general_args = hf_parser.parse_args_into_dataclasses()
407
 
408
  if not predict_args.video_ids:
409
  logger.error(
410
  'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
411
  return
412
 
413
- model, tokenizer = get_model_tokenizer(
414
- predict_args.model_path, predict_args.cache_dir, general_args.no_cuda)
415
 
416
  for video_id in predict_args.video_ids:
417
- video_id = video_id.strip()
418
  try:
419
- predictions = predict(video_id, model, tokenizer,
420
- segmentation_args, classifier_args=classifier_args)
 
421
  except TranscriptError:
422
  logger.warning(f'No transcript available for {video_id}')
423
  continue
 
1
+
 
 
 
 
2
  from transformers import HfArgumentParser
 
3
  from dataclasses import dataclass, field
4
  import logging
5
+ from shared import CustomTokens, extract_sponsor_matches, GeneralArguments, seconds_to_time
 
 
 
 
6
  from segment import (
7
  generate_segments,
8
  extract_segment,
 
13
  SegmentationArguments
14
  )
15
  import preprocess
16
+ from errors import TranscriptError
17
+ from model import get_model_tokenizer_classifier, InferenceArguments
18
 
19
  logging.basicConfig()
20
  logger = logging.getLogger(__name__)
21
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  @dataclass
23
  class PredictArguments(InferenceArguments):
24
  video_id: str = field(
 
34
  super().__post_init__()
35
 
36
 
 
 
 
 
37
  MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
38
  MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
39
 
 
41
  START_TIME_ZERO_THRESHOLD = 0.08
42
 
43
 
44
+ def filter_and_add_probabilities(predictions, classifier, min_probability):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  """Use classifier to filter predictions"""
46
  if not predictions:
47
  return predictions
48
 
49
+ # We update the predicted category from the extractive transformer
50
+ # if the classifier is confident enough it is another category
51
 
52
+ texts = [
53
  preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
54
  for pred in predictions
55
+ ]
56
+ classifications = classifier(texts)
57
 
 
 
58
  filtered_predictions = []
59
+ for prediction, probabilities in zip(predictions, classifications):
60
+ predicted_probabilities = {
61
+ p['label'].lower(): p['score'] for p in probabilities}
62
 
63
  # Get best category + probability
64
  classifier_category = max(
65
  predicted_probabilities, key=predicted_probabilities.get)
66
  classifier_probability = predicted_probabilities[classifier_category]
67
 
68
+ if classifier_category == 'none' and classifier_probability > min_probability:
69
  continue # Ignore
70
 
71
  if (prediction['category'] not in predicted_probabilities) \
72
+ or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param
73
  # Unknown category or we are confident enough to overrule,
74
  # so change category to what was predicted by classifier
75
  prediction['category'] = classifier_category
 
87
  return filtered_predictions
88
 
89
 
90
+ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier=None, min_probability=None):
91
  # Allow words to be passed in so that we don't have to get the words if we already have them
92
  if words is None:
93
  words = preprocess.get_words(video_id)
 
107
  prediction['words'] = extract_segment(
108
  words, prediction['start'], prediction['end'])
109
 
110
+ if classifier is not None:
111
+ predictions = filter_and_add_probabilities(
112
+ predictions, classifier, min_probability)
 
 
 
 
113
 
114
  return predictions
115
 
 
131
  return best_i, best_j, best_k
132
 
133
 
 
 
 
134
  def predict_sponsor_text(text, model, tokenizer):
135
  """Given a body of text, predict the words which are part of the sponsor"""
136
  model_device = next(model.parameters()).device
 
150
 
151
  def predict_sponsor_matches(text, model, tokenizer):
152
  sponsorship_text = predict_sponsor_text(text, model, tokenizer)
153
+ return extract_sponsor_matches(sponsorship_text)
 
 
 
 
154
 
155
 
156
  def segments_to_predictions(segments, model, tokenizer):
 
224
  hf_parser = HfArgumentParser((
225
  PredictArguments,
226
  SegmentationArguments,
 
227
  GeneralArguments
228
  ))
229
+ predict_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
230
 
231
  if not predict_args.video_ids:
232
  logger.error(
233
  'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
234
  return
235
 
236
+ model, tokenizer, classifier = get_model_tokenizer_classifier(
237
+ predict_args, general_args)
238
 
239
  for video_id in predict_args.video_ids:
 
240
  try:
241
+ predictions = predict(video_id, model, tokenizer, segmentation_args,
242
+ classifier=classifier,
243
+ min_probability=predict_args.min_probability)
244
  except TranscriptError:
245
  logger.warning(f'No transcript available for {video_id}')
246
  continue
src/preprocess.py CHANGED
@@ -1,14 +1,15 @@
 
1
  from utils import jaccard
2
  from functools import lru_cache
3
  from datetime import datetime
4
  import itertools
5
- from typing import Optional, List
6
- from model import ModelArguments
7
  import segment
8
  from tqdm import tqdm
9
  from dataclasses import dataclass, field
10
  from transformers import HfArgumentParser
11
- from shared import ACTION_OPTIONS, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
12
  import csv
13
  import re
14
  import random
@@ -213,9 +214,10 @@ def get_words(video_id, process=True, transcript_type='auto', fallback='manual',
213
  else:
214
  ts = transcript_list.find_generated_transcript(
215
  LANGUAGE_PREFERENCE_LIST)
216
-
217
- raw_transcript_json = ts._http_client.get(
218
- f'{ts._url}&fmt=json3').json()
 
219
 
220
  except (TooManyRequests, YouTubeRequestFailed):
221
  raise # Cannot recover from these errors and do not mark as empty transcript
@@ -386,9 +388,14 @@ class PreprocessArguments:
386
 
387
  max_date: str = field(
388
  # default='01/01/9999', # Include all
389
- default='02/02/2022',
390
  metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
391
 
 
 
 
 
 
392
  keep_duplicate_segments: bool = field(
393
  default=False, metadata={'help': 'Keep duplicate segments'}
394
  )
@@ -482,25 +489,7 @@ def download_file(url, filename):
482
 
483
 
484
  @dataclass
485
- class DatasetArguments:
486
- data_dir: Optional[str] = field(
487
- default='data',
488
- metadata={
489
- 'help': 'The directory which stores train, test and/or validation data.'
490
- },
491
- )
492
- processed_file: Optional[str] = field(
493
- default='segments.json',
494
- metadata={
495
- 'help': 'Processed data file'
496
- },
497
- )
498
- processed_database: Optional[str] = field(
499
- default='processed_database.json',
500
- metadata={
501
- 'help': 'Processed database file'
502
- },
503
- )
504
 
505
  train_file: Optional[str] = field(
506
  default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
@@ -508,21 +497,38 @@ class DatasetArguments:
508
  validation_file: Optional[str] = field(
509
  default='valid.json',
510
  metadata={
511
- 'help': 'An optional input evaluation data file to evaluate the metrics (rouge) on (a jsonlines file).'
512
  },
513
  )
514
  test_file: Optional[str] = field(
515
  default='test.json',
516
  metadata={
517
- 'help': 'An optional input test data file to evaluate the metrics (rouge) on (a jsonlines file).'
518
  },
519
  )
520
- excess_file: Optional[str] = field(
521
- default='excess.json',
 
 
 
 
522
  metadata={
523
- 'help': 'The excess segments left after the split'
524
  },
525
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
526
  dataset_cache_dir: Optional[str] = field(
527
  default=None,
528
  metadata={
@@ -555,9 +561,9 @@ def main():
555
  # Generate final.json from sponsorTimes.csv
556
  hf_parser = HfArgumentParser((
557
  PreprocessArguments,
558
- DatasetArguments,
559
  segment.SegmentationArguments,
560
- ModelArguments,
561
  GeneralArguments
562
  ))
563
  preprocess_args, dataset_args, segmentation_args, model_args, general_args = hf_parser.parse_args_into_dataclasses()
@@ -821,8 +827,7 @@ def main():
821
  # , max_videos, max_segments
822
 
823
  from model import get_model_tokenizer
824
- model, tokenizer = get_model_tokenizer(
825
- model_args.model_name_or_path, model_args.cache_dir, general_args.no_cuda)
826
 
827
  # TODO
828
  # count_videos = 0
@@ -871,8 +876,9 @@ def main():
871
  continue
872
 
873
  d = {
874
- 'video_index': offset + start_index,
875
  'video_id': video_id,
 
876
  'text': ' '.join(x['cleaned'] for x in seg),
877
  'start': seg_start,
878
  'end': seg_end,
@@ -919,7 +925,7 @@ def main():
919
  z = int(preprocess_args.percentage_positive /
920
  percentage_negative * len(non_sponsors))
921
 
922
- excess = sponsors[z:]
923
  sponsors = sponsors[:z]
924
 
925
  else:
@@ -927,7 +933,7 @@ def main():
927
  z = int(percentage_negative /
928
  preprocess_args.percentage_positive * len(sponsors))
929
 
930
- excess = non_sponsors[z:]
931
  non_sponsors = non_sponsors[:z]
932
 
933
  logger.info('Join')
@@ -935,6 +941,7 @@ def main():
935
 
936
  random.shuffle(all_labelled_segments)
937
 
 
938
  logger.info('Split')
939
  ratios = [preprocess_args.train_split,
940
  preprocess_args.test_split,
@@ -958,15 +965,52 @@ def main():
958
  else:
959
  logger.info(f'Skipping {name}')
960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  logger.info('Write')
962
  # Save excess items
963
- excess_path = os.path.join(
964
- dataset_args.data_dir, dataset_args.excess_file)
965
- if not os.path.exists(excess_path) or preprocess_args.overwrite:
966
- with open(excess_path, 'w', encoding='utf-8') as fp:
967
- fp.writelines(excess)
968
- else:
969
- logger.info(f'Skipping {dataset_args.excess_file}')
970
 
971
  logger.info(
972
  f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
 
1
+ from shared import DatasetArguments
2
  from utils import jaccard
3
  from functools import lru_cache
4
  from datetime import datetime
5
  import itertools
6
+ from typing import Optional
7
+ import model as model_module
8
  import segment
9
  from tqdm import tqdm
10
  from dataclasses import dataclass, field
11
  from transformers import HfArgumentParser
12
+ from shared import extract_sponsor_matches, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
13
  import csv
14
  import re
15
  import random
 
214
  else:
215
  ts = transcript_list.find_generated_transcript(
216
  LANGUAGE_PREFERENCE_LIST)
217
+ raw_transcript = ts._http_client.get(
218
+ f'{ts._url}&fmt=json3').content
219
+ if raw_transcript:
220
+ raw_transcript_json = json.loads(raw_transcript)
221
 
222
  except (TooManyRequests, YouTubeRequestFailed):
223
  raise # Cannot recover from these errors and do not mark as empty transcript
 
388
 
389
  max_date: str = field(
390
  # default='01/01/9999', # Include all
391
+ default='01/03/2022',
392
  metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
393
 
394
+ # max_unseen_date: str = field( # TODO
395
+ # default='02/03/2022',
396
+ # metadata={'help': 'Generate test and validation data from `max_date` to `max_unseen_date`'})
397
+ # Specify min/max video id for splitting (seen vs. unseen)
398
+
399
  keep_duplicate_segments: bool = field(
400
  default=False, metadata={'help': 'Keep duplicate segments'}
401
  )
 
489
 
490
 
491
  @dataclass
492
+ class PreprocessingDatasetArguments(DatasetArguments):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
493
 
494
  train_file: Optional[str] = field(
495
  default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
 
497
  validation_file: Optional[str] = field(
498
  default='valid.json',
499
  metadata={
500
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
501
  },
502
  )
503
  test_file: Optional[str] = field(
504
  default='test.json',
505
  metadata={
506
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
507
  },
508
  )
509
+
510
+ c_train_file: Optional[str] = field(
511
+ default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
512
+ )
513
+ c_validation_file: Optional[str] = field(
514
+ default='c_valid.json',
515
  metadata={
516
+ 'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
517
  },
518
  )
519
+ c_test_file: Optional[str] = field(
520
+ default='c_test.json',
521
+ metadata={
522
+ 'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
523
+ },
524
+ )
525
+
526
+ # excess_file: Optional[str] = field(
527
+ # default='excess.json',
528
+ # metadata={
529
+ # 'help': 'The excess segments left after the split'
530
+ # },
531
+ # )
532
  dataset_cache_dir: Optional[str] = field(
533
  default=None,
534
  metadata={
 
561
  # Generate final.json from sponsorTimes.csv
562
  hf_parser = HfArgumentParser((
563
  PreprocessArguments,
564
+ PreprocessingDatasetArguments,
565
  segment.SegmentationArguments,
566
+ model_module.ModelArguments,
567
  GeneralArguments
568
  ))
569
  preprocess_args, dataset_args, segmentation_args, model_args, general_args = hf_parser.parse_args_into_dataclasses()
 
827
  # , max_videos, max_segments
828
 
829
  from model import get_model_tokenizer
830
+ model, tokenizer = get_model_tokenizer(model_args, general_args)
 
831
 
832
  # TODO
833
  # count_videos = 0
 
876
  continue
877
 
878
  d = {
879
+ # 'video_index': offset + start_index,
880
  'video_id': video_id,
881
+ # 'uuid': video_id, # TODO add uuid
882
  'text': ' '.join(x['cleaned'] for x in seg),
883
  'start': seg_start,
884
  'end': seg_end,
 
925
  z = int(preprocess_args.percentage_positive /
926
  percentage_negative * len(non_sponsors))
927
 
928
+ # excess = sponsors[z:]
929
  sponsors = sponsors[:z]
930
 
931
  else:
 
933
  z = int(percentage_negative /
934
  preprocess_args.percentage_positive * len(sponsors))
935
 
936
+ # excess = non_sponsors[z:]
937
  non_sponsors = non_sponsors[:z]
938
 
939
  logger.info('Join')
 
941
 
942
  random.shuffle(all_labelled_segments)
943
 
944
+ # TODO split based on video ids
945
  logger.info('Split')
946
  ratios = [preprocess_args.train_split,
947
  preprocess_args.test_split,
 
965
  else:
966
  logger.info(f'Skipping {name}')
967
 
968
+ classifier_splits = {
969
+ dataset_args.c_train_file: train_data,
970
+ dataset_args.c_test_file: test_data,
971
+ dataset_args.c_validation_file: valid_data
972
+ }
973
+
974
+ none_category = CATEGORIES.index(None)
975
+
976
+ # Output training, testing and validation data
977
+ for name, items in classifier_splits.items():
978
+ outfile = os.path.join(dataset_args.data_dir, name)
979
+ if not os.path.exists(outfile) or preprocess_args.overwrite:
980
+ with open(outfile, 'w', encoding='utf-8') as fp:
981
+ for i in items:
982
+ x = json.loads(i) # TODO add uuid
983
+ labelled_items = []
984
+
985
+ matches = extract_sponsor_matches(x['extracted'])
986
+
987
+ if x['extracted'] == CustomTokens.NO_SEGMENT.value:
988
+ labelled_items.append({
989
+ 'text': x['text'],
990
+ 'label': none_category
991
+ })
992
+ else:
993
+ for match in matches:
994
+ labelled_items.append({
995
+ 'text': match['text'],
996
+ 'label': CATEGORIES.index(match['category'])
997
+ })
998
+
999
+ for labelled_item in labelled_items:
1000
+ print(json.dumps(labelled_item), file=fp)
1001
+
1002
+ else:
1003
+ logger.info(f'Skipping {name}')
1004
+
1005
  logger.info('Write')
1006
  # Save excess items
1007
+ # excess_path = os.path.join(
1008
+ # dataset_args.data_dir, dataset_args.excess_file)
1009
+ # if not os.path.exists(excess_path) or preprocess_args.overwrite:
1010
+ # with open(excess_path, 'w', encoding='utf-8') as fp:
1011
+ # fp.writelines(excess)
1012
+ # else:
1013
+ # logger.info(f'Skipping {dataset_args.excess_file}')
1014
 
1015
  logger.info(
1016
  f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
src/segment.py CHANGED
@@ -121,6 +121,8 @@ def generate_segments(words, tokenizer, segmentation_args):
121
 
122
  def extract_segment(words, start, end, map_function=None):
123
  """Extracts all words with time in [start, end]"""
 
 
124
 
125
  a = max(binary_search_below(words, 0, len(words), start), 0)
126
  b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
 
121
 
122
  def extract_segment(words, start, end, map_function=None):
123
  """Extracts all words with time in [start, end]"""
124
+ if words is None:
125
+ words = []
126
 
127
  a = max(binary_search_below(words, 0, len(words), start), 0)
128
  b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
src/shared.py CHANGED
@@ -1,3 +1,10 @@
 
 
 
 
 
 
 
1
  import re
2
  import gc
3
  from time import time_ns
@@ -8,6 +15,8 @@ from typing import Optional
8
  from dataclasses import dataclass, field
9
  from enum import Enum
10
 
 
 
11
  ACTION_OPTIONS = ['skip', 'mute', 'full']
12
 
13
  CATGEGORY_OPTIONS = {
@@ -62,6 +71,47 @@ class CustomTokens(Enum):
62
  tokenizer.add_tokens(cls.custom_tokens())
63
 
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  @dataclass
66
  class OutputArguments:
67
 
@@ -126,3 +176,106 @@ def reset():
126
  torch.cuda.empty_cache()
127
  gc.collect()
128
  print(torch.cuda.memory_summary(device=None, abbreviated=False))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers.trainer_utils import get_last_checkpoint as glc
2
+ from transformers import TrainingArguments
3
+ import os
4
+ from utils import re_findall
5
+ import logging
6
+ import sys
7
+ from datasets import load_dataset
8
  import re
9
  import gc
10
  from time import time_ns
 
15
  from dataclasses import dataclass, field
16
  from enum import Enum
17
 
18
+ CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
19
+
20
  ACTION_OPTIONS = ['skip', 'mute', 'full']
21
 
22
  CATGEGORY_OPTIONS = {
 
71
  tokenizer.add_tokens(cls.custom_tokens())
72
 
73
 
74
+ _SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
75
+ _SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
76
+ SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
77
+
78
+
79
+ def extract_sponsor_matches(text):
80
+ if CustomTokens.NO_SEGMENT.value in text:
81
+ return []
82
+
83
+ return re_findall(SEGMENT_MATCH_RE, text)
84
+
85
+
86
+ @dataclass
87
+ class DatasetArguments:
88
+ data_dir: Optional[str] = field(
89
+ default='data',
90
+ metadata={
91
+ 'help': 'The directory which stores train, test and/or validation data.'
92
+ },
93
+ )
94
+ processed_file: Optional[str] = field(
95
+ default='segments.json',
96
+ metadata={
97
+ 'help': 'Processed data file'
98
+ },
99
+ )
100
+ processed_database: Optional[str] = field(
101
+ default='processed_database.json',
102
+ metadata={
103
+ 'help': 'Processed database file'
104
+ },
105
+ )
106
+
107
+ dataset_cache_dir: Optional[str] = field(
108
+ default=None,
109
+ metadata={
110
+ 'help': 'Where to store the cached datasets'
111
+ },
112
+ )
113
+
114
+
115
  @dataclass
116
  class OutputArguments:
117
 
 
176
  torch.cuda.empty_cache()
177
  gc.collect()
178
  print(torch.cuda.memory_summary(device=None, abbreviated=False))
179
+
180
+
181
+ def load_datasets(dataset_args):
182
+
183
+ print('Reading datasets')
184
+ data_files = {}
185
+
186
+ if dataset_args.train_file is not None:
187
+ data_files['train'] = os.path.join(
188
+ dataset_args.data_dir, dataset_args.train_file)
189
+ if dataset_args.validation_file is not None:
190
+ data_files['validation'] = os.path.join(
191
+ dataset_args.data_dir, dataset_args.validation_file)
192
+ if dataset_args.test_file is not None:
193
+ data_files['test'] = os.path.join(
194
+ dataset_args.data_dir, dataset_args.test_file)
195
+
196
+ return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
197
+
198
+
199
+ @dataclass
200
+ class CustomTrainingArguments(OutputArguments, TrainingArguments):
201
+ seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
202
+
203
+ num_train_epochs: float = field(
204
+ default=1, metadata={'help': 'Total number of training epochs to perform.'})
205
+
206
+ save_steps: int = field(default=5000, metadata={
207
+ 'help': 'Save checkpoint every X updates steps.'})
208
+ eval_steps: int = field(default=5000, metadata={
209
+ 'help': 'Run an evaluation every X steps.'})
210
+ logging_steps: int = field(default=5000, metadata={
211
+ 'help': 'Log every X updates steps.'})
212
+
213
+ # do_eval: bool = field(default=False, metadata={
214
+ # 'help': 'Whether to run eval on the dev set.'})
215
+ # do_predict: bool = field(default=False, metadata={
216
+ # 'help': 'Whether to run predictions on the test set.'})
217
+
218
+ per_device_train_batch_size: int = field(
219
+ default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
220
+ )
221
+ per_device_eval_batch_size: int = field(
222
+ default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
223
+ )
224
+
225
+ # report_to: Optional[List[str]] = field(
226
+ # default=None, metadata={"help": "The list of integrations to report the results and logs to."}
227
+ # )
228
+ evaluation_strategy: str = field(
229
+ default='steps',
230
+ metadata={
231
+ 'help': 'The evaluation strategy to use.',
232
+ 'choices': ['no', 'steps', 'epoch']
233
+ },
234
+ )
235
+
236
+ # evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
237
+ # The evaluation strategy to adopt during training. Possible values are:
238
+
239
+ # * :obj:`"no"`: No evaluation is done during training.
240
+ # * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
241
+ # * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
242
+
243
+
244
+ logging.basicConfig()
245
+ logger = logging.getLogger(__name__)
246
+
247
+ # Setup logging
248
+ logging.basicConfig(
249
+ format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
250
+ datefmt='%m/%d/%Y %H:%M:%S',
251
+ handlers=[logging.StreamHandler(sys.stdout)],
252
+ )
253
+
254
+
255
+ def get_last_checkpoint(training_args):
256
+ last_checkpoint = None
257
+ if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
258
+ last_checkpoint = glc(training_args.output_dir)
259
+ if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
260
+ raise ValueError(
261
+ f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
262
+ )
263
+ elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
264
+ logger.info(
265
+ f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
266
+ )
267
+ return last_checkpoint
268
+
269
+
270
+ def train_from_checkpoint(trainer, last_checkpoint, training_args):
271
+ checkpoint = None
272
+ if training_args.resume_from_checkpoint is not None:
273
+ checkpoint = training_args.resume_from_checkpoint
274
+ elif last_checkpoint is not None:
275
+ checkpoint = last_checkpoint
276
+
277
+ train_result = trainer.train(resume_from_checkpoint=checkpoint)
278
+
279
+ trainer.save_model() # Saves the tokenizer too for easy upload
280
+
281
+ return train_result
src/train.py CHANGED
@@ -1,7 +1,5 @@
1
- from datasets import load_dataset
2
- from preprocess import DatasetArguments
3
- from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
4
- from shared import CustomTokens, GeneralArguments, OutputArguments
5
  from model import ModelArguments
6
  import transformers
7
  import logging
@@ -9,21 +7,15 @@ import os
9
  import sys
10
  from dataclasses import dataclass, field
11
  from typing import Optional
12
- import datasets
13
- import pickle
14
  from transformers import (
15
  DataCollatorForSeq2Seq,
16
  HfArgumentParser,
17
  Seq2SeqTrainer,
18
- Seq2SeqTrainingArguments
19
  )
20
 
21
- from transformers.trainer_utils import get_last_checkpoint
22
  from transformers.utils import check_min_version
23
  from transformers.utils.versions import require_version
24
- from sklearn.linear_model import LogisticRegression
25
- from sklearn.feature_extraction.text import TfidfVectorizer
26
- from utils import re_findall
27
 
28
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
29
  check_min_version('4.13.0.dev0')
@@ -43,23 +35,6 @@ logging.basicConfig(
43
  )
44
 
45
 
46
- def load_datasets(dataset_args):
47
-
48
- print('Reading datasets')
49
- data_files = {}
50
-
51
- if dataset_args.train_file is not None:
52
- data_files['train'] = os.path.join(
53
- dataset_args.data_dir, dataset_args.train_file)
54
- if dataset_args.validation_file is not None:
55
- data_files['validation'] = os.path.join(
56
- dataset_args.data_dir, dataset_args.validation_file)
57
- if dataset_args.test_file is not None:
58
- data_files['test'] = os.path.join(
59
- dataset_args.data_dir, dataset_args.test_file)
60
-
61
- return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
62
-
63
 
64
  @dataclass
65
  class DataTrainingArguments:
@@ -92,58 +67,7 @@ class DataTrainingArguments:
92
  )
93
 
94
 
95
- @dataclass
96
- class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
97
- seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
98
-
99
- num_train_epochs: float = field(
100
- default=1, metadata={'help': 'Total number of training epochs to perform.'})
101
-
102
- save_steps: int = field(default=5000, metadata={
103
- 'help': 'Save checkpoint every X updates steps.'})
104
- eval_steps: int = field(default=5000, metadata={
105
- 'help': 'Run an evaluation every X steps.'})
106
- logging_steps: int = field(default=5000, metadata={
107
- 'help': 'Log every X updates steps.'})
108
-
109
- skip_train_transformer: bool = field(default=False, metadata={
110
- 'help': 'Whether to skip training the transformer.'})
111
- train_classifier: bool = field(default=False, metadata={
112
- 'help': 'Whether to run training on the 2nd phase (classifier).'})
113
-
114
- # do_eval: bool = field(default=False, metadata={
115
- # 'help': 'Whether to run eval on the dev set.'})
116
- do_predict: bool = field(default=False, metadata={
117
- 'help': 'Whether to run predictions on the test set.'})
118
-
119
- per_device_train_batch_size: int = field(
120
- default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
121
- )
122
- per_device_eval_batch_size: int = field(
123
- default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
124
- )
125
-
126
- # report_to: Optional[List[str]] = field(
127
- # default=None, metadata={"help": "The list of integrations to report the results and logs to."}
128
- # )
129
- evaluation_strategy: str = field(
130
- default='steps',
131
- metadata={
132
- 'help': 'The evaluation strategy to use.',
133
- 'choices': ['no', 'steps', 'epoch']
134
- },
135
- )
136
-
137
- # evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
138
- # The evaluation strategy to adopt during training. Possible values are:
139
-
140
- # * :obj:`"no"`: No evaluation is done during training.
141
- # * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
142
- # * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
143
-
144
-
145
  def main():
146
- # reset()
147
 
148
  # See all possible arguments in src/transformers/training_args.py
149
  # or by passing the --help flag to this script.
@@ -151,16 +75,15 @@ def main():
151
 
152
  hf_parser = HfArgumentParser((
153
  ModelArguments,
154
- DatasetArguments,
155
  DataTrainingArguments,
156
- SequenceTrainingArguments,
157
- ClassifierArguments
158
  ))
159
- model_args, dataset_args, data_training_args, training_args, classifier_args = hf_parser.parse_args_into_dataclasses()
160
 
161
  log_level = training_args.get_process_log_level()
162
  logger.setLevel(log_level)
163
- datasets.utils.logging.set_verbosity(log_level)
164
  transformers.utils.logging.set_verbosity(log_level)
165
  transformers.utils.logging.enable_default_handler()
166
  transformers.utils.logging.enable_explicit_format()
@@ -199,231 +122,130 @@ def main():
199
 
200
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
201
  # download the dataset.
202
- if training_args.skip_train_transformer and not training_args.train_classifier:
203
- print('Nothing to do. Exiting')
204
- return
205
-
206
  raw_datasets = load_datasets(dataset_args)
207
  # , cache_dir=model_args.cache_dir
208
 
209
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
210
  # https://huggingface.co/docs/datasets/loading_datasets.html.
211
 
212
- if training_args.train_classifier:
213
- print('Train classifier')
214
- # 1. Vectorize raw data to pass into classifier
215
- # CountVectorizer TfidfVectorizer
216
- # TfidfVectorizer - better (comb of CountVectorizer)
217
- vectorizer = TfidfVectorizer( # CountVectorizer
218
- # lowercase=False,
219
- # stop_words='english', # TODO optimise stop words?
220
- # stop_words=stop_words,
221
-
222
- ngram_range=(1, 2), # best so far
223
- # max_features=8000 # remove for higher accuracy?
224
- max_features=20000
225
- # max_features=10000
226
- # max_features=1000
227
- )
228
 
229
- train_test_data = {
230
- 'train': {
231
- 'X': [],
232
- 'y': []
233
- },
234
- 'test': {
235
- 'X': [],
236
- 'y': []
237
- }
238
- }
239
-
240
- print('Splitting')
241
- for ds_type in train_test_data:
242
- dataset = raw_datasets[ds_type]
243
-
244
- for row in dataset:
245
- matches = re_findall(SEGMENT_MATCH_RE, row['extracted'])
246
- if matches:
247
- for match in matches:
248
- train_test_data[ds_type]['X'].append(match['text'])
249
-
250
- class_index = CATEGORIES.index(match['category'])
251
- train_test_data[ds_type]['y'].append(class_index)
252
-
253
- else:
254
- train_test_data[ds_type]['X'].append(row['text'])
255
- train_test_data[ds_type]['y'].append(0)
256
-
257
- print('Fitting')
258
- _X_train = vectorizer.fit_transform(train_test_data['train']['X'])
259
- _X_test = vectorizer.transform(train_test_data['test']['X'])
260
-
261
- y_train = train_test_data['train']['y']
262
- y_test = train_test_data['test']['y']
263
-
264
- # 2. Create classifier
265
- classifier = LogisticRegression(max_iter=2000, class_weight='balanced')
266
-
267
- # 3. Fit data
268
- print('Fit classifier')
269
- classifier.fit(_X_train, y_train)
270
-
271
- # 4. Measure accuracy
272
- accuracy = classifier.score(_X_test, y_test)
273
-
274
- print(f'[LogisticRegression] Accuracy percent:',
275
- round(accuracy*100, 3))
276
-
277
- # 5. Save classifier and vectorizer
278
- with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'wb') as fp:
279
- pickle.dump(classifier, fp)
280
-
281
- with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'wb') as fp:
282
- pickle.dump(vectorizer, fp)
283
-
284
- if not training_args.skip_train_transformer:
285
- # Detecting last checkpoint.
286
- last_checkpoint = None
287
- if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
288
- last_checkpoint = get_last_checkpoint(training_args.output_dir)
289
- if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
290
- raise ValueError(
291
- f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
292
- )
293
- elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
294
- logger.info(
295
- f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
296
- )
297
-
298
- from model import get_model_tokenizer
299
- model, tokenizer = get_model_tokenizer(
300
- model_args.model_name_or_path, model_args.cache_dir, training_args.no_cuda)
301
-
302
- # Preprocessing the datasets.
303
- # We need to tokenize inputs and targets.
304
- column_names = raw_datasets['train'].column_names
305
-
306
- prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
307
-
308
- PAD_TOKEN_REPLACE_ID = -100
309
-
310
- # https://github.com/huggingface/transformers/issues/5204
311
- def preprocess_function(examples):
312
- inputs = examples['text']
313
- targets = examples['extracted']
314
- inputs = [prefix + inp for inp in inputs]
315
- model_inputs = tokenizer(inputs, truncation=True)
316
-
317
- # Setup the tokenizer for targets
318
- with tokenizer.as_target_tokenizer():
319
- labels = tokenizer(targets, truncation=True)
320
-
321
- # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
322
- # when we want to ignore padding in the loss.
323
-
324
- model_inputs['labels'] = [
325
- [(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
326
- for l in label]
327
- for label in labels['input_ids']
328
- ]
329
-
330
- return model_inputs
331
-
332
- def prepare_dataset(dataset, desc):
333
- return dataset.map(
334
- preprocess_function,
335
- batched=True,
336
- num_proc=data_training_args.preprocessing_num_workers,
337
- remove_columns=column_names,
338
- load_from_cache_file=not dataset_args.overwrite_cache,
339
- desc=desc, # tokenizing train dataset
340
- )
341
- # train_dataset # TODO shuffle?
342
-
343
- # if training_args.do_train:
344
- if 'train' not in raw_datasets: # TODO do checks above?
345
- raise ValueError('Train dataset missing')
346
- train_dataset = raw_datasets['train']
347
- if data_training_args.max_train_samples is not None:
348
- train_dataset = train_dataset.select(
349
- range(data_training_args.max_train_samples))
350
- with training_args.main_process_first(desc='train dataset map pre-processing'):
351
- train_dataset = prepare_dataset(
352
- train_dataset, desc='Running tokenizer on train dataset')
353
-
354
- if 'validation' not in raw_datasets:
355
- raise ValueError('Validation dataset missing')
356
- eval_dataset = raw_datasets['validation']
357
- if data_training_args.max_eval_samples is not None:
358
- eval_dataset = eval_dataset.select(
359
- range(data_training_args.max_eval_samples))
360
- with training_args.main_process_first(desc='validation dataset map pre-processing'):
361
- eval_dataset = prepare_dataset(
362
- eval_dataset, desc='Running tokenizer on validation dataset')
363
-
364
- if 'test' not in raw_datasets:
365
- raise ValueError('Test dataset missing')
366
- predict_dataset = raw_datasets['test']
367
- if data_training_args.max_predict_samples is not None:
368
- predict_dataset = predict_dataset.select(
369
- range(data_training_args.max_predict_samples))
370
- with training_args.main_process_first(desc='prediction dataset map pre-processing'):
371
- predict_dataset = prepare_dataset(
372
- predict_dataset, desc='Running tokenizer on prediction dataset')
373
-
374
- # Data collator
375
- data_collator = DataCollatorForSeq2Seq(
376
- tokenizer,
377
- model=model,
378
- label_pad_token_id=PAD_TOKEN_REPLACE_ID,
379
- pad_to_multiple_of=8 if training_args.fp16 else None,
380
- )
381
 
382
- # Done processing datasets
383
 
384
- # Initialize our Trainer
385
- trainer = Seq2SeqTrainer(
386
- model=model,
387
- args=training_args,
388
- train_dataset=train_dataset,
389
- eval_dataset=eval_dataset,
390
- tokenizer=tokenizer,
391
- data_collator=data_collator,
392
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
393
 
394
- # Training
395
- checkpoint = None
396
- if training_args.resume_from_checkpoint is not None:
397
- checkpoint = training_args.resume_from_checkpoint
398
- elif last_checkpoint is not None:
399
- checkpoint = last_checkpoint
400
-
401
- try:
402
- train_result = trainer.train(resume_from_checkpoint=checkpoint)
403
- trainer.save_model() # Saves the tokenizer too for easy upload
404
- except KeyboardInterrupt:
405
- # TODO add option to save model on interrupt?
406
- # print('Saving model')
407
- # trainer.save_model(os.path.join(
408
- # training_args.output_dir, 'checkpoint-latest')) # TODO use dir
409
- raise
410
-
411
- metrics = train_result.metrics
412
- max_train_samples = data_training_args.max_train_samples or len(
413
- train_dataset)
414
- metrics['train_samples'] = min(max_train_samples, len(train_dataset))
415
-
416
- trainer.log_metrics('train', metrics)
417
- trainer.save_metrics('train', metrics)
418
- trainer.save_state()
419
-
420
- kwargs = {'finetuned_from': model_args.model_name_or_path,
421
- 'tasks': 'summarization'}
422
-
423
- if training_args.push_to_hub:
424
- trainer.push_to_hub(**kwargs)
425
- else:
426
- trainer.create_model_card(**kwargs)
427
 
428
 
429
  if __name__ == '__main__':
 
1
+ from preprocess import PreprocessingDatasetArguments
2
+ from shared import CustomTokens, load_datasets, CustomTrainingArguments, get_last_checkpoint, train_from_checkpoint
 
 
3
  from model import ModelArguments
4
  import transformers
5
  import logging
 
7
  import sys
8
  from dataclasses import dataclass, field
9
  from typing import Optional
10
+ from datasets import utils as d_utils
 
11
  from transformers import (
12
  DataCollatorForSeq2Seq,
13
  HfArgumentParser,
14
  Seq2SeqTrainer,
 
15
  )
16
 
 
17
  from transformers.utils import check_min_version
18
  from transformers.utils.versions import require_version
 
 
 
19
 
20
  # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
21
  check_min_version('4.13.0.dev0')
 
35
  )
36
 
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  @dataclass
40
  class DataTrainingArguments:
 
67
  )
68
 
69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
  def main():
 
71
 
72
  # See all possible arguments in src/transformers/training_args.py
73
  # or by passing the --help flag to this script.
 
75
 
76
  hf_parser = HfArgumentParser((
77
  ModelArguments,
78
+ PreprocessingDatasetArguments,
79
  DataTrainingArguments,
80
+ CustomTrainingArguments
 
81
  ))
82
+ model_args, dataset_args, data_training_args, training_args = hf_parser.parse_args_into_dataclasses()
83
 
84
  log_level = training_args.get_process_log_level()
85
  logger.setLevel(log_level)
86
+ d_utils.logging.set_verbosity(log_level)
87
  transformers.utils.logging.set_verbosity(log_level)
88
  transformers.utils.logging.enable_default_handler()
89
  transformers.utils.logging.enable_explicit_format()
 
122
 
123
  # In distributed training, the load_dataset function guarantees that only one local process can concurrently
124
  # download the dataset.
 
 
 
 
125
  raw_datasets = load_datasets(dataset_args)
126
  # , cache_dir=model_args.cache_dir
127
 
128
  # See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
129
  # https://huggingface.co/docs/datasets/loading_datasets.html.
130
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
+ # Detecting last checkpoint.
133
+ last_checkpoint = get_last_checkpoint(training_args)
134
+
135
+ from model import get_model_tokenizer
136
+ model, tokenizer = get_model_tokenizer(model_args, training_args)
137
+
138
+ # Preprocessing the datasets.
139
+ # We need to tokenize inputs and targets.
140
+ column_names = raw_datasets['train'].column_names
141
+
142
+ prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
143
+
144
+ PAD_TOKEN_REPLACE_ID = -100
145
+
146
+ # https://github.com/huggingface/transformers/issues/5204
147
+ def preprocess_function(examples):
148
+ inputs = examples['text']
149
+ targets = examples['extracted']
150
+ inputs = [prefix + inp for inp in inputs]
151
+ model_inputs = tokenizer(inputs, truncation=True)
152
+
153
+ # Setup the tokenizer for targets
154
+ with tokenizer.as_target_tokenizer():
155
+ labels = tokenizer(targets, truncation=True)
156
+
157
+ # If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
158
+ # when we want to ignore padding in the loss.
159
+
160
+ model_inputs['labels'] = [
161
+ [(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
162
+ for l in label]
163
+ for label in labels['input_ids']
164
+ ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
+ return model_inputs
167
 
168
+ def prepare_dataset(dataset, desc):
169
+ return dataset.map(
170
+ preprocess_function,
171
+ batched=True,
172
+ num_proc=data_training_args.preprocessing_num_workers,
173
+ remove_columns=column_names,
174
+ load_from_cache_file=not dataset_args.overwrite_cache,
175
+ desc=desc, # tokenizing train dataset
176
  )
177
+ # train_dataset # TODO shuffle?
178
+
179
+ # if training_args.do_train:
180
+ if 'train' not in raw_datasets: # TODO do checks above?
181
+ raise ValueError('Train dataset missing')
182
+ train_dataset = raw_datasets['train']
183
+ if data_training_args.max_train_samples is not None:
184
+ train_dataset = train_dataset.select(
185
+ range(data_training_args.max_train_samples))
186
+ with training_args.main_process_first(desc='train dataset map pre-processing'):
187
+ train_dataset = prepare_dataset(
188
+ train_dataset, desc='Running tokenizer on train dataset')
189
+
190
+ if 'validation' not in raw_datasets:
191
+ raise ValueError('Validation dataset missing')
192
+ eval_dataset = raw_datasets['validation']
193
+ if data_training_args.max_eval_samples is not None:
194
+ eval_dataset = eval_dataset.select(
195
+ range(data_training_args.max_eval_samples))
196
+ with training_args.main_process_first(desc='validation dataset map pre-processing'):
197
+ eval_dataset = prepare_dataset(
198
+ eval_dataset, desc='Running tokenizer on validation dataset')
199
+
200
+ if 'test' not in raw_datasets:
201
+ raise ValueError('Test dataset missing')
202
+ predict_dataset = raw_datasets['test']
203
+ if data_training_args.max_predict_samples is not None:
204
+ predict_dataset = predict_dataset.select(
205
+ range(data_training_args.max_predict_samples))
206
+ with training_args.main_process_first(desc='prediction dataset map pre-processing'):
207
+ predict_dataset = prepare_dataset(
208
+ predict_dataset, desc='Running tokenizer on prediction dataset')
209
+
210
+ # Data collator
211
+ data_collator = DataCollatorForSeq2Seq(
212
+ tokenizer,
213
+ model=model,
214
+ label_pad_token_id=PAD_TOKEN_REPLACE_ID,
215
+ pad_to_multiple_of=8 if training_args.fp16 else None,
216
+ )
217
+
218
+ # Done processing datasets
219
+
220
+ # Initialize our Trainer
221
+ trainer = Seq2SeqTrainer(
222
+ model=model,
223
+ args=training_args,
224
+ train_dataset=train_dataset,
225
+ eval_dataset=eval_dataset,
226
+ tokenizer=tokenizer,
227
+ data_collator=data_collator,
228
+ )
229
+
230
+ # Training
231
+ train_result = train_from_checkpoint(trainer, last_checkpoint, training_args)
232
+
233
+ metrics = train_result.metrics
234
+ max_train_samples = data_training_args.max_train_samples or len(
235
+ train_dataset)
236
+ metrics['train_samples'] = min(max_train_samples, len(train_dataset))
237
+
238
+ trainer.log_metrics('train', metrics)
239
+ trainer.save_metrics('train', metrics)
240
+ trainer.save_state()
241
+
242
+ kwargs = {'finetuned_from': model_args.model_name_or_path,
243
+ 'tasks': 'summarization'}
244
 
245
+ if training_args.push_to_hub:
246
+ trainer.push_to_hub(**kwargs)
247
+ else:
248
+ trainer.create_model_card(**kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
249
 
250
 
251
  if __name__ == '__main__':
src/train_classifier.py ADDED
@@ -0,0 +1,287 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ """ Finetuning the library models for sequence classification."""
3
+
4
+ import logging
5
+ import os
6
+ import random
7
+ import sys
8
+ from dataclasses import dataclass, field
9
+ from typing import Optional
10
+
11
+ import datasets
12
+ import numpy as np
13
+ from datasets import load_metric
14
+
15
+ import transformers
16
+ from transformers import (
17
+ DataCollatorWithPadding,
18
+ EvalPrediction,
19
+ HfArgumentParser,
20
+ Trainer,
21
+ default_data_collator,
22
+ set_seed,
23
+ )
24
+ from transformers.utils import check_min_version
25
+ from transformers.utils.versions import require_version
26
+ from shared import CATEGORIES, load_datasets, CustomTrainingArguments, train_from_checkpoint, get_last_checkpoint
27
+ from preprocess import PreprocessingDatasetArguments
28
+ from model import get_model_tokenizer, ModelArguments
29
+
30
+ # Will error if the minimal version of Transformers is not installed. Remove at your own risks.
31
+ check_min_version("4.17.0")
32
+ require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
33
+
34
+ os.environ["WANDB_DISABLED"] = "true"
35
+
36
+ logger = logging.getLogger(__name__)
37
+
38
+
39
+ @dataclass
40
+ class DataArguments:
41
+ """
42
+ Arguments pertaining to what data we are going to input our model for training and eval.
43
+
44
+ Using `HfArgumentParser` we can turn this class
45
+ into argparse arguments to be able to specify them on
46
+ the command line.
47
+ """
48
+
49
+ max_seq_length: int = field(
50
+ default=512,
51
+ metadata={
52
+ "help": "The maximum total input sequence length after tokenization. Sequences longer "
53
+ "than this will be truncated, sequences shorter will be padded."
54
+ },
55
+ )
56
+ overwrite_cache: bool = field(
57
+ default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
58
+ )
59
+ pad_to_max_length: bool = field(
60
+ default=True,
61
+ metadata={
62
+ "help": "Whether to pad all samples to `max_seq_length`. "
63
+ "If False, will pad the samples dynamically when batching to the maximum length in the batch."
64
+ },
65
+ )
66
+ max_train_samples: Optional[int] = field(
67
+ default=None,
68
+ metadata={
69
+ "help": "For debugging purposes or quicker training, truncate the number of training examples to this "
70
+ "value if set."
71
+ },
72
+ )
73
+ max_eval_samples: Optional[int] = field(
74
+ default=None,
75
+ metadata={
76
+ "help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
77
+ "value if set."
78
+ },
79
+ )
80
+ max_predict_samples: Optional[int] = field(
81
+ default=None,
82
+ metadata={
83
+ "help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
84
+ "value if set."
85
+ },
86
+ )
87
+
88
+ dataset_cache_dir: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
89
+ 'dataset_cache_dir']
90
+ data_dir: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
91
+ 'data_dir']
92
+ train_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
93
+ 'c_train_file']
94
+ validation_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
95
+ 'c_validation_file']
96
+ test_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
97
+ 'c_test_file']
98
+
99
+ def __post_init__(self):
100
+ if self.train_file is None or self.validation_file is None:
101
+ raise ValueError(
102
+ "Need either a GLUE task, a training/validation file or a dataset name.")
103
+ else:
104
+ train_extension = self.train_file.split(".")[-1]
105
+ assert train_extension in [
106
+ "csv", "json"], "`train_file` should be a csv or a json file."
107
+ validation_extension = self.validation_file.split(".")[-1]
108
+ assert (
109
+ validation_extension == train_extension
110
+ ), "`validation_file` should have the same extension (csv or json) as `train_file`."
111
+
112
+
113
+ def main():
114
+ # See all possible arguments in src/transformers/training_args.py
115
+ # or by passing the --help flag to this script.
116
+ # We now keep distinct sets of args, for a cleaner separation of concerns.
117
+
118
+ parser = HfArgumentParser(
119
+ (ModelArguments, DataArguments, CustomTrainingArguments))
120
+ model_args, data_args, training_args = parser.parse_args_into_dataclasses()
121
+
122
+ # Setup logging
123
+ logging.basicConfig(
124
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
125
+ datefmt="%m/%d/%Y %H:%M:%S",
126
+ handlers=[logging.StreamHandler(sys.stdout)],
127
+ )
128
+
129
+ log_level = training_args.get_process_log_level()
130
+ logger.setLevel(log_level)
131
+ datasets.utils.logging.set_verbosity(log_level)
132
+ transformers.utils.logging.set_verbosity(log_level)
133
+ transformers.utils.logging.enable_default_handler()
134
+ transformers.utils.logging.enable_explicit_format()
135
+
136
+ # Log on each process the small summary:
137
+ logger.warning(
138
+ f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
139
+ + f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
140
+ )
141
+ logger.info(f"Training/evaluation parameters {training_args}")
142
+
143
+ # Detecting last checkpoint.
144
+ last_checkpoint = get_last_checkpoint(training_args)
145
+
146
+ # Set seed before initializing model.
147
+ set_seed(training_args.seed)
148
+
149
+ # Loading a dataset from your local files.
150
+ # CSV/JSON training and evaluation files are needed.
151
+ raw_datasets = load_datasets(data_args)
152
+
153
+ # See more about loading any type of standard or custom dataset at
154
+ # https://huggingface.co/docs/datasets/loading_datasets.html.
155
+
156
+ config_args = {
157
+ 'num_labels': len(CATEGORIES),
158
+ 'id2label': {k: str(v).upper() for k, v in enumerate(CATEGORIES)},
159
+ 'label2id': {str(v).upper(): k for k, v in enumerate(CATEGORIES)}
160
+ }
161
+ model, tokenizer = get_model_tokenizer(model_args, training_args, config_args=config_args, model_type='classifier')
162
+
163
+
164
+ # Padding strategy
165
+ if data_args.pad_to_max_length:
166
+ padding = "max_length"
167
+ else:
168
+ # We will pad later, dynamically at batch creation, to the max sequence length in each batch
169
+ padding = False
170
+
171
+ if data_args.max_seq_length > tokenizer.model_max_length:
172
+ logger.warning(
173
+ f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
174
+ f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
175
+ )
176
+ max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
177
+
178
+ def preprocess_function(examples):
179
+ # Tokenize the texts
180
+ result = tokenizer(
181
+ examples['text'], padding=padding, max_length=max_seq_length, truncation=True)
182
+ result['label'] = examples['label']
183
+ return result
184
+
185
+ with training_args.main_process_first(desc="dataset map pre-processing"):
186
+ raw_datasets = raw_datasets.map(
187
+ preprocess_function,
188
+ batched=True,
189
+ load_from_cache_file=not data_args.overwrite_cache,
190
+ desc="Running tokenizer on dataset",
191
+ )
192
+ if training_args.do_train:
193
+ if "train" not in raw_datasets:
194
+ raise ValueError("--do_train requires a train dataset")
195
+ train_dataset = raw_datasets["train"]
196
+ if data_args.max_train_samples is not None:
197
+ train_dataset = train_dataset.select(
198
+ range(data_args.max_train_samples))
199
+
200
+ if training_args.do_eval:
201
+ if "validation" not in raw_datasets:
202
+ raise ValueError("--do_eval requires a validation dataset")
203
+ eval_dataset = raw_datasets["validation"]
204
+ if data_args.max_eval_samples is not None:
205
+ eval_dataset = eval_dataset.select(
206
+ range(data_args.max_eval_samples))
207
+
208
+ if training_args.do_predict or data_args.test_file is not None:
209
+ if "test" not in raw_datasets:
210
+ raise ValueError("--do_predict requires a test dataset")
211
+ predict_dataset = raw_datasets["test"]
212
+ if data_args.max_predict_samples is not None:
213
+ predict_dataset = predict_dataset.select(
214
+ range(data_args.max_predict_samples))
215
+
216
+ # Log a few random samples from the training set:
217
+ if training_args.do_train:
218
+ for index in random.sample(range(len(train_dataset)), 3):
219
+ logger.info(
220
+ f"Sample {index} of the training set: {train_dataset[index]}.")
221
+
222
+ # Get the metric function
223
+ metric = load_metric("accuracy")
224
+
225
+ # You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
226
+ # predictions and label_ids field) and has to return a dictionary string to float.
227
+ def compute_metrics(p: EvalPrediction):
228
+ preds = p.predictions[0] if isinstance(
229
+ p.predictions, tuple) else p.predictions
230
+ preds = np.argmax(preds, axis=1)
231
+ if data_args.task_name is not None:
232
+ result = metric.compute(predictions=preds, references=p.label_ids)
233
+ if len(result) > 1:
234
+ result["combined_score"] = np.mean(
235
+ list(result.values())).item()
236
+ return result
237
+ else:
238
+ return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
239
+
240
+ # Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
241
+ # we already did the padding.
242
+ if data_args.pad_to_max_length:
243
+ data_collator = default_data_collator
244
+ elif training_args.fp16:
245
+ data_collator = DataCollatorWithPadding(
246
+ tokenizer, pad_to_multiple_of=8)
247
+ else:
248
+ data_collator = None
249
+
250
+ # Initialize our Trainer
251
+ trainer = Trainer(
252
+ model=model,
253
+ args=training_args,
254
+ train_dataset=train_dataset,
255
+ eval_dataset=eval_dataset,
256
+ compute_metrics=compute_metrics,
257
+ tokenizer=tokenizer,
258
+ data_collator=data_collator,
259
+ )
260
+
261
+ # Training
262
+ train_result = train_from_checkpoint(
263
+ trainer, last_checkpoint, training_args)
264
+
265
+ metrics = train_result.metrics
266
+ max_train_samples = (
267
+ data_args.max_train_samples if data_args.max_train_samples is not None else len(
268
+ train_dataset)
269
+ )
270
+ metrics["train_samples"] = min(max_train_samples, len(train_dataset))
271
+
272
+ trainer.save_model() # Saves the tokenizer too for easy upload
273
+
274
+ trainer.log_metrics("train", metrics)
275
+ trainer.save_metrics("train", metrics)
276
+ trainer.save_state()
277
+
278
+ kwargs = {"finetuned_from": model_args.model_name_or_path,
279
+ "tasks": "text-classification"}
280
+ if training_args.push_to_hub:
281
+ trainer.push_to_hub(**kwargs)
282
+ else:
283
+ trainer.create_model_card(**kwargs)
284
+
285
+
286
+ if __name__ == "__main__":
287
+ main()
src/utils.py CHANGED
@@ -1,8 +1,4 @@
1
  import re
2
- import logging
3
-
4
- logging.basicConfig()
5
- logger = logging.getLogger(__name__)
6
 
7
 
8
  def re_findall(pattern, string):
 
1
  import re
 
 
 
 
2
 
3
 
4
  def re_findall(pattern, string):