Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
36f7534
1
Parent(s):
15626e5
Upgrade classifier to transformer-based model
Browse files- src/classify.py +41 -0
- src/errors.py +2 -6
- src/evaluate.py +10 -10
- src/model.py +179 -65
- src/moderate.py +104 -0
- src/predict.py +26 -203
- src/preprocess.py +89 -45
- src/segment.py +2 -0
- src/shared.py +153 -0
- src/train.py +120 -298
- src/train_classifier.py +287 -0
- src/utils.py +0 -4
src/classify.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import TextClassificationPipeline
|
2 |
+
import preprocess
|
3 |
+
import segment
|
4 |
+
|
5 |
+
|
6 |
+
class SponsorBlockClassificationPipeline(TextClassificationPipeline):
|
7 |
+
def __init__(self, model, tokenizer):
|
8 |
+
device = next(model.parameters()).device.index
|
9 |
+
super().__init__(model=model, tokenizer=tokenizer,
|
10 |
+
return_all_scores=True, truncation=True, device=device)
|
11 |
+
|
12 |
+
def preprocess(self, data, **tokenizer_kwargs):
|
13 |
+
# TODO add support for lists
|
14 |
+
texts = []
|
15 |
+
|
16 |
+
if not isinstance(data, list):
|
17 |
+
data = [data]
|
18 |
+
|
19 |
+
for d in data:
|
20 |
+
if isinstance(d, dict): # Otherwise, get data from transcript
|
21 |
+
words = preprocess.get_words(d['video_id'])
|
22 |
+
segment_words = segment.extract_segment(
|
23 |
+
words, d['start'], d['end'])
|
24 |
+
text = preprocess.clean_text(
|
25 |
+
' '.join(x['text'] for x in segment_words))
|
26 |
+
texts.append(text)
|
27 |
+
elif isinstance(d, str): # If string, assume this is what user wants to classify
|
28 |
+
texts.append(d)
|
29 |
+
else:
|
30 |
+
raise ValueError(f'Invalid input type: "{type(d)}"')
|
31 |
+
|
32 |
+
return self.tokenizer(
|
33 |
+
texts, return_tensors=self.framework, **tokenizer_kwargs)
|
34 |
+
|
35 |
+
|
36 |
+
def main():
|
37 |
+
pass
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == '__main__':
|
41 |
+
main()
|
src/errors.py
CHANGED
@@ -1,9 +1,10 @@
|
|
|
|
1 |
class SponsorBlockException(Exception):
|
2 |
"""Base class for all sponsor block exceptions"""
|
3 |
pass
|
4 |
|
5 |
|
6 |
-
class
|
7 |
"""An exception occurred while predicting sponsor segments"""
|
8 |
pass
|
9 |
|
@@ -21,8 +22,3 @@ class ModelError(SponsorBlockException):
|
|
21 |
class ModelLoadError(ModelError):
|
22 |
"""An exception occurred while loading the model"""
|
23 |
pass
|
24 |
-
|
25 |
-
|
26 |
-
class ClassifierLoadError(ModelError):
|
27 |
-
"""An exception occurred while loading the classifier"""
|
28 |
-
pass
|
|
|
1 |
+
|
2 |
class SponsorBlockException(Exception):
|
3 |
"""Base class for all sponsor block exceptions"""
|
4 |
pass
|
5 |
|
6 |
|
7 |
+
class InferenceException(SponsorBlockException):
|
8 |
"""An exception occurred while predicting sponsor segments"""
|
9 |
pass
|
10 |
|
|
|
22 |
class ModelLoadError(ModelError):
|
23 |
"""An exception occurred while loading the model"""
|
24 |
pass
|
|
|
|
|
|
|
|
|
|
src/evaluate.py
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
|
2 |
-
from model import
|
3 |
from utils import jaccard
|
4 |
from transformers import HfArgumentParser
|
5 |
-
from preprocess import
|
6 |
-
from shared import GeneralArguments
|
7 |
-
from predict import
|
8 |
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
|
9 |
import pandas as pd
|
10 |
from dataclasses import dataclass, field
|
@@ -134,11 +134,10 @@ def main():
|
|
134 |
EvaluationArguments,
|
135 |
DatasetArguments,
|
136 |
SegmentationArguments,
|
137 |
-
ClassifierArguments,
|
138 |
GeneralArguments
|
139 |
))
|
140 |
|
141 |
-
evaluation_args, dataset_args, segmentation_args,
|
142 |
|
143 |
# Load labelled data:
|
144 |
final_path = os.path.join(
|
@@ -149,8 +148,8 @@ def main():
|
|
149 |
f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
|
150 |
return
|
151 |
|
152 |
-
model, tokenizer =
|
153 |
-
evaluation_args
|
154 |
|
155 |
with open(final_path) as fp:
|
156 |
final_data = json.load(fp)
|
@@ -187,8 +186,9 @@ def main():
|
|
187 |
continue
|
188 |
|
189 |
# Make predictions
|
190 |
-
predictions = predict(video_id, model, tokenizer,
|
191 |
-
|
|
|
192 |
|
193 |
# Get labels
|
194 |
sponsor_segments = final_data.get(video_id)
|
|
|
1 |
|
2 |
+
from model import get_model_tokenizer_classifier, InferenceArguments
|
3 |
from utils import jaccard
|
4 |
from transformers import HfArgumentParser
|
5 |
+
from preprocess import get_words
|
6 |
+
from shared import GeneralArguments, DatasetArguments
|
7 |
+
from predict import predict
|
8 |
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
|
9 |
import pandas as pd
|
10 |
from dataclasses import dataclass, field
|
|
|
134 |
EvaluationArguments,
|
135 |
DatasetArguments,
|
136 |
SegmentationArguments,
|
|
|
137 |
GeneralArguments
|
138 |
))
|
139 |
|
140 |
+
evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
|
141 |
|
142 |
# Load labelled data:
|
143 |
final_path = os.path.join(
|
|
|
148 |
f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
|
149 |
return
|
150 |
|
151 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(
|
152 |
+
evaluation_args, general_args)
|
153 |
|
154 |
with open(final_path) as fp:
|
155 |
final_data = json.load(fp)
|
|
|
186 |
continue
|
187 |
|
188 |
# Make predictions
|
189 |
+
predictions = predict(video_id, model, tokenizer, segmentation_args,
|
190 |
+
classifier=classifier,
|
191 |
+
min_probability=evaluation_args.min_probability)
|
192 |
|
193 |
# Get labels
|
194 |
sponsor_segments = final_data.get(video_id)
|
src/model.py
CHANGED
@@ -1,13 +1,68 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from shared import CustomTokens
|
4 |
-
from errors import ClassifierLoadError, ModelLoadError
|
5 |
from functools import lru_cache
|
6 |
-
import pickle
|
7 |
-
import os
|
8 |
from dataclasses import dataclass, field
|
9 |
-
from typing import Optional
|
10 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
|
13 |
@dataclass
|
@@ -17,34 +72,24 @@ class ModelArguments:
|
|
17 |
"""
|
18 |
|
19 |
model_name_or_path: str = field(
|
20 |
-
default=None,
|
21 |
-
# default='google/t5-v1_1-small', # t5-small
|
22 |
metadata={
|
23 |
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
24 |
}
|
25 |
)
|
26 |
|
27 |
-
# config_name: Optional[str] = field( # TODO remove?
|
28 |
-
# default=None, metadata={'help': 'Pretrained config name or path if not the same as model_name'}
|
29 |
-
# )
|
30 |
-
# tokenizer_name: Optional[str] = field(
|
31 |
-
# default=None, metadata={
|
32 |
-
# 'help': 'Pretrained tokenizer name or path if not the same as model_name'
|
33 |
-
# }
|
34 |
-
# )
|
35 |
cache_dir: Optional[str] = field(
|
36 |
default='models',
|
37 |
metadata={
|
38 |
'help': 'Where to store the pretrained models downloaded from huggingface.co'
|
39 |
},
|
40 |
)
|
41 |
-
use_fast_tokenizer: bool = field(
|
42 |
default=True,
|
43 |
metadata={
|
44 |
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
|
45 |
},
|
46 |
)
|
47 |
-
model_revision: str = field(
|
48 |
default='main',
|
49 |
metadata={
|
50 |
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
|
@@ -57,62 +102,131 @@ class ModelArguments:
|
|
57 |
'with private models).'
|
58 |
},
|
59 |
)
|
60 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
61 |
default=None,
|
62 |
metadata={
|
63 |
-
'help':
|
64 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
)
|
66 |
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
|
108 |
tokenizer = AutoTokenizer.from_pretrained(
|
109 |
-
model_name_or_path,
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
CustomTokens.add_custom_tokens(tokenizer)
|
113 |
model.resize_token_embeddings(len(tokenizer))
|
114 |
|
115 |
-
#
|
116 |
-
|
|
|
117 |
|
118 |
return model, tokenizer
|
|
|
1 |
+
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, AutoConfig, AutoModelForSequenceClassification, TrainingArguments
|
2 |
+
from shared import CustomTokens, GeneralArguments
|
|
|
|
|
3 |
from functools import lru_cache
|
|
|
|
|
4 |
from dataclasses import dataclass, field
|
5 |
+
from typing import Optional, Union
|
6 |
import torch
|
7 |
+
import classify
|
8 |
+
import base64
|
9 |
+
import re
|
10 |
+
import requests
|
11 |
+
import json
|
12 |
+
import logging
|
13 |
+
|
14 |
+
logging.basicConfig()
|
15 |
+
logger = logging.getLogger(__name__)
|
16 |
+
|
17 |
+
# Public innertube key (b64 encoded so that it is not incorrectly flagged)
|
18 |
+
INNERTUBE_KEY = base64.b64decode(
|
19 |
+
b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
|
20 |
+
|
21 |
+
YT_CONTEXT = {
|
22 |
+
'client': {
|
23 |
+
'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
|
24 |
+
'clientName': 'WEB',
|
25 |
+
'clientVersion': '2.20211221.00.00',
|
26 |
+
}
|
27 |
+
}
|
28 |
+
_YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
|
29 |
+
|
30 |
+
|
31 |
+
def get_all_channel_vids(channel_id):
|
32 |
+
continuation = None
|
33 |
+
while True:
|
34 |
+
if continuation is None:
|
35 |
+
params = {'list': channel_id.replace('UC', 'UU', 1)}
|
36 |
+
response = requests.get(
|
37 |
+
'https://www.youtube.com/playlist', params=params)
|
38 |
+
items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
|
39 |
+
'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
|
40 |
+
else:
|
41 |
+
params = {'key': INNERTUBE_KEY}
|
42 |
+
data = {
|
43 |
+
'context': YT_CONTEXT,
|
44 |
+
'continuation': continuation
|
45 |
+
}
|
46 |
+
response = requests.post(
|
47 |
+
'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
|
48 |
+
items = response.json()[
|
49 |
+
'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
|
50 |
+
|
51 |
+
new_token = None
|
52 |
+
for vid in items:
|
53 |
+
info = vid.get('playlistVideoRenderer')
|
54 |
+
if info:
|
55 |
+
yield info['videoId']
|
56 |
+
continue
|
57 |
+
|
58 |
+
info = vid.get('continuationItemRenderer')
|
59 |
+
if info:
|
60 |
+
new_token = info['continuationEndpoint']['continuationCommand']['token']
|
61 |
+
|
62 |
+
if new_token is None:
|
63 |
+
break
|
64 |
+
continuation = new_token
|
65 |
+
|
66 |
|
67 |
|
68 |
@dataclass
|
|
|
72 |
"""
|
73 |
|
74 |
model_name_or_path: str = field(
|
|
|
|
|
75 |
metadata={
|
76 |
'help': 'Path to pretrained model or model identifier from huggingface.co/models'
|
77 |
}
|
78 |
)
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
cache_dir: Optional[str] = field(
|
81 |
default='models',
|
82 |
metadata={
|
83 |
'help': 'Where to store the pretrained models downloaded from huggingface.co'
|
84 |
},
|
85 |
)
|
86 |
+
use_fast_tokenizer: bool = field(
|
87 |
default=True,
|
88 |
metadata={
|
89 |
'help': 'Whether to use one of the fast tokenizer (backed by the tokenizers library) or not.'
|
90 |
},
|
91 |
)
|
92 |
+
model_revision: str = field(
|
93 |
default='main',
|
94 |
metadata={
|
95 |
'help': 'The specific model version to use (can be a branch name, tag name or commit id).'
|
|
|
102 |
'with private models).'
|
103 |
},
|
104 |
)
|
105 |
+
|
106 |
+
import itertools
|
107 |
+
from errors import InferenceException
|
108 |
+
|
109 |
+
@dataclass
|
110 |
+
class InferenceArguments(ModelArguments):
|
111 |
+
|
112 |
+
model_name_or_path: str = field(
|
113 |
+
default='Xenova/sponsorblock-small',
|
114 |
+
metadata={
|
115 |
+
'help': 'Path to pretrained model used for prediction'
|
116 |
+
}
|
117 |
+
)
|
118 |
+
classifier_model_name_or_path: str = field(
|
119 |
+
default='Xenova/sponsorblock-classifier-v2',
|
120 |
+
metadata={
|
121 |
+
'help': 'Use a pretrained classifier'
|
122 |
+
}
|
123 |
+
)
|
124 |
+
|
125 |
+
max_videos: Optional[int] = field(
|
126 |
default=None,
|
127 |
metadata={
|
128 |
+
'help': 'The number of videos to test on'
|
129 |
+
}
|
130 |
+
)
|
131 |
+
start_index: int = field(default=None, metadata={
|
132 |
+
'help': 'Video to start the evaluation at.'})
|
133 |
+
channel_id: Optional[str] = field(
|
134 |
+
default=None,
|
135 |
+
metadata={
|
136 |
+
'help': 'Used to evaluate a channel'
|
137 |
+
}
|
138 |
+
)
|
139 |
+
video_ids: str = field(
|
140 |
+
default_factory=lambda: [],
|
141 |
+
metadata={
|
142 |
+
'nargs': '+'
|
143 |
+
}
|
144 |
)
|
145 |
|
146 |
+
output_as_json: bool = field(default=False, metadata={
|
147 |
+
'help': 'Output evaluations as JSON'})
|
148 |
|
149 |
+
min_probability: float = field(
|
150 |
+
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
151 |
+
|
152 |
+
def __post_init__(self):
|
153 |
+
|
154 |
+
self.video_ids = list(map(str.strip, self.video_ids))
|
155 |
+
|
156 |
+
if any(len(video_id) != 11 for video_id in self.video_ids):
|
157 |
+
raise InferenceException('Invalid video IDs (length not 11)')
|
158 |
+
|
159 |
+
if self.channel_id is not None:
|
160 |
+
start = self.start_index or 0
|
161 |
+
end = None if self.max_videos is None else start + self.max_videos
|
162 |
+
|
163 |
+
channel_video_ids = list(itertools.islice(get_all_channel_vids(
|
164 |
+
self.channel_id), start, end))
|
165 |
+
logger.info(
|
166 |
+
f'Found {len(channel_video_ids)} for channel {self.channel_id}')
|
167 |
+
|
168 |
+
self.video_ids += channel_video_ids
|
169 |
+
|
170 |
+
|
171 |
+
|
172 |
+
def get_model_tokenizer_classifier(inference_args: InferenceArguments, general_args: GeneralArguments):
|
173 |
+
|
174 |
+
original_path = inference_args.model_name_or_path
|
175 |
+
|
176 |
+
# Load main model and tokenizer
|
177 |
+
model, tokenizer = get_model_tokenizer(inference_args, general_args)
|
178 |
+
|
179 |
+
# Load classifier
|
180 |
+
inference_args.model_name_or_path = inference_args.classifier_model_name_or_path
|
181 |
+
classifier_model, classifier_tokenizer = get_model_tokenizer(
|
182 |
+
inference_args, general_args, model_type='classifier')
|
183 |
+
|
184 |
+
classifier = classify.SponsorBlockClassificationPipeline(
|
185 |
+
classifier_model, classifier_tokenizer)
|
186 |
+
|
187 |
+
# Reset to original model_name_or_path
|
188 |
+
inference_args.model_name_or_path = original_path
|
189 |
+
|
190 |
+
return model, tokenizer, classifier
|
191 |
+
|
192 |
+
|
193 |
+
def get_model_tokenizer(model_args: ModelArguments, general_args: Union[GeneralArguments, TrainingArguments] = None, config_args=None, model_type='seq2seq'):
|
194 |
+
if config_args is None:
|
195 |
+
config_args = {}
|
196 |
+
|
197 |
+
use_auth_token = True if model_args.use_auth_token else None
|
198 |
+
|
199 |
+
config = AutoConfig.from_pretrained(
|
200 |
+
model_args.model_name_or_path,
|
201 |
+
cache_dir=model_args.cache_dir,
|
202 |
+
revision=model_args.model_revision,
|
203 |
+
use_auth_token=use_auth_token,
|
204 |
+
**config_args
|
205 |
+
)
|
206 |
|
207 |
tokenizer = AutoTokenizer.from_pretrained(
|
208 |
+
model_args.model_name_or_path,
|
209 |
+
cache_dir=model_args.cache_dir,
|
210 |
+
use_fast=model_args.use_fast_tokenizer,
|
211 |
+
revision=model_args.model_revision,
|
212 |
+
use_auth_token=use_auth_token,
|
213 |
+
)
|
214 |
|
215 |
+
model_type = AutoModelForSeq2SeqLM if model_type == 'seq2seq' else AutoModelForSequenceClassification
|
216 |
+
model = model_type.from_pretrained(
|
217 |
+
model_args.model_name_or_path,
|
218 |
+
config=config,
|
219 |
+
cache_dir=model_args.cache_dir,
|
220 |
+
revision=model_args.model_revision,
|
221 |
+
use_auth_token=use_auth_token,
|
222 |
+
)
|
223 |
+
|
224 |
+
# Add custom tokens
|
225 |
CustomTokens.add_custom_tokens(tokenizer)
|
226 |
model.resize_token_embeddings(len(tokenizer))
|
227 |
|
228 |
+
# Potentially move model to gpu
|
229 |
+
if general_args is not None and not general_args.no_cuda:
|
230 |
+
model.to('cuda' if torch.cuda.is_available() else 'cpu')
|
231 |
|
232 |
return model, tokenizer
|
src/moderate.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from transformers import (
|
3 |
+
AutoModelForSequenceClassification,
|
4 |
+
AutoTokenizer,
|
5 |
+
HfArgumentParser
|
6 |
+
)
|
7 |
+
|
8 |
+
from train_classifier import ClassifierModelArguments
|
9 |
+
from shared import CATEGORIES, DatasetArguments
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from preprocess import get_words, clean_text
|
13 |
+
from segment import extract_segment
|
14 |
+
import os
|
15 |
+
import json
|
16 |
+
import numpy as np
|
17 |
+
|
18 |
+
|
19 |
+
def softmax(_outputs):
|
20 |
+
maxes = np.max(_outputs, axis=-1, keepdims=True)
|
21 |
+
shifted_exp = np.exp(_outputs - maxes)
|
22 |
+
return shifted_exp / shifted_exp.sum(axis=-1, keepdims=True)
|
23 |
+
|
24 |
+
|
25 |
+
def main():
|
26 |
+
# See all possible arguments in src/transformers/training_args.py
|
27 |
+
# or by passing the --help flag to this script.
|
28 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
29 |
+
|
30 |
+
parser = HfArgumentParser((ClassifierModelArguments, DatasetArguments))
|
31 |
+
model_args, dataset_args = parser.parse_args_into_dataclasses()
|
32 |
+
|
33 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
34 |
+
model_args.model_name_or_path)
|
35 |
+
tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path)
|
36 |
+
|
37 |
+
processed_db_path = os.path.join(
|
38 |
+
dataset_args.data_dir, dataset_args.processed_database)
|
39 |
+
with open(processed_db_path) as fp:
|
40 |
+
data = json.load(fp)
|
41 |
+
|
42 |
+
mapped_categories = {
|
43 |
+
str(v).lower(): k for k, v in enumerate(CATEGORIES)
|
44 |
+
}
|
45 |
+
|
46 |
+
for video_id, segments in tqdm(data.items()):
|
47 |
+
|
48 |
+
words = get_words(video_id)
|
49 |
+
|
50 |
+
if not words:
|
51 |
+
continue # No/empty transcript for video_id
|
52 |
+
|
53 |
+
valid_segments = []
|
54 |
+
texts = []
|
55 |
+
for segment in segments:
|
56 |
+
segment_words = extract_segment(
|
57 |
+
words, segment['start'], segment['end'])
|
58 |
+
text = clean_text(' '.join(x['text'] for x in segment_words))
|
59 |
+
|
60 |
+
duration = segment['end'] - segment['start']
|
61 |
+
wps = len(segment_words)/duration if duration > 0 else 0
|
62 |
+
if wps < 1.5:
|
63 |
+
continue
|
64 |
+
|
65 |
+
# Do not worry about those that are locked or have enough votes
|
66 |
+
if segment['locked']: # or segment['votes'] > 5:
|
67 |
+
continue
|
68 |
+
|
69 |
+
texts.append(text)
|
70 |
+
valid_segments.append(segment)
|
71 |
+
|
72 |
+
if not texts:
|
73 |
+
continue # No valid segments
|
74 |
+
|
75 |
+
model_inputs = tokenizer(
|
76 |
+
texts, return_tensors='pt', padding=True, truncation=True)
|
77 |
+
|
78 |
+
with torch.no_grad():
|
79 |
+
model_outputs = model(**model_inputs)
|
80 |
+
outputs = list(map(lambda x: x.numpy(), model_outputs['logits']))
|
81 |
+
|
82 |
+
scores = softmax(outputs)
|
83 |
+
|
84 |
+
for segment, text, score in zip(valid_segments, texts, scores):
|
85 |
+
predicted_index = score.argmax().item()
|
86 |
+
|
87 |
+
if predicted_index == mapped_categories[segment['category']]:
|
88 |
+
continue # Ignore correct segments
|
89 |
+
|
90 |
+
a = {k: round(float(score[i]), 3)
|
91 |
+
for i, k in enumerate(CATEGORIES)}
|
92 |
+
|
93 |
+
del segment['submission_time']
|
94 |
+
segment.update({
|
95 |
+
'predicted': str(CATEGORIES[predicted_index]).lower(),
|
96 |
+
'text': text,
|
97 |
+
'scores': a
|
98 |
+
})
|
99 |
+
|
100 |
+
print(json.dumps(segment))
|
101 |
+
|
102 |
+
|
103 |
+
if __name__ == "__main__":
|
104 |
+
main()
|
src/predict.py
CHANGED
@@ -1,17 +1,8 @@
|
|
1 |
-
|
2 |
-
import base64
|
3 |
-
import re
|
4 |
-
import requests
|
5 |
-
import json
|
6 |
from transformers import HfArgumentParser
|
7 |
-
from transformers.trainer_utils import get_last_checkpoint
|
8 |
from dataclasses import dataclass, field
|
9 |
import logging
|
10 |
-
import
|
11 |
-
import itertools
|
12 |
-
from utils import re_findall
|
13 |
-
from shared import CustomTokens, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, OutputArguments, seconds_to_time
|
14 |
-
from typing import Optional
|
15 |
from segment import (
|
16 |
generate_segments,
|
17 |
extract_segment,
|
@@ -22,129 +13,12 @@ from segment import (
|
|
22 |
SegmentationArguments
|
23 |
)
|
24 |
import preprocess
|
25 |
-
from errors import
|
26 |
-
from model import
|
27 |
|
28 |
logging.basicConfig()
|
29 |
logger = logging.getLogger(__name__)
|
30 |
|
31 |
-
# Public innertube key (b64 encoded so that it is not incorrectly flagged)
|
32 |
-
INNERTUBE_KEY = base64.b64decode(
|
33 |
-
b'QUl6YVN5QU9fRkoyU2xxVThRNFNURUhMR0NpbHdfWTlfMTFxY1c4').decode()
|
34 |
-
|
35 |
-
YT_CONTEXT = {
|
36 |
-
'client': {
|
37 |
-
'userAgent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/96.0.4664.110 Safari/537.36,gzip(gfe)',
|
38 |
-
'clientName': 'WEB',
|
39 |
-
'clientVersion': '2.20211221.00.00',
|
40 |
-
}
|
41 |
-
}
|
42 |
-
_YT_INITIAL_DATA_RE = r'(?:window\s*\[\s*["\']ytInitialData["\']\s*\]|ytInitialData)\s*=\s*({.+?})\s*;\s*(?:var\s+meta|</script|\n)'
|
43 |
-
|
44 |
-
|
45 |
-
def get_all_channel_vids(channel_id):
|
46 |
-
continuation = None
|
47 |
-
while True:
|
48 |
-
if continuation is None:
|
49 |
-
params = {'list': channel_id.replace('UC', 'UU', 1)}
|
50 |
-
response = requests.get(
|
51 |
-
'https://www.youtube.com/playlist', params=params)
|
52 |
-
items = json.loads(re.search(_YT_INITIAL_DATA_RE, response.text).group(1))['contents']['twoColumnBrowseResultsRenderer']['tabs'][0]['tabRenderer']['content'][
|
53 |
-
'sectionListRenderer']['contents'][0]['itemSectionRenderer']['contents'][0]['playlistVideoListRenderer']['contents']
|
54 |
-
else:
|
55 |
-
params = {'key': INNERTUBE_KEY}
|
56 |
-
data = {
|
57 |
-
'context': YT_CONTEXT,
|
58 |
-
'continuation': continuation
|
59 |
-
}
|
60 |
-
response = requests.post(
|
61 |
-
'https://www.youtube.com/youtubei/v1/browse', params=params, json=data)
|
62 |
-
items = response.json()[
|
63 |
-
'onResponseReceivedActions'][0]['appendContinuationItemsAction']['continuationItems']
|
64 |
-
|
65 |
-
new_token = None
|
66 |
-
for vid in items:
|
67 |
-
info = vid.get('playlistVideoRenderer')
|
68 |
-
if info:
|
69 |
-
yield info['videoId']
|
70 |
-
continue
|
71 |
-
|
72 |
-
info = vid.get('continuationItemRenderer')
|
73 |
-
if info:
|
74 |
-
new_token = info['continuationEndpoint']['continuationCommand']['token']
|
75 |
-
|
76 |
-
if new_token is None:
|
77 |
-
break
|
78 |
-
continuation = new_token
|
79 |
-
|
80 |
-
|
81 |
-
@dataclass
|
82 |
-
class InferenceArguments:
|
83 |
-
|
84 |
-
model_path: str = field(
|
85 |
-
default='Xenova/sponsorblock-small',
|
86 |
-
metadata={
|
87 |
-
'help': 'Path to pretrained model used for prediction'
|
88 |
-
}
|
89 |
-
)
|
90 |
-
cache_dir: Optional[str] = ModelArguments.__dataclass_fields__['cache_dir']
|
91 |
-
|
92 |
-
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
93 |
-
'output_dir']
|
94 |
-
|
95 |
-
max_videos: Optional[int] = field(
|
96 |
-
default=None,
|
97 |
-
metadata={
|
98 |
-
'help': 'The number of videos to test on'
|
99 |
-
}
|
100 |
-
)
|
101 |
-
start_index: int = field(default=None, metadata={
|
102 |
-
'help': 'Video to start the evaluation at.'})
|
103 |
-
channel_id: Optional[str] = field(
|
104 |
-
default=None,
|
105 |
-
metadata={
|
106 |
-
'help': 'Used to evaluate a channel'
|
107 |
-
}
|
108 |
-
)
|
109 |
-
video_ids: str = field(
|
110 |
-
default_factory=lambda: [],
|
111 |
-
metadata={
|
112 |
-
'nargs': '+'
|
113 |
-
}
|
114 |
-
)
|
115 |
-
|
116 |
-
output_as_json: bool = field(default=False, metadata={
|
117 |
-
'help': 'Output evaluations as JSON'})
|
118 |
-
|
119 |
-
def __post_init__(self):
|
120 |
-
# Try to load model from latest checkpoint
|
121 |
-
if self.model_path is None:
|
122 |
-
if os.path.exists(self.output_dir):
|
123 |
-
last_checkpoint = get_last_checkpoint(self.output_dir)
|
124 |
-
if last_checkpoint is not None:
|
125 |
-
self.model_path = last_checkpoint
|
126 |
-
else:
|
127 |
-
raise ModelLoadError(
|
128 |
-
'Unable to load model from checkpoint, explicitly set `--model_path`')
|
129 |
-
else:
|
130 |
-
raise ModelLoadError(
|
131 |
-
f'Unable to find model in {self.output_dir}, explicitly set `--model_path`')
|
132 |
-
|
133 |
-
if any(len(video_id) != 11 for video_id in self.video_ids):
|
134 |
-
raise PredictionException('Invalid video IDs (length not 11)')
|
135 |
-
|
136 |
-
if self.channel_id is not None:
|
137 |
-
start = self.start_index or 0
|
138 |
-
end = None if self.max_videos is None else start + self.max_videos
|
139 |
-
|
140 |
-
channel_video_ids = list(itertools.islice(get_all_channel_vids(
|
141 |
-
self.channel_id), start, end))
|
142 |
-
logger.info(
|
143 |
-
f'Found {len(channel_video_ids)} for channel {self.channel_id}')
|
144 |
-
|
145 |
-
self.video_ids += channel_video_ids
|
146 |
-
|
147 |
-
|
148 |
@dataclass
|
149 |
class PredictArguments(InferenceArguments):
|
150 |
video_id: str = field(
|
@@ -160,10 +34,6 @@ class PredictArguments(InferenceArguments):
|
|
160 |
super().__post_init__()
|
161 |
|
162 |
|
163 |
-
_SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
|
164 |
-
_SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
|
165 |
-
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
166 |
-
|
167 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
168 |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
169 |
|
@@ -171,70 +41,35 @@ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
|
171 |
START_TIME_ZERO_THRESHOLD = 0.08
|
172 |
|
173 |
|
174 |
-
|
175 |
-
class ClassifierArguments:
|
176 |
-
classifier_model: Optional[str] = field(
|
177 |
-
default='Xenova/sponsorblock-classifier',
|
178 |
-
metadata={
|
179 |
-
'help': 'Use a pretrained classifier'
|
180 |
-
}
|
181 |
-
)
|
182 |
-
|
183 |
-
classifier_dir: Optional[str] = field(
|
184 |
-
default='classifiers',
|
185 |
-
metadata={
|
186 |
-
'help': 'The directory that contains the classifier and vectorizer.'
|
187 |
-
}
|
188 |
-
)
|
189 |
-
|
190 |
-
classifier_file: Optional[str] = field(
|
191 |
-
default='classifier.pickle',
|
192 |
-
metadata={
|
193 |
-
'help': 'The name of the classifier'
|
194 |
-
}
|
195 |
-
)
|
196 |
-
|
197 |
-
vectorizer_file: Optional[str] = field(
|
198 |
-
default='vectorizer.pickle',
|
199 |
-
metadata={
|
200 |
-
'help': 'The name of the vectorizer'
|
201 |
-
}
|
202 |
-
)
|
203 |
-
|
204 |
-
min_probability: float = field(
|
205 |
-
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
206 |
-
|
207 |
-
|
208 |
-
def filter_and_add_probabilities(predictions, classifier_args):
|
209 |
"""Use classifier to filter predictions"""
|
210 |
if not predictions:
|
211 |
return predictions
|
212 |
|
213 |
-
|
|
|
214 |
|
215 |
-
|
216 |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
217 |
for pred in predictions
|
218 |
-
]
|
219 |
-
|
220 |
|
221 |
-
# Transformer sometimes says segment is of another category, so we
|
222 |
-
# update category and probabilities if classifier is confident it is another category
|
223 |
filtered_predictions = []
|
224 |
-
for prediction, probabilities in zip(predictions,
|
225 |
-
predicted_probabilities = {
|
226 |
-
|
227 |
|
228 |
# Get best category + probability
|
229 |
classifier_category = max(
|
230 |
predicted_probabilities, key=predicted_probabilities.get)
|
231 |
classifier_probability = predicted_probabilities[classifier_category]
|
232 |
|
233 |
-
if classifier_category
|
234 |
continue # Ignore
|
235 |
|
236 |
if (prediction['category'] not in predicted_probabilities) \
|
237 |
-
or (classifier_category
|
238 |
# Unknown category or we are confident enough to overrule,
|
239 |
# so change category to what was predicted by classifier
|
240 |
prediction['category'] = classifier_category
|
@@ -252,7 +87,7 @@ def filter_and_add_probabilities(predictions, classifier_args):
|
|
252 |
return filtered_predictions
|
253 |
|
254 |
|
255 |
-
def predict(video_id, model, tokenizer, segmentation_args, words=None,
|
256 |
# Allow words to be passed in so that we don't have to get the words if we already have them
|
257 |
if words is None:
|
258 |
words = preprocess.get_words(video_id)
|
@@ -272,13 +107,9 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
272 |
prediction['words'] = extract_segment(
|
273 |
words, prediction['start'], prediction['end'])
|
274 |
|
275 |
-
|
276 |
-
|
277 |
-
|
278 |
-
predictions = filter_and_add_probabilities(
|
279 |
-
predictions, classifier_args)
|
280 |
-
except ClassifierLoadError:
|
281 |
-
print('Unable to load classifer')
|
282 |
|
283 |
return predictions
|
284 |
|
@@ -300,9 +131,6 @@ def greedy_match(list, sublist):
|
|
300 |
return best_i, best_j, best_k
|
301 |
|
302 |
|
303 |
-
CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
|
304 |
-
|
305 |
-
|
306 |
def predict_sponsor_text(text, model, tokenizer):
|
307 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
308 |
model_device = next(model.parameters()).device
|
@@ -322,11 +150,7 @@ def predict_sponsor_text(text, model, tokenizer):
|
|
322 |
|
323 |
def predict_sponsor_matches(text, model, tokenizer):
|
324 |
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
325 |
-
|
326 |
-
if CustomTokens.NO_SEGMENT.value in sponsorship_text:
|
327 |
-
return []
|
328 |
-
|
329 |
-
return re_findall(SEGMENT_MATCH_RE, sponsorship_text)
|
330 |
|
331 |
|
332 |
def segments_to_predictions(segments, model, tokenizer):
|
@@ -400,24 +224,23 @@ def main():
|
|
400 |
hf_parser = HfArgumentParser((
|
401 |
PredictArguments,
|
402 |
SegmentationArguments,
|
403 |
-
ClassifierArguments,
|
404 |
GeneralArguments
|
405 |
))
|
406 |
-
predict_args, segmentation_args,
|
407 |
|
408 |
if not predict_args.video_ids:
|
409 |
logger.error(
|
410 |
'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
|
411 |
return
|
412 |
|
413 |
-
model, tokenizer =
|
414 |
-
predict_args
|
415 |
|
416 |
for video_id in predict_args.video_ids:
|
417 |
-
video_id = video_id.strip()
|
418 |
try:
|
419 |
-
predictions = predict(video_id, model, tokenizer,
|
420 |
-
|
|
|
421 |
except TranscriptError:
|
422 |
logger.warning(f'No transcript available for {video_id}')
|
423 |
continue
|
|
|
1 |
+
|
|
|
|
|
|
|
|
|
2 |
from transformers import HfArgumentParser
|
|
|
3 |
from dataclasses import dataclass, field
|
4 |
import logging
|
5 |
+
from shared import CustomTokens, extract_sponsor_matches, GeneralArguments, seconds_to_time
|
|
|
|
|
|
|
|
|
6 |
from segment import (
|
7 |
generate_segments,
|
8 |
extract_segment,
|
|
|
13 |
SegmentationArguments
|
14 |
)
|
15 |
import preprocess
|
16 |
+
from errors import TranscriptError
|
17 |
+
from model import get_model_tokenizer_classifier, InferenceArguments
|
18 |
|
19 |
logging.basicConfig()
|
20 |
logger = logging.getLogger(__name__)
|
21 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
@dataclass
|
23 |
class PredictArguments(InferenceArguments):
|
24 |
video_id: str = field(
|
|
|
34 |
super().__post_init__()
|
35 |
|
36 |
|
|
|
|
|
|
|
|
|
37 |
MATCH_WINDOW = 25 # Increase for accuracy, but takes longer: O(n^3)
|
38 |
MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
39 |
|
|
|
41 |
START_TIME_ZERO_THRESHOLD = 0.08
|
42 |
|
43 |
|
44 |
+
def filter_and_add_probabilities(predictions, classifier, min_probability):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"""Use classifier to filter predictions"""
|
46 |
if not predictions:
|
47 |
return predictions
|
48 |
|
49 |
+
# We update the predicted category from the extractive transformer
|
50 |
+
# if the classifier is confident enough it is another category
|
51 |
|
52 |
+
texts = [
|
53 |
preprocess.clean_text(' '.join([x['text'] for x in pred['words']]))
|
54 |
for pred in predictions
|
55 |
+
]
|
56 |
+
classifications = classifier(texts)
|
57 |
|
|
|
|
|
58 |
filtered_predictions = []
|
59 |
+
for prediction, probabilities in zip(predictions, classifications):
|
60 |
+
predicted_probabilities = {
|
61 |
+
p['label'].lower(): p['score'] for p in probabilities}
|
62 |
|
63 |
# Get best category + probability
|
64 |
classifier_category = max(
|
65 |
predicted_probabilities, key=predicted_probabilities.get)
|
66 |
classifier_probability = predicted_probabilities[classifier_category]
|
67 |
|
68 |
+
if classifier_category == 'none' and classifier_probability > min_probability:
|
69 |
continue # Ignore
|
70 |
|
71 |
if (prediction['category'] not in predicted_probabilities) \
|
72 |
+
or (classifier_category != 'none' and classifier_probability > 0.5): # TODO make param
|
73 |
# Unknown category or we are confident enough to overrule,
|
74 |
# so change category to what was predicted by classifier
|
75 |
prediction['category'] = classifier_category
|
|
|
87 |
return filtered_predictions
|
88 |
|
89 |
|
90 |
+
def predict(video_id, model, tokenizer, segmentation_args, words=None, classifier=None, min_probability=None):
|
91 |
# Allow words to be passed in so that we don't have to get the words if we already have them
|
92 |
if words is None:
|
93 |
words = preprocess.get_words(video_id)
|
|
|
107 |
prediction['words'] = extract_segment(
|
108 |
words, prediction['start'], prediction['end'])
|
109 |
|
110 |
+
if classifier is not None:
|
111 |
+
predictions = filter_and_add_probabilities(
|
112 |
+
predictions, classifier, min_probability)
|
|
|
|
|
|
|
|
|
113 |
|
114 |
return predictions
|
115 |
|
|
|
131 |
return best_i, best_j, best_k
|
132 |
|
133 |
|
|
|
|
|
|
|
134 |
def predict_sponsor_text(text, model, tokenizer):
|
135 |
"""Given a body of text, predict the words which are part of the sponsor"""
|
136 |
model_device = next(model.parameters()).device
|
|
|
150 |
|
151 |
def predict_sponsor_matches(text, model, tokenizer):
|
152 |
sponsorship_text = predict_sponsor_text(text, model, tokenizer)
|
153 |
+
return extract_sponsor_matches(sponsorship_text)
|
|
|
|
|
|
|
|
|
154 |
|
155 |
|
156 |
def segments_to_predictions(segments, model, tokenizer):
|
|
|
224 |
hf_parser = HfArgumentParser((
|
225 |
PredictArguments,
|
226 |
SegmentationArguments,
|
|
|
227 |
GeneralArguments
|
228 |
))
|
229 |
+
predict_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
|
230 |
|
231 |
if not predict_args.video_ids:
|
232 |
logger.error(
|
233 |
'No video IDs supplied. Use `--video_id`, `--video_ids`, or `--channel_id`.')
|
234 |
return
|
235 |
|
236 |
+
model, tokenizer, classifier = get_model_tokenizer_classifier(
|
237 |
+
predict_args, general_args)
|
238 |
|
239 |
for video_id in predict_args.video_ids:
|
|
|
240 |
try:
|
241 |
+
predictions = predict(video_id, model, tokenizer, segmentation_args,
|
242 |
+
classifier=classifier,
|
243 |
+
min_probability=predict_args.min_probability)
|
244 |
except TranscriptError:
|
245 |
logger.warning(f'No transcript available for {video_id}')
|
246 |
continue
|
src/preprocess.py
CHANGED
@@ -1,14 +1,15 @@
|
|
|
|
1 |
from utils import jaccard
|
2 |
from functools import lru_cache
|
3 |
from datetime import datetime
|
4 |
import itertools
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
import segment
|
8 |
from tqdm import tqdm
|
9 |
from dataclasses import dataclass, field
|
10 |
from transformers import HfArgumentParser
|
11 |
-
from shared import ACTION_OPTIONS, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
|
12 |
import csv
|
13 |
import re
|
14 |
import random
|
@@ -213,9 +214,10 @@ def get_words(video_id, process=True, transcript_type='auto', fallback='manual',
|
|
213 |
else:
|
214 |
ts = transcript_list.find_generated_transcript(
|
215 |
LANGUAGE_PREFERENCE_LIST)
|
216 |
-
|
217 |
-
|
218 |
-
|
|
|
219 |
|
220 |
except (TooManyRequests, YouTubeRequestFailed):
|
221 |
raise # Cannot recover from these errors and do not mark as empty transcript
|
@@ -386,9 +388,14 @@ class PreprocessArguments:
|
|
386 |
|
387 |
max_date: str = field(
|
388 |
# default='01/01/9999', # Include all
|
389 |
-
default='
|
390 |
metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
|
391 |
|
|
|
|
|
|
|
|
|
|
|
392 |
keep_duplicate_segments: bool = field(
|
393 |
default=False, metadata={'help': 'Keep duplicate segments'}
|
394 |
)
|
@@ -482,25 +489,7 @@ def download_file(url, filename):
|
|
482 |
|
483 |
|
484 |
@dataclass
|
485 |
-
class DatasetArguments:
|
486 |
-
data_dir: Optional[str] = field(
|
487 |
-
default='data',
|
488 |
-
metadata={
|
489 |
-
'help': 'The directory which stores train, test and/or validation data.'
|
490 |
-
},
|
491 |
-
)
|
492 |
-
processed_file: Optional[str] = field(
|
493 |
-
default='segments.json',
|
494 |
-
metadata={
|
495 |
-
'help': 'Processed data file'
|
496 |
-
},
|
497 |
-
)
|
498 |
-
processed_database: Optional[str] = field(
|
499 |
-
default='processed_database.json',
|
500 |
-
metadata={
|
501 |
-
'help': 'Processed database file'
|
502 |
-
},
|
503 |
-
)
|
504 |
|
505 |
train_file: Optional[str] = field(
|
506 |
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
@@ -508,21 +497,38 @@ class DatasetArguments:
|
|
508 |
validation_file: Optional[str] = field(
|
509 |
default='valid.json',
|
510 |
metadata={
|
511 |
-
'help': 'An optional input evaluation data file to evaluate the metrics
|
512 |
},
|
513 |
)
|
514 |
test_file: Optional[str] = field(
|
515 |
default='test.json',
|
516 |
metadata={
|
517 |
-
'help': 'An optional input test data file to evaluate the metrics
|
518 |
},
|
519 |
)
|
520 |
-
|
521 |
-
|
|
|
|
|
|
|
|
|
522 |
metadata={
|
523 |
-
'help': '
|
524 |
},
|
525 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
526 |
dataset_cache_dir: Optional[str] = field(
|
527 |
default=None,
|
528 |
metadata={
|
@@ -555,9 +561,9 @@ def main():
|
|
555 |
# Generate final.json from sponsorTimes.csv
|
556 |
hf_parser = HfArgumentParser((
|
557 |
PreprocessArguments,
|
558 |
-
|
559 |
segment.SegmentationArguments,
|
560 |
-
ModelArguments,
|
561 |
GeneralArguments
|
562 |
))
|
563 |
preprocess_args, dataset_args, segmentation_args, model_args, general_args = hf_parser.parse_args_into_dataclasses()
|
@@ -821,8 +827,7 @@ def main():
|
|
821 |
# , max_videos, max_segments
|
822 |
|
823 |
from model import get_model_tokenizer
|
824 |
-
model, tokenizer = get_model_tokenizer(
|
825 |
-
model_args.model_name_or_path, model_args.cache_dir, general_args.no_cuda)
|
826 |
|
827 |
# TODO
|
828 |
# count_videos = 0
|
@@ -871,8 +876,9 @@ def main():
|
|
871 |
continue
|
872 |
|
873 |
d = {
|
874 |
-
'video_index': offset + start_index,
|
875 |
'video_id': video_id,
|
|
|
876 |
'text': ' '.join(x['cleaned'] for x in seg),
|
877 |
'start': seg_start,
|
878 |
'end': seg_end,
|
@@ -919,7 +925,7 @@ def main():
|
|
919 |
z = int(preprocess_args.percentage_positive /
|
920 |
percentage_negative * len(non_sponsors))
|
921 |
|
922 |
-
excess = sponsors[z:]
|
923 |
sponsors = sponsors[:z]
|
924 |
|
925 |
else:
|
@@ -927,7 +933,7 @@ def main():
|
|
927 |
z = int(percentage_negative /
|
928 |
preprocess_args.percentage_positive * len(sponsors))
|
929 |
|
930 |
-
excess = non_sponsors[z:]
|
931 |
non_sponsors = non_sponsors[:z]
|
932 |
|
933 |
logger.info('Join')
|
@@ -935,6 +941,7 @@ def main():
|
|
935 |
|
936 |
random.shuffle(all_labelled_segments)
|
937 |
|
|
|
938 |
logger.info('Split')
|
939 |
ratios = [preprocess_args.train_split,
|
940 |
preprocess_args.test_split,
|
@@ -958,15 +965,52 @@ def main():
|
|
958 |
else:
|
959 |
logger.info(f'Skipping {name}')
|
960 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
961 |
logger.info('Write')
|
962 |
# Save excess items
|
963 |
-
excess_path = os.path.join(
|
964 |
-
|
965 |
-
if not os.path.exists(excess_path) or preprocess_args.overwrite:
|
966 |
-
|
967 |
-
|
968 |
-
else:
|
969 |
-
|
970 |
|
971 |
logger.info(
|
972 |
f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
|
|
|
1 |
+
from shared import DatasetArguments
|
2 |
from utils import jaccard
|
3 |
from functools import lru_cache
|
4 |
from datetime import datetime
|
5 |
import itertools
|
6 |
+
from typing import Optional
|
7 |
+
import model as model_module
|
8 |
import segment
|
9 |
from tqdm import tqdm
|
10 |
from dataclasses import dataclass, field
|
11 |
from transformers import HfArgumentParser
|
12 |
+
from shared import extract_sponsor_matches, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
|
13 |
import csv
|
14 |
import re
|
15 |
import random
|
|
|
214 |
else:
|
215 |
ts = transcript_list.find_generated_transcript(
|
216 |
LANGUAGE_PREFERENCE_LIST)
|
217 |
+
raw_transcript = ts._http_client.get(
|
218 |
+
f'{ts._url}&fmt=json3').content
|
219 |
+
if raw_transcript:
|
220 |
+
raw_transcript_json = json.loads(raw_transcript)
|
221 |
|
222 |
except (TooManyRequests, YouTubeRequestFailed):
|
223 |
raise # Cannot recover from these errors and do not mark as empty transcript
|
|
|
388 |
|
389 |
max_date: str = field(
|
390 |
# default='01/01/9999', # Include all
|
391 |
+
default='01/03/2022',
|
392 |
metadata={'help': 'Only use videos that have some segment from before this date (exclusive). This allows for videos to have segments be corrected, but ignores new videos (posted after this date) to enter the pool.'})
|
393 |
|
394 |
+
# max_unseen_date: str = field( # TODO
|
395 |
+
# default='02/03/2022',
|
396 |
+
# metadata={'help': 'Generate test and validation data from `max_date` to `max_unseen_date`'})
|
397 |
+
# Specify min/max video id for splitting (seen vs. unseen)
|
398 |
+
|
399 |
keep_duplicate_segments: bool = field(
|
400 |
default=False, metadata={'help': 'Keep duplicate segments'}
|
401 |
)
|
|
|
489 |
|
490 |
|
491 |
@dataclass
|
492 |
+
class PreprocessingDatasetArguments(DatasetArguments):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
493 |
|
494 |
train_file: Optional[str] = field(
|
495 |
default='train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
|
|
497 |
validation_file: Optional[str] = field(
|
498 |
default='valid.json',
|
499 |
metadata={
|
500 |
+
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
501 |
},
|
502 |
)
|
503 |
test_file: Optional[str] = field(
|
504 |
default='test.json',
|
505 |
metadata={
|
506 |
+
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
507 |
},
|
508 |
)
|
509 |
+
|
510 |
+
c_train_file: Optional[str] = field(
|
511 |
+
default='c_train.json', metadata={'help': 'The input training data file (a jsonlines file).'}
|
512 |
+
)
|
513 |
+
c_validation_file: Optional[str] = field(
|
514 |
+
default='c_valid.json',
|
515 |
metadata={
|
516 |
+
'help': 'An optional input evaluation data file to evaluate the metrics on (a jsonlines file).'
|
517 |
},
|
518 |
)
|
519 |
+
c_test_file: Optional[str] = field(
|
520 |
+
default='c_test.json',
|
521 |
+
metadata={
|
522 |
+
'help': 'An optional input test data file to evaluate the metrics on (a jsonlines file).'
|
523 |
+
},
|
524 |
+
)
|
525 |
+
|
526 |
+
# excess_file: Optional[str] = field(
|
527 |
+
# default='excess.json',
|
528 |
+
# metadata={
|
529 |
+
# 'help': 'The excess segments left after the split'
|
530 |
+
# },
|
531 |
+
# )
|
532 |
dataset_cache_dir: Optional[str] = field(
|
533 |
default=None,
|
534 |
metadata={
|
|
|
561 |
# Generate final.json from sponsorTimes.csv
|
562 |
hf_parser = HfArgumentParser((
|
563 |
PreprocessArguments,
|
564 |
+
PreprocessingDatasetArguments,
|
565 |
segment.SegmentationArguments,
|
566 |
+
model_module.ModelArguments,
|
567 |
GeneralArguments
|
568 |
))
|
569 |
preprocess_args, dataset_args, segmentation_args, model_args, general_args = hf_parser.parse_args_into_dataclasses()
|
|
|
827 |
# , max_videos, max_segments
|
828 |
|
829 |
from model import get_model_tokenizer
|
830 |
+
model, tokenizer = get_model_tokenizer(model_args, general_args)
|
|
|
831 |
|
832 |
# TODO
|
833 |
# count_videos = 0
|
|
|
876 |
continue
|
877 |
|
878 |
d = {
|
879 |
+
# 'video_index': offset + start_index,
|
880 |
'video_id': video_id,
|
881 |
+
# 'uuid': video_id, # TODO add uuid
|
882 |
'text': ' '.join(x['cleaned'] for x in seg),
|
883 |
'start': seg_start,
|
884 |
'end': seg_end,
|
|
|
925 |
z = int(preprocess_args.percentage_positive /
|
926 |
percentage_negative * len(non_sponsors))
|
927 |
|
928 |
+
# excess = sponsors[z:]
|
929 |
sponsors = sponsors[:z]
|
930 |
|
931 |
else:
|
|
|
933 |
z = int(percentage_negative /
|
934 |
preprocess_args.percentage_positive * len(sponsors))
|
935 |
|
936 |
+
# excess = non_sponsors[z:]
|
937 |
non_sponsors = non_sponsors[:z]
|
938 |
|
939 |
logger.info('Join')
|
|
|
941 |
|
942 |
random.shuffle(all_labelled_segments)
|
943 |
|
944 |
+
# TODO split based on video ids
|
945 |
logger.info('Split')
|
946 |
ratios = [preprocess_args.train_split,
|
947 |
preprocess_args.test_split,
|
|
|
965 |
else:
|
966 |
logger.info(f'Skipping {name}')
|
967 |
|
968 |
+
classifier_splits = {
|
969 |
+
dataset_args.c_train_file: train_data,
|
970 |
+
dataset_args.c_test_file: test_data,
|
971 |
+
dataset_args.c_validation_file: valid_data
|
972 |
+
}
|
973 |
+
|
974 |
+
none_category = CATEGORIES.index(None)
|
975 |
+
|
976 |
+
# Output training, testing and validation data
|
977 |
+
for name, items in classifier_splits.items():
|
978 |
+
outfile = os.path.join(dataset_args.data_dir, name)
|
979 |
+
if not os.path.exists(outfile) or preprocess_args.overwrite:
|
980 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
981 |
+
for i in items:
|
982 |
+
x = json.loads(i) # TODO add uuid
|
983 |
+
labelled_items = []
|
984 |
+
|
985 |
+
matches = extract_sponsor_matches(x['extracted'])
|
986 |
+
|
987 |
+
if x['extracted'] == CustomTokens.NO_SEGMENT.value:
|
988 |
+
labelled_items.append({
|
989 |
+
'text': x['text'],
|
990 |
+
'label': none_category
|
991 |
+
})
|
992 |
+
else:
|
993 |
+
for match in matches:
|
994 |
+
labelled_items.append({
|
995 |
+
'text': match['text'],
|
996 |
+
'label': CATEGORIES.index(match['category'])
|
997 |
+
})
|
998 |
+
|
999 |
+
for labelled_item in labelled_items:
|
1000 |
+
print(json.dumps(labelled_item), file=fp)
|
1001 |
+
|
1002 |
+
else:
|
1003 |
+
logger.info(f'Skipping {name}')
|
1004 |
+
|
1005 |
logger.info('Write')
|
1006 |
# Save excess items
|
1007 |
+
# excess_path = os.path.join(
|
1008 |
+
# dataset_args.data_dir, dataset_args.excess_file)
|
1009 |
+
# if not os.path.exists(excess_path) or preprocess_args.overwrite:
|
1010 |
+
# with open(excess_path, 'w', encoding='utf-8') as fp:
|
1011 |
+
# fp.writelines(excess)
|
1012 |
+
# else:
|
1013 |
+
# logger.info(f'Skipping {dataset_args.excess_file}')
|
1014 |
|
1015 |
logger.info(
|
1016 |
f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
|
src/segment.py
CHANGED
@@ -121,6 +121,8 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
121 |
|
122 |
def extract_segment(words, start, end, map_function=None):
|
123 |
"""Extracts all words with time in [start, end]"""
|
|
|
|
|
124 |
|
125 |
a = max(binary_search_below(words, 0, len(words), start), 0)
|
126 |
b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
|
|
|
121 |
|
122 |
def extract_segment(words, start, end, map_function=None):
|
123 |
"""Extracts all words with time in [start, end]"""
|
124 |
+
if words is None:
|
125 |
+
words = []
|
126 |
|
127 |
a = max(binary_search_below(words, 0, len(words), start), 0)
|
128 |
b = min(binary_search_above(words, -1, len(words) - 1, end) + 1, len(words))
|
src/shared.py
CHANGED
@@ -1,3 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import re
|
2 |
import gc
|
3 |
from time import time_ns
|
@@ -8,6 +15,8 @@ from typing import Optional
|
|
8 |
from dataclasses import dataclass, field
|
9 |
from enum import Enum
|
10 |
|
|
|
|
|
11 |
ACTION_OPTIONS = ['skip', 'mute', 'full']
|
12 |
|
13 |
CATGEGORY_OPTIONS = {
|
@@ -62,6 +71,47 @@ class CustomTokens(Enum):
|
|
62 |
tokenizer.add_tokens(cls.custom_tokens())
|
63 |
|
64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
@dataclass
|
66 |
class OutputArguments:
|
67 |
|
@@ -126,3 +176,106 @@ def reset():
|
|
126 |
torch.cuda.empty_cache()
|
127 |
gc.collect()
|
128 |
print(torch.cuda.memory_summary(device=None, abbreviated=False))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers.trainer_utils import get_last_checkpoint as glc
|
2 |
+
from transformers import TrainingArguments
|
3 |
+
import os
|
4 |
+
from utils import re_findall
|
5 |
+
import logging
|
6 |
+
import sys
|
7 |
+
from datasets import load_dataset
|
8 |
import re
|
9 |
import gc
|
10 |
from time import time_ns
|
|
|
15 |
from dataclasses import dataclass, field
|
16 |
from enum import Enum
|
17 |
|
18 |
+
CATEGORIES = [None, 'SPONSOR', 'SELFPROMO', 'INTERACTION']
|
19 |
+
|
20 |
ACTION_OPTIONS = ['skip', 'mute', 'full']
|
21 |
|
22 |
CATGEGORY_OPTIONS = {
|
|
|
71 |
tokenizer.add_tokens(cls.custom_tokens())
|
72 |
|
73 |
|
74 |
+
_SEGMENT_START = START_SEGMENT_TEMPLATE.format(r'(?P<category>\w+)')
|
75 |
+
_SEGMENT_END = END_SEGMENT_TEMPLATE.format(r'\w+')
|
76 |
+
SEGMENT_MATCH_RE = fr'{_SEGMENT_START}\s*(?P<text>.*?)\s*(?:{_SEGMENT_END}|$)'
|
77 |
+
|
78 |
+
|
79 |
+
def extract_sponsor_matches(text):
|
80 |
+
if CustomTokens.NO_SEGMENT.value in text:
|
81 |
+
return []
|
82 |
+
|
83 |
+
return re_findall(SEGMENT_MATCH_RE, text)
|
84 |
+
|
85 |
+
|
86 |
+
@dataclass
|
87 |
+
class DatasetArguments:
|
88 |
+
data_dir: Optional[str] = field(
|
89 |
+
default='data',
|
90 |
+
metadata={
|
91 |
+
'help': 'The directory which stores train, test and/or validation data.'
|
92 |
+
},
|
93 |
+
)
|
94 |
+
processed_file: Optional[str] = field(
|
95 |
+
default='segments.json',
|
96 |
+
metadata={
|
97 |
+
'help': 'Processed data file'
|
98 |
+
},
|
99 |
+
)
|
100 |
+
processed_database: Optional[str] = field(
|
101 |
+
default='processed_database.json',
|
102 |
+
metadata={
|
103 |
+
'help': 'Processed database file'
|
104 |
+
},
|
105 |
+
)
|
106 |
+
|
107 |
+
dataset_cache_dir: Optional[str] = field(
|
108 |
+
default=None,
|
109 |
+
metadata={
|
110 |
+
'help': 'Where to store the cached datasets'
|
111 |
+
},
|
112 |
+
)
|
113 |
+
|
114 |
+
|
115 |
@dataclass
|
116 |
class OutputArguments:
|
117 |
|
|
|
176 |
torch.cuda.empty_cache()
|
177 |
gc.collect()
|
178 |
print(torch.cuda.memory_summary(device=None, abbreviated=False))
|
179 |
+
|
180 |
+
|
181 |
+
def load_datasets(dataset_args):
|
182 |
+
|
183 |
+
print('Reading datasets')
|
184 |
+
data_files = {}
|
185 |
+
|
186 |
+
if dataset_args.train_file is not None:
|
187 |
+
data_files['train'] = os.path.join(
|
188 |
+
dataset_args.data_dir, dataset_args.train_file)
|
189 |
+
if dataset_args.validation_file is not None:
|
190 |
+
data_files['validation'] = os.path.join(
|
191 |
+
dataset_args.data_dir, dataset_args.validation_file)
|
192 |
+
if dataset_args.test_file is not None:
|
193 |
+
data_files['test'] = os.path.join(
|
194 |
+
dataset_args.data_dir, dataset_args.test_file)
|
195 |
+
|
196 |
+
return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
|
197 |
+
|
198 |
+
|
199 |
+
@dataclass
|
200 |
+
class CustomTrainingArguments(OutputArguments, TrainingArguments):
|
201 |
+
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
202 |
+
|
203 |
+
num_train_epochs: float = field(
|
204 |
+
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
205 |
+
|
206 |
+
save_steps: int = field(default=5000, metadata={
|
207 |
+
'help': 'Save checkpoint every X updates steps.'})
|
208 |
+
eval_steps: int = field(default=5000, metadata={
|
209 |
+
'help': 'Run an evaluation every X steps.'})
|
210 |
+
logging_steps: int = field(default=5000, metadata={
|
211 |
+
'help': 'Log every X updates steps.'})
|
212 |
+
|
213 |
+
# do_eval: bool = field(default=False, metadata={
|
214 |
+
# 'help': 'Whether to run eval on the dev set.'})
|
215 |
+
# do_predict: bool = field(default=False, metadata={
|
216 |
+
# 'help': 'Whether to run predictions on the test set.'})
|
217 |
+
|
218 |
+
per_device_train_batch_size: int = field(
|
219 |
+
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
|
220 |
+
)
|
221 |
+
per_device_eval_batch_size: int = field(
|
222 |
+
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
|
223 |
+
)
|
224 |
+
|
225 |
+
# report_to: Optional[List[str]] = field(
|
226 |
+
# default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
227 |
+
# )
|
228 |
+
evaluation_strategy: str = field(
|
229 |
+
default='steps',
|
230 |
+
metadata={
|
231 |
+
'help': 'The evaluation strategy to use.',
|
232 |
+
'choices': ['no', 'steps', 'epoch']
|
233 |
+
},
|
234 |
+
)
|
235 |
+
|
236 |
+
# evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
237 |
+
# The evaluation strategy to adopt during training. Possible values are:
|
238 |
+
|
239 |
+
# * :obj:`"no"`: No evaluation is done during training.
|
240 |
+
# * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
|
241 |
+
# * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
|
242 |
+
|
243 |
+
|
244 |
+
logging.basicConfig()
|
245 |
+
logger = logging.getLogger(__name__)
|
246 |
+
|
247 |
+
# Setup logging
|
248 |
+
logging.basicConfig(
|
249 |
+
format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
|
250 |
+
datefmt='%m/%d/%Y %H:%M:%S',
|
251 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
252 |
+
)
|
253 |
+
|
254 |
+
|
255 |
+
def get_last_checkpoint(training_args):
|
256 |
+
last_checkpoint = None
|
257 |
+
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
258 |
+
last_checkpoint = glc(training_args.output_dir)
|
259 |
+
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
260 |
+
raise ValueError(
|
261 |
+
f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
|
262 |
+
)
|
263 |
+
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
264 |
+
logger.info(
|
265 |
+
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
|
266 |
+
)
|
267 |
+
return last_checkpoint
|
268 |
+
|
269 |
+
|
270 |
+
def train_from_checkpoint(trainer, last_checkpoint, training_args):
|
271 |
+
checkpoint = None
|
272 |
+
if training_args.resume_from_checkpoint is not None:
|
273 |
+
checkpoint = training_args.resume_from_checkpoint
|
274 |
+
elif last_checkpoint is not None:
|
275 |
+
checkpoint = last_checkpoint
|
276 |
+
|
277 |
+
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
278 |
+
|
279 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
280 |
+
|
281 |
+
return train_result
|
src/train.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
-
from
|
2 |
-
from
|
3 |
-
from predict import ClassifierArguments, SEGMENT_MATCH_RE, CATEGORIES
|
4 |
-
from shared import CustomTokens, GeneralArguments, OutputArguments
|
5 |
from model import ModelArguments
|
6 |
import transformers
|
7 |
import logging
|
@@ -9,21 +7,15 @@ import os
|
|
9 |
import sys
|
10 |
from dataclasses import dataclass, field
|
11 |
from typing import Optional
|
12 |
-
import
|
13 |
-
import pickle
|
14 |
from transformers import (
|
15 |
DataCollatorForSeq2Seq,
|
16 |
HfArgumentParser,
|
17 |
Seq2SeqTrainer,
|
18 |
-
Seq2SeqTrainingArguments
|
19 |
)
|
20 |
|
21 |
-
from transformers.trainer_utils import get_last_checkpoint
|
22 |
from transformers.utils import check_min_version
|
23 |
from transformers.utils.versions import require_version
|
24 |
-
from sklearn.linear_model import LogisticRegression
|
25 |
-
from sklearn.feature_extraction.text import TfidfVectorizer
|
26 |
-
from utils import re_findall
|
27 |
|
28 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
29 |
check_min_version('4.13.0.dev0')
|
@@ -43,23 +35,6 @@ logging.basicConfig(
|
|
43 |
)
|
44 |
|
45 |
|
46 |
-
def load_datasets(dataset_args):
|
47 |
-
|
48 |
-
print('Reading datasets')
|
49 |
-
data_files = {}
|
50 |
-
|
51 |
-
if dataset_args.train_file is not None:
|
52 |
-
data_files['train'] = os.path.join(
|
53 |
-
dataset_args.data_dir, dataset_args.train_file)
|
54 |
-
if dataset_args.validation_file is not None:
|
55 |
-
data_files['validation'] = os.path.join(
|
56 |
-
dataset_args.data_dir, dataset_args.validation_file)
|
57 |
-
if dataset_args.test_file is not None:
|
58 |
-
data_files['test'] = os.path.join(
|
59 |
-
dataset_args.data_dir, dataset_args.test_file)
|
60 |
-
|
61 |
-
return load_dataset('json', data_files=data_files, cache_dir=dataset_args.dataset_cache_dir)
|
62 |
-
|
63 |
|
64 |
@dataclass
|
65 |
class DataTrainingArguments:
|
@@ -92,58 +67,7 @@ class DataTrainingArguments:
|
|
92 |
)
|
93 |
|
94 |
|
95 |
-
@dataclass
|
96 |
-
class SequenceTrainingArguments(OutputArguments, Seq2SeqTrainingArguments):
|
97 |
-
seed: Optional[int] = GeneralArguments.__dataclass_fields__['seed']
|
98 |
-
|
99 |
-
num_train_epochs: float = field(
|
100 |
-
default=1, metadata={'help': 'Total number of training epochs to perform.'})
|
101 |
-
|
102 |
-
save_steps: int = field(default=5000, metadata={
|
103 |
-
'help': 'Save checkpoint every X updates steps.'})
|
104 |
-
eval_steps: int = field(default=5000, metadata={
|
105 |
-
'help': 'Run an evaluation every X steps.'})
|
106 |
-
logging_steps: int = field(default=5000, metadata={
|
107 |
-
'help': 'Log every X updates steps.'})
|
108 |
-
|
109 |
-
skip_train_transformer: bool = field(default=False, metadata={
|
110 |
-
'help': 'Whether to skip training the transformer.'})
|
111 |
-
train_classifier: bool = field(default=False, metadata={
|
112 |
-
'help': 'Whether to run training on the 2nd phase (classifier).'})
|
113 |
-
|
114 |
-
# do_eval: bool = field(default=False, metadata={
|
115 |
-
# 'help': 'Whether to run eval on the dev set.'})
|
116 |
-
do_predict: bool = field(default=False, metadata={
|
117 |
-
'help': 'Whether to run predictions on the test set.'})
|
118 |
-
|
119 |
-
per_device_train_batch_size: int = field(
|
120 |
-
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for training.'}
|
121 |
-
)
|
122 |
-
per_device_eval_batch_size: int = field(
|
123 |
-
default=4, metadata={'help': 'Batch size per GPU/TPU core/CPU for evaluation.'}
|
124 |
-
)
|
125 |
-
|
126 |
-
# report_to: Optional[List[str]] = field(
|
127 |
-
# default=None, metadata={"help": "The list of integrations to report the results and logs to."}
|
128 |
-
# )
|
129 |
-
evaluation_strategy: str = field(
|
130 |
-
default='steps',
|
131 |
-
metadata={
|
132 |
-
'help': 'The evaluation strategy to use.',
|
133 |
-
'choices': ['no', 'steps', 'epoch']
|
134 |
-
},
|
135 |
-
)
|
136 |
-
|
137 |
-
# evaluation_strategy (:obj:`str` or :class:`~transformers.trainer_utils.IntervalStrategy`, `optional`, defaults to :obj:`"no"`):
|
138 |
-
# The evaluation strategy to adopt during training. Possible values are:
|
139 |
-
|
140 |
-
# * :obj:`"no"`: No evaluation is done during training.
|
141 |
-
# * :obj:`"steps"`: Evaluation is done (and logged) every :obj:`eval_steps`.
|
142 |
-
# * :obj:`"epoch"`: Evaluation is done at the end of each epoch.
|
143 |
-
|
144 |
-
|
145 |
def main():
|
146 |
-
# reset()
|
147 |
|
148 |
# See all possible arguments in src/transformers/training_args.py
|
149 |
# or by passing the --help flag to this script.
|
@@ -151,16 +75,15 @@ def main():
|
|
151 |
|
152 |
hf_parser = HfArgumentParser((
|
153 |
ModelArguments,
|
154 |
-
|
155 |
DataTrainingArguments,
|
156 |
-
|
157 |
-
ClassifierArguments
|
158 |
))
|
159 |
-
model_args, dataset_args, data_training_args, training_args
|
160 |
|
161 |
log_level = training_args.get_process_log_level()
|
162 |
logger.setLevel(log_level)
|
163 |
-
|
164 |
transformers.utils.logging.set_verbosity(log_level)
|
165 |
transformers.utils.logging.enable_default_handler()
|
166 |
transformers.utils.logging.enable_explicit_format()
|
@@ -199,231 +122,130 @@ def main():
|
|
199 |
|
200 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
201 |
# download the dataset.
|
202 |
-
if training_args.skip_train_transformer and not training_args.train_classifier:
|
203 |
-
print('Nothing to do. Exiting')
|
204 |
-
return
|
205 |
-
|
206 |
raw_datasets = load_datasets(dataset_args)
|
207 |
# , cache_dir=model_args.cache_dir
|
208 |
|
209 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
210 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
211 |
|
212 |
-
if training_args.train_classifier:
|
213 |
-
print('Train classifier')
|
214 |
-
# 1. Vectorize raw data to pass into classifier
|
215 |
-
# CountVectorizer TfidfVectorizer
|
216 |
-
# TfidfVectorizer - better (comb of CountVectorizer)
|
217 |
-
vectorizer = TfidfVectorizer( # CountVectorizer
|
218 |
-
# lowercase=False,
|
219 |
-
# stop_words='english', # TODO optimise stop words?
|
220 |
-
# stop_words=stop_words,
|
221 |
-
|
222 |
-
ngram_range=(1, 2), # best so far
|
223 |
-
# max_features=8000 # remove for higher accuracy?
|
224 |
-
max_features=20000
|
225 |
-
# max_features=10000
|
226 |
-
# max_features=1000
|
227 |
-
)
|
228 |
|
229 |
-
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
236 |
-
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
247 |
-
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
-
|
252 |
-
|
253 |
-
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
y_test = train_test_data['test']['y']
|
263 |
-
|
264 |
-
# 2. Create classifier
|
265 |
-
classifier = LogisticRegression(max_iter=2000, class_weight='balanced')
|
266 |
-
|
267 |
-
# 3. Fit data
|
268 |
-
print('Fit classifier')
|
269 |
-
classifier.fit(_X_train, y_train)
|
270 |
-
|
271 |
-
# 4. Measure accuracy
|
272 |
-
accuracy = classifier.score(_X_test, y_test)
|
273 |
-
|
274 |
-
print(f'[LogisticRegression] Accuracy percent:',
|
275 |
-
round(accuracy*100, 3))
|
276 |
-
|
277 |
-
# 5. Save classifier and vectorizer
|
278 |
-
with open(os.path.join(classifier_args.classifier_dir, classifier_args.classifier_file), 'wb') as fp:
|
279 |
-
pickle.dump(classifier, fp)
|
280 |
-
|
281 |
-
with open(os.path.join(classifier_args.classifier_dir, classifier_args.vectorizer_file), 'wb') as fp:
|
282 |
-
pickle.dump(vectorizer, fp)
|
283 |
-
|
284 |
-
if not training_args.skip_train_transformer:
|
285 |
-
# Detecting last checkpoint.
|
286 |
-
last_checkpoint = None
|
287 |
-
if os.path.isdir(training_args.output_dir) and not training_args.overwrite_output_dir:
|
288 |
-
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
289 |
-
if last_checkpoint is None and len(os.listdir(training_args.output_dir)) > 0:
|
290 |
-
raise ValueError(
|
291 |
-
f'Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome.'
|
292 |
-
)
|
293 |
-
elif last_checkpoint is not None and training_args.resume_from_checkpoint is None:
|
294 |
-
logger.info(
|
295 |
-
f'Checkpoint detected, resuming training at {last_checkpoint}. To avoid this behavior, change the `--output_dir` or add `--overwrite_output_dir` to train from scratch.'
|
296 |
-
)
|
297 |
-
|
298 |
-
from model import get_model_tokenizer
|
299 |
-
model, tokenizer = get_model_tokenizer(
|
300 |
-
model_args.model_name_or_path, model_args.cache_dir, training_args.no_cuda)
|
301 |
-
|
302 |
-
# Preprocessing the datasets.
|
303 |
-
# We need to tokenize inputs and targets.
|
304 |
-
column_names = raw_datasets['train'].column_names
|
305 |
-
|
306 |
-
prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
|
307 |
-
|
308 |
-
PAD_TOKEN_REPLACE_ID = -100
|
309 |
-
|
310 |
-
# https://github.com/huggingface/transformers/issues/5204
|
311 |
-
def preprocess_function(examples):
|
312 |
-
inputs = examples['text']
|
313 |
-
targets = examples['extracted']
|
314 |
-
inputs = [prefix + inp for inp in inputs]
|
315 |
-
model_inputs = tokenizer(inputs, truncation=True)
|
316 |
-
|
317 |
-
# Setup the tokenizer for targets
|
318 |
-
with tokenizer.as_target_tokenizer():
|
319 |
-
labels = tokenizer(targets, truncation=True)
|
320 |
-
|
321 |
-
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
|
322 |
-
# when we want to ignore padding in the loss.
|
323 |
-
|
324 |
-
model_inputs['labels'] = [
|
325 |
-
[(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
|
326 |
-
for l in label]
|
327 |
-
for label in labels['input_ids']
|
328 |
-
]
|
329 |
-
|
330 |
-
return model_inputs
|
331 |
-
|
332 |
-
def prepare_dataset(dataset, desc):
|
333 |
-
return dataset.map(
|
334 |
-
preprocess_function,
|
335 |
-
batched=True,
|
336 |
-
num_proc=data_training_args.preprocessing_num_workers,
|
337 |
-
remove_columns=column_names,
|
338 |
-
load_from_cache_file=not dataset_args.overwrite_cache,
|
339 |
-
desc=desc, # tokenizing train dataset
|
340 |
-
)
|
341 |
-
# train_dataset # TODO shuffle?
|
342 |
-
|
343 |
-
# if training_args.do_train:
|
344 |
-
if 'train' not in raw_datasets: # TODO do checks above?
|
345 |
-
raise ValueError('Train dataset missing')
|
346 |
-
train_dataset = raw_datasets['train']
|
347 |
-
if data_training_args.max_train_samples is not None:
|
348 |
-
train_dataset = train_dataset.select(
|
349 |
-
range(data_training_args.max_train_samples))
|
350 |
-
with training_args.main_process_first(desc='train dataset map pre-processing'):
|
351 |
-
train_dataset = prepare_dataset(
|
352 |
-
train_dataset, desc='Running tokenizer on train dataset')
|
353 |
-
|
354 |
-
if 'validation' not in raw_datasets:
|
355 |
-
raise ValueError('Validation dataset missing')
|
356 |
-
eval_dataset = raw_datasets['validation']
|
357 |
-
if data_training_args.max_eval_samples is not None:
|
358 |
-
eval_dataset = eval_dataset.select(
|
359 |
-
range(data_training_args.max_eval_samples))
|
360 |
-
with training_args.main_process_first(desc='validation dataset map pre-processing'):
|
361 |
-
eval_dataset = prepare_dataset(
|
362 |
-
eval_dataset, desc='Running tokenizer on validation dataset')
|
363 |
-
|
364 |
-
if 'test' not in raw_datasets:
|
365 |
-
raise ValueError('Test dataset missing')
|
366 |
-
predict_dataset = raw_datasets['test']
|
367 |
-
if data_training_args.max_predict_samples is not None:
|
368 |
-
predict_dataset = predict_dataset.select(
|
369 |
-
range(data_training_args.max_predict_samples))
|
370 |
-
with training_args.main_process_first(desc='prediction dataset map pre-processing'):
|
371 |
-
predict_dataset = prepare_dataset(
|
372 |
-
predict_dataset, desc='Running tokenizer on prediction dataset')
|
373 |
-
|
374 |
-
# Data collator
|
375 |
-
data_collator = DataCollatorForSeq2Seq(
|
376 |
-
tokenizer,
|
377 |
-
model=model,
|
378 |
-
label_pad_token_id=PAD_TOKEN_REPLACE_ID,
|
379 |
-
pad_to_multiple_of=8 if training_args.fp16 else None,
|
380 |
-
)
|
381 |
|
382 |
-
|
383 |
|
384 |
-
|
385 |
-
|
386 |
-
|
387 |
-
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
393 |
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
elif last_checkpoint is not None:
|
399 |
-
checkpoint = last_checkpoint
|
400 |
-
|
401 |
-
try:
|
402 |
-
train_result = trainer.train(resume_from_checkpoint=checkpoint)
|
403 |
-
trainer.save_model() # Saves the tokenizer too for easy upload
|
404 |
-
except KeyboardInterrupt:
|
405 |
-
# TODO add option to save model on interrupt?
|
406 |
-
# print('Saving model')
|
407 |
-
# trainer.save_model(os.path.join(
|
408 |
-
# training_args.output_dir, 'checkpoint-latest')) # TODO use dir
|
409 |
-
raise
|
410 |
-
|
411 |
-
metrics = train_result.metrics
|
412 |
-
max_train_samples = data_training_args.max_train_samples or len(
|
413 |
-
train_dataset)
|
414 |
-
metrics['train_samples'] = min(max_train_samples, len(train_dataset))
|
415 |
-
|
416 |
-
trainer.log_metrics('train', metrics)
|
417 |
-
trainer.save_metrics('train', metrics)
|
418 |
-
trainer.save_state()
|
419 |
-
|
420 |
-
kwargs = {'finetuned_from': model_args.model_name_or_path,
|
421 |
-
'tasks': 'summarization'}
|
422 |
-
|
423 |
-
if training_args.push_to_hub:
|
424 |
-
trainer.push_to_hub(**kwargs)
|
425 |
-
else:
|
426 |
-
trainer.create_model_card(**kwargs)
|
427 |
|
428 |
|
429 |
if __name__ == '__main__':
|
|
|
1 |
+
from preprocess import PreprocessingDatasetArguments
|
2 |
+
from shared import CustomTokens, load_datasets, CustomTrainingArguments, get_last_checkpoint, train_from_checkpoint
|
|
|
|
|
3 |
from model import ModelArguments
|
4 |
import transformers
|
5 |
import logging
|
|
|
7 |
import sys
|
8 |
from dataclasses import dataclass, field
|
9 |
from typing import Optional
|
10 |
+
from datasets import utils as d_utils
|
|
|
11 |
from transformers import (
|
12 |
DataCollatorForSeq2Seq,
|
13 |
HfArgumentParser,
|
14 |
Seq2SeqTrainer,
|
|
|
15 |
)
|
16 |
|
|
|
17 |
from transformers.utils import check_min_version
|
18 |
from transformers.utils.versions import require_version
|
|
|
|
|
|
|
19 |
|
20 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
21 |
check_min_version('4.13.0.dev0')
|
|
|
35 |
)
|
36 |
|
37 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
|
39 |
@dataclass
|
40 |
class DataTrainingArguments:
|
|
|
67 |
)
|
68 |
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
def main():
|
|
|
71 |
|
72 |
# See all possible arguments in src/transformers/training_args.py
|
73 |
# or by passing the --help flag to this script.
|
|
|
75 |
|
76 |
hf_parser = HfArgumentParser((
|
77 |
ModelArguments,
|
78 |
+
PreprocessingDatasetArguments,
|
79 |
DataTrainingArguments,
|
80 |
+
CustomTrainingArguments
|
|
|
81 |
))
|
82 |
+
model_args, dataset_args, data_training_args, training_args = hf_parser.parse_args_into_dataclasses()
|
83 |
|
84 |
log_level = training_args.get_process_log_level()
|
85 |
logger.setLevel(log_level)
|
86 |
+
d_utils.logging.set_verbosity(log_level)
|
87 |
transformers.utils.logging.set_verbosity(log_level)
|
88 |
transformers.utils.logging.enable_default_handler()
|
89 |
transformers.utils.logging.enable_explicit_format()
|
|
|
122 |
|
123 |
# In distributed training, the load_dataset function guarantees that only one local process can concurrently
|
124 |
# download the dataset.
|
|
|
|
|
|
|
|
|
125 |
raw_datasets = load_datasets(dataset_args)
|
126 |
# , cache_dir=model_args.cache_dir
|
127 |
|
128 |
# See more about loading any type of standard or custom dataset (from files, python dict, pandas DataFrame, etc) at
|
129 |
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
130 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
+
# Detecting last checkpoint.
|
133 |
+
last_checkpoint = get_last_checkpoint(training_args)
|
134 |
+
|
135 |
+
from model import get_model_tokenizer
|
136 |
+
model, tokenizer = get_model_tokenizer(model_args, training_args)
|
137 |
+
|
138 |
+
# Preprocessing the datasets.
|
139 |
+
# We need to tokenize inputs and targets.
|
140 |
+
column_names = raw_datasets['train'].column_names
|
141 |
+
|
142 |
+
prefix = CustomTokens.EXTRACT_SEGMENTS_PREFIX.value
|
143 |
+
|
144 |
+
PAD_TOKEN_REPLACE_ID = -100
|
145 |
+
|
146 |
+
# https://github.com/huggingface/transformers/issues/5204
|
147 |
+
def preprocess_function(examples):
|
148 |
+
inputs = examples['text']
|
149 |
+
targets = examples['extracted']
|
150 |
+
inputs = [prefix + inp for inp in inputs]
|
151 |
+
model_inputs = tokenizer(inputs, truncation=True)
|
152 |
+
|
153 |
+
# Setup the tokenizer for targets
|
154 |
+
with tokenizer.as_target_tokenizer():
|
155 |
+
labels = tokenizer(targets, truncation=True)
|
156 |
+
|
157 |
+
# If we are padding here, replace all tokenizer.pad_token_id in the labels by -100
|
158 |
+
# when we want to ignore padding in the loss.
|
159 |
+
|
160 |
+
model_inputs['labels'] = [
|
161 |
+
[(l if l != tokenizer.pad_token_id else PAD_TOKEN_REPLACE_ID)
|
162 |
+
for l in label]
|
163 |
+
for label in labels['input_ids']
|
164 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
+
return model_inputs
|
167 |
|
168 |
+
def prepare_dataset(dataset, desc):
|
169 |
+
return dataset.map(
|
170 |
+
preprocess_function,
|
171 |
+
batched=True,
|
172 |
+
num_proc=data_training_args.preprocessing_num_workers,
|
173 |
+
remove_columns=column_names,
|
174 |
+
load_from_cache_file=not dataset_args.overwrite_cache,
|
175 |
+
desc=desc, # tokenizing train dataset
|
176 |
)
|
177 |
+
# train_dataset # TODO shuffle?
|
178 |
+
|
179 |
+
# if training_args.do_train:
|
180 |
+
if 'train' not in raw_datasets: # TODO do checks above?
|
181 |
+
raise ValueError('Train dataset missing')
|
182 |
+
train_dataset = raw_datasets['train']
|
183 |
+
if data_training_args.max_train_samples is not None:
|
184 |
+
train_dataset = train_dataset.select(
|
185 |
+
range(data_training_args.max_train_samples))
|
186 |
+
with training_args.main_process_first(desc='train dataset map pre-processing'):
|
187 |
+
train_dataset = prepare_dataset(
|
188 |
+
train_dataset, desc='Running tokenizer on train dataset')
|
189 |
+
|
190 |
+
if 'validation' not in raw_datasets:
|
191 |
+
raise ValueError('Validation dataset missing')
|
192 |
+
eval_dataset = raw_datasets['validation']
|
193 |
+
if data_training_args.max_eval_samples is not None:
|
194 |
+
eval_dataset = eval_dataset.select(
|
195 |
+
range(data_training_args.max_eval_samples))
|
196 |
+
with training_args.main_process_first(desc='validation dataset map pre-processing'):
|
197 |
+
eval_dataset = prepare_dataset(
|
198 |
+
eval_dataset, desc='Running tokenizer on validation dataset')
|
199 |
+
|
200 |
+
if 'test' not in raw_datasets:
|
201 |
+
raise ValueError('Test dataset missing')
|
202 |
+
predict_dataset = raw_datasets['test']
|
203 |
+
if data_training_args.max_predict_samples is not None:
|
204 |
+
predict_dataset = predict_dataset.select(
|
205 |
+
range(data_training_args.max_predict_samples))
|
206 |
+
with training_args.main_process_first(desc='prediction dataset map pre-processing'):
|
207 |
+
predict_dataset = prepare_dataset(
|
208 |
+
predict_dataset, desc='Running tokenizer on prediction dataset')
|
209 |
+
|
210 |
+
# Data collator
|
211 |
+
data_collator = DataCollatorForSeq2Seq(
|
212 |
+
tokenizer,
|
213 |
+
model=model,
|
214 |
+
label_pad_token_id=PAD_TOKEN_REPLACE_ID,
|
215 |
+
pad_to_multiple_of=8 if training_args.fp16 else None,
|
216 |
+
)
|
217 |
+
|
218 |
+
# Done processing datasets
|
219 |
+
|
220 |
+
# Initialize our Trainer
|
221 |
+
trainer = Seq2SeqTrainer(
|
222 |
+
model=model,
|
223 |
+
args=training_args,
|
224 |
+
train_dataset=train_dataset,
|
225 |
+
eval_dataset=eval_dataset,
|
226 |
+
tokenizer=tokenizer,
|
227 |
+
data_collator=data_collator,
|
228 |
+
)
|
229 |
+
|
230 |
+
# Training
|
231 |
+
train_result = train_from_checkpoint(trainer, last_checkpoint, training_args)
|
232 |
+
|
233 |
+
metrics = train_result.metrics
|
234 |
+
max_train_samples = data_training_args.max_train_samples or len(
|
235 |
+
train_dataset)
|
236 |
+
metrics['train_samples'] = min(max_train_samples, len(train_dataset))
|
237 |
+
|
238 |
+
trainer.log_metrics('train', metrics)
|
239 |
+
trainer.save_metrics('train', metrics)
|
240 |
+
trainer.save_state()
|
241 |
+
|
242 |
+
kwargs = {'finetuned_from': model_args.model_name_or_path,
|
243 |
+
'tasks': 'summarization'}
|
244 |
|
245 |
+
if training_args.push_to_hub:
|
246 |
+
trainer.push_to_hub(**kwargs)
|
247 |
+
else:
|
248 |
+
trainer.create_model_card(**kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
249 |
|
250 |
|
251 |
if __name__ == '__main__':
|
src/train_classifier.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
""" Finetuning the library models for sequence classification."""
|
3 |
+
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
import random
|
7 |
+
import sys
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from typing import Optional
|
10 |
+
|
11 |
+
import datasets
|
12 |
+
import numpy as np
|
13 |
+
from datasets import load_metric
|
14 |
+
|
15 |
+
import transformers
|
16 |
+
from transformers import (
|
17 |
+
DataCollatorWithPadding,
|
18 |
+
EvalPrediction,
|
19 |
+
HfArgumentParser,
|
20 |
+
Trainer,
|
21 |
+
default_data_collator,
|
22 |
+
set_seed,
|
23 |
+
)
|
24 |
+
from transformers.utils import check_min_version
|
25 |
+
from transformers.utils.versions import require_version
|
26 |
+
from shared import CATEGORIES, load_datasets, CustomTrainingArguments, train_from_checkpoint, get_last_checkpoint
|
27 |
+
from preprocess import PreprocessingDatasetArguments
|
28 |
+
from model import get_model_tokenizer, ModelArguments
|
29 |
+
|
30 |
+
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
31 |
+
check_min_version("4.17.0")
|
32 |
+
require_version("datasets>=1.8.0", "To fix: pip install -r requirements.txt")
|
33 |
+
|
34 |
+
os.environ["WANDB_DISABLED"] = "true"
|
35 |
+
|
36 |
+
logger = logging.getLogger(__name__)
|
37 |
+
|
38 |
+
|
39 |
+
@dataclass
|
40 |
+
class DataArguments:
|
41 |
+
"""
|
42 |
+
Arguments pertaining to what data we are going to input our model for training and eval.
|
43 |
+
|
44 |
+
Using `HfArgumentParser` we can turn this class
|
45 |
+
into argparse arguments to be able to specify them on
|
46 |
+
the command line.
|
47 |
+
"""
|
48 |
+
|
49 |
+
max_seq_length: int = field(
|
50 |
+
default=512,
|
51 |
+
metadata={
|
52 |
+
"help": "The maximum total input sequence length after tokenization. Sequences longer "
|
53 |
+
"than this will be truncated, sequences shorter will be padded."
|
54 |
+
},
|
55 |
+
)
|
56 |
+
overwrite_cache: bool = field(
|
57 |
+
default=False, metadata={"help": "Overwrite the cached preprocessed datasets or not."}
|
58 |
+
)
|
59 |
+
pad_to_max_length: bool = field(
|
60 |
+
default=True,
|
61 |
+
metadata={
|
62 |
+
"help": "Whether to pad all samples to `max_seq_length`. "
|
63 |
+
"If False, will pad the samples dynamically when batching to the maximum length in the batch."
|
64 |
+
},
|
65 |
+
)
|
66 |
+
max_train_samples: Optional[int] = field(
|
67 |
+
default=None,
|
68 |
+
metadata={
|
69 |
+
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
70 |
+
"value if set."
|
71 |
+
},
|
72 |
+
)
|
73 |
+
max_eval_samples: Optional[int] = field(
|
74 |
+
default=None,
|
75 |
+
metadata={
|
76 |
+
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
77 |
+
"value if set."
|
78 |
+
},
|
79 |
+
)
|
80 |
+
max_predict_samples: Optional[int] = field(
|
81 |
+
default=None,
|
82 |
+
metadata={
|
83 |
+
"help": "For debugging purposes or quicker training, truncate the number of prediction examples to this "
|
84 |
+
"value if set."
|
85 |
+
},
|
86 |
+
)
|
87 |
+
|
88 |
+
dataset_cache_dir: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
|
89 |
+
'dataset_cache_dir']
|
90 |
+
data_dir: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
|
91 |
+
'data_dir']
|
92 |
+
train_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
|
93 |
+
'c_train_file']
|
94 |
+
validation_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
|
95 |
+
'c_validation_file']
|
96 |
+
test_file: Optional[str] = PreprocessingDatasetArguments.__dataclass_fields__[
|
97 |
+
'c_test_file']
|
98 |
+
|
99 |
+
def __post_init__(self):
|
100 |
+
if self.train_file is None or self.validation_file is None:
|
101 |
+
raise ValueError(
|
102 |
+
"Need either a GLUE task, a training/validation file or a dataset name.")
|
103 |
+
else:
|
104 |
+
train_extension = self.train_file.split(".")[-1]
|
105 |
+
assert train_extension in [
|
106 |
+
"csv", "json"], "`train_file` should be a csv or a json file."
|
107 |
+
validation_extension = self.validation_file.split(".")[-1]
|
108 |
+
assert (
|
109 |
+
validation_extension == train_extension
|
110 |
+
), "`validation_file` should have the same extension (csv or json) as `train_file`."
|
111 |
+
|
112 |
+
|
113 |
+
def main():
|
114 |
+
# See all possible arguments in src/transformers/training_args.py
|
115 |
+
# or by passing the --help flag to this script.
|
116 |
+
# We now keep distinct sets of args, for a cleaner separation of concerns.
|
117 |
+
|
118 |
+
parser = HfArgumentParser(
|
119 |
+
(ModelArguments, DataArguments, CustomTrainingArguments))
|
120 |
+
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
|
121 |
+
|
122 |
+
# Setup logging
|
123 |
+
logging.basicConfig(
|
124 |
+
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
125 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
126 |
+
handlers=[logging.StreamHandler(sys.stdout)],
|
127 |
+
)
|
128 |
+
|
129 |
+
log_level = training_args.get_process_log_level()
|
130 |
+
logger.setLevel(log_level)
|
131 |
+
datasets.utils.logging.set_verbosity(log_level)
|
132 |
+
transformers.utils.logging.set_verbosity(log_level)
|
133 |
+
transformers.utils.logging.enable_default_handler()
|
134 |
+
transformers.utils.logging.enable_explicit_format()
|
135 |
+
|
136 |
+
# Log on each process the small summary:
|
137 |
+
logger.warning(
|
138 |
+
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
|
139 |
+
+ f"distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
|
140 |
+
)
|
141 |
+
logger.info(f"Training/evaluation parameters {training_args}")
|
142 |
+
|
143 |
+
# Detecting last checkpoint.
|
144 |
+
last_checkpoint = get_last_checkpoint(training_args)
|
145 |
+
|
146 |
+
# Set seed before initializing model.
|
147 |
+
set_seed(training_args.seed)
|
148 |
+
|
149 |
+
# Loading a dataset from your local files.
|
150 |
+
# CSV/JSON training and evaluation files are needed.
|
151 |
+
raw_datasets = load_datasets(data_args)
|
152 |
+
|
153 |
+
# See more about loading any type of standard or custom dataset at
|
154 |
+
# https://huggingface.co/docs/datasets/loading_datasets.html.
|
155 |
+
|
156 |
+
config_args = {
|
157 |
+
'num_labels': len(CATEGORIES),
|
158 |
+
'id2label': {k: str(v).upper() for k, v in enumerate(CATEGORIES)},
|
159 |
+
'label2id': {str(v).upper(): k for k, v in enumerate(CATEGORIES)}
|
160 |
+
}
|
161 |
+
model, tokenizer = get_model_tokenizer(model_args, training_args, config_args=config_args, model_type='classifier')
|
162 |
+
|
163 |
+
|
164 |
+
# Padding strategy
|
165 |
+
if data_args.pad_to_max_length:
|
166 |
+
padding = "max_length"
|
167 |
+
else:
|
168 |
+
# We will pad later, dynamically at batch creation, to the max sequence length in each batch
|
169 |
+
padding = False
|
170 |
+
|
171 |
+
if data_args.max_seq_length > tokenizer.model_max_length:
|
172 |
+
logger.warning(
|
173 |
+
f"The max_seq_length passed ({data_args.max_seq_length}) is larger than the maximum length for the"
|
174 |
+
f"model ({tokenizer.model_max_length}). Using max_seq_length={tokenizer.model_max_length}."
|
175 |
+
)
|
176 |
+
max_seq_length = min(data_args.max_seq_length, tokenizer.model_max_length)
|
177 |
+
|
178 |
+
def preprocess_function(examples):
|
179 |
+
# Tokenize the texts
|
180 |
+
result = tokenizer(
|
181 |
+
examples['text'], padding=padding, max_length=max_seq_length, truncation=True)
|
182 |
+
result['label'] = examples['label']
|
183 |
+
return result
|
184 |
+
|
185 |
+
with training_args.main_process_first(desc="dataset map pre-processing"):
|
186 |
+
raw_datasets = raw_datasets.map(
|
187 |
+
preprocess_function,
|
188 |
+
batched=True,
|
189 |
+
load_from_cache_file=not data_args.overwrite_cache,
|
190 |
+
desc="Running tokenizer on dataset",
|
191 |
+
)
|
192 |
+
if training_args.do_train:
|
193 |
+
if "train" not in raw_datasets:
|
194 |
+
raise ValueError("--do_train requires a train dataset")
|
195 |
+
train_dataset = raw_datasets["train"]
|
196 |
+
if data_args.max_train_samples is not None:
|
197 |
+
train_dataset = train_dataset.select(
|
198 |
+
range(data_args.max_train_samples))
|
199 |
+
|
200 |
+
if training_args.do_eval:
|
201 |
+
if "validation" not in raw_datasets:
|
202 |
+
raise ValueError("--do_eval requires a validation dataset")
|
203 |
+
eval_dataset = raw_datasets["validation"]
|
204 |
+
if data_args.max_eval_samples is not None:
|
205 |
+
eval_dataset = eval_dataset.select(
|
206 |
+
range(data_args.max_eval_samples))
|
207 |
+
|
208 |
+
if training_args.do_predict or data_args.test_file is not None:
|
209 |
+
if "test" not in raw_datasets:
|
210 |
+
raise ValueError("--do_predict requires a test dataset")
|
211 |
+
predict_dataset = raw_datasets["test"]
|
212 |
+
if data_args.max_predict_samples is not None:
|
213 |
+
predict_dataset = predict_dataset.select(
|
214 |
+
range(data_args.max_predict_samples))
|
215 |
+
|
216 |
+
# Log a few random samples from the training set:
|
217 |
+
if training_args.do_train:
|
218 |
+
for index in random.sample(range(len(train_dataset)), 3):
|
219 |
+
logger.info(
|
220 |
+
f"Sample {index} of the training set: {train_dataset[index]}.")
|
221 |
+
|
222 |
+
# Get the metric function
|
223 |
+
metric = load_metric("accuracy")
|
224 |
+
|
225 |
+
# You can define your custom compute_metrics function. It takes an `EvalPrediction` object (a namedtuple with a
|
226 |
+
# predictions and label_ids field) and has to return a dictionary string to float.
|
227 |
+
def compute_metrics(p: EvalPrediction):
|
228 |
+
preds = p.predictions[0] if isinstance(
|
229 |
+
p.predictions, tuple) else p.predictions
|
230 |
+
preds = np.argmax(preds, axis=1)
|
231 |
+
if data_args.task_name is not None:
|
232 |
+
result = metric.compute(predictions=preds, references=p.label_ids)
|
233 |
+
if len(result) > 1:
|
234 |
+
result["combined_score"] = np.mean(
|
235 |
+
list(result.values())).item()
|
236 |
+
return result
|
237 |
+
else:
|
238 |
+
return {"accuracy": (preds == p.label_ids).astype(np.float32).mean().item()}
|
239 |
+
|
240 |
+
# Data collator will default to DataCollatorWithPadding when the tokenizer is passed to Trainer, so we change it if
|
241 |
+
# we already did the padding.
|
242 |
+
if data_args.pad_to_max_length:
|
243 |
+
data_collator = default_data_collator
|
244 |
+
elif training_args.fp16:
|
245 |
+
data_collator = DataCollatorWithPadding(
|
246 |
+
tokenizer, pad_to_multiple_of=8)
|
247 |
+
else:
|
248 |
+
data_collator = None
|
249 |
+
|
250 |
+
# Initialize our Trainer
|
251 |
+
trainer = Trainer(
|
252 |
+
model=model,
|
253 |
+
args=training_args,
|
254 |
+
train_dataset=train_dataset,
|
255 |
+
eval_dataset=eval_dataset,
|
256 |
+
compute_metrics=compute_metrics,
|
257 |
+
tokenizer=tokenizer,
|
258 |
+
data_collator=data_collator,
|
259 |
+
)
|
260 |
+
|
261 |
+
# Training
|
262 |
+
train_result = train_from_checkpoint(
|
263 |
+
trainer, last_checkpoint, training_args)
|
264 |
+
|
265 |
+
metrics = train_result.metrics
|
266 |
+
max_train_samples = (
|
267 |
+
data_args.max_train_samples if data_args.max_train_samples is not None else len(
|
268 |
+
train_dataset)
|
269 |
+
)
|
270 |
+
metrics["train_samples"] = min(max_train_samples, len(train_dataset))
|
271 |
+
|
272 |
+
trainer.save_model() # Saves the tokenizer too for easy upload
|
273 |
+
|
274 |
+
trainer.log_metrics("train", metrics)
|
275 |
+
trainer.save_metrics("train", metrics)
|
276 |
+
trainer.save_state()
|
277 |
+
|
278 |
+
kwargs = {"finetuned_from": model_args.model_name_or_path,
|
279 |
+
"tasks": "text-classification"}
|
280 |
+
if training_args.push_to_hub:
|
281 |
+
trainer.push_to_hub(**kwargs)
|
282 |
+
else:
|
283 |
+
trainer.create_model_card(**kwargs)
|
284 |
+
|
285 |
+
|
286 |
+
if __name__ == "__main__":
|
287 |
+
main()
|
src/utils.py
CHANGED
@@ -1,8 +1,4 @@
|
|
1 |
import re
|
2 |
-
import logging
|
3 |
-
|
4 |
-
logging.basicConfig()
|
5 |
-
logger = logging.getLogger(__name__)
|
6 |
|
7 |
|
8 |
def re_findall(pattern, string):
|
|
|
1 |
import re
|
|
|
|
|
|
|
|
|
2 |
|
3 |
|
4 |
def re_findall(pattern, string):
|