File size: 16,851 Bytes
b7b7347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409

from model import get_model_tokenizer_classifier, InferenceArguments
from utils import jaccard, safe_print
from transformers import HfArgumentParser
from preprocess import get_words, clean_text
from shared import GeneralArguments, DatasetArguments
from predict import predict
from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
import pandas as pd
from dataclasses import dataclass, field
from typing import Optional
from tqdm import tqdm
import json
import os
import random
from shared import seconds_to_time
from urllib.parse import quote
import logging

logging.basicConfig()
logger = logging.getLogger(__name__)


@dataclass
class EvaluationArguments(InferenceArguments):
    """Arguments pertaining to how evaluation will occur."""
    output_file: Optional[str] = field(
        default='metrics.csv',
        metadata={
            'help': 'Save metrics to output file'
        }
    )

    skip_missing: bool = field(
        default=False,
        metadata={
            'help': 'Whether to skip checking for missing segments. If False, predictions will be made.'
        }
    )
    skip_incorrect: bool = field(
        default=False,
        metadata={
            'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.'
        }
    )


def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
    """Attach sponsor segments to closest prediction"""
    for prediction in predictions:
        prediction['best_overlap'] = 0
        prediction['best_sponsorship'] = None

        # Assign predictions to actual (labelled) sponsored segments
        for sponsor_segment in sponsor_segments:
            j = jaccard(prediction['start'], prediction['end'],
                        sponsor_segment['start'], sponsor_segment['end'])
            if prediction['best_overlap'] < j:
                prediction['best_overlap'] = j
                prediction['best_sponsorship'] = sponsor_segment

    return sponsor_segments


def calculate_metrics(labelled_words, predictions):

    metrics = {
        'true_positive': 0,  # Is sponsor, predicted sponsor
        # Is sponsor, predicted not sponsor (i.e., missed it - bad)
        'false_negative': 0,
        # Is not sponsor, predicted sponsor (classified incorectly, not that bad since we do manual checking afterwards)
        'false_positive': 0,
        'true_negative': 0,  # Is not sponsor, predicted not sponsor
    }

    metrics['video_duration'] = word_end(
        labelled_words[-1])-word_start(labelled_words[0])

    for index, word in enumerate(labelled_words):
        if index >= len(labelled_words) - 1:
            continue

        duration = word_end(word) - word_start(word)

        predicted_sponsor = False
        for p in predictions:
            # Is in some prediction
            if p['start'] <= word['start'] <= p['end']:
                predicted_sponsor = True
                break

        if predicted_sponsor:
            # total_positive_time += duration
            if word.get('category') is not None:  # Is actual sponsor
                metrics['true_positive'] += duration
            else:
                metrics['false_positive'] += duration
        else:
            # total_negative_time += duration
            if word.get('category') is not None:  # Is actual sponsor
                metrics['false_negative'] += duration
            else:
                metrics['true_negative'] += duration

    # NOTE In cases where we encounter division by 0, we say that the value is 1
    # https://stats.stackexchange.com/a/1775
    # (Precision) TP+FP=0: means that all instances were predicted as negative
    # (Recall)    TP+FN=0: means that there were no positive cases in the input data

    # The fraction of predictions our model got right
    # Can simplify, but use full formula
    z = metrics['true_positive'] + metrics['true_negative'] + \
        metrics['false_positive'] + metrics['false_negative']
    metrics['accuracy'] = (
        (metrics['true_positive'] + metrics['true_negative']) / z) if z > 0 else 1

    # What proportion of positive identifications was actually correct?
    z = metrics['true_positive'] + metrics['false_positive']
    metrics['precision'] = (metrics['true_positive'] / z) if z > 0 else 1

    # What proportion of actual positives was identified correctly?
    z = metrics['true_positive'] + metrics['false_negative']
    metrics['recall'] = (metrics['true_positive'] / z) if z > 0 else 1

    # https://deepai.org/machine-learning-glossary-and-terms/f-score

    s = metrics['precision'] + metrics['recall']
    metrics['f-score'] = (2 * (metrics['precision'] *
                               metrics['recall']) / s) if s > 0 else 0

    return metrics


def main():
    logger.setLevel(logging.DEBUG)

    hf_parser = HfArgumentParser((
        EvaluationArguments,
        DatasetArguments,
        SegmentationArguments,
        GeneralArguments
    ))

    evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()

    if evaluation_args.skip_missing and evaluation_args.skip_incorrect:
        logger.error('ERROR: Nothing to do')
        return

    # Load labelled data:
    final_path = os.path.join(
        dataset_args.data_dir, dataset_args.processed_file)

    if not os.path.exists(final_path):
        logger.error('ERROR: Processed database not found.\n'
                     f'Run `python src/preprocess.py --update_database --do_create` to generate "{final_path}".')
        return

    model, tokenizer, classifier = get_model_tokenizer_classifier(
        evaluation_args, general_args)

    with open(final_path) as fp:
        final_data = json.load(fp)

    if evaluation_args.video_ids:  # Use specified
        video_ids = evaluation_args.video_ids

    else:  # Use items found in preprocessed database
        video_ids = list(final_data.keys())
        random.shuffle(video_ids)

        if evaluation_args.start_index is not None:
            video_ids = video_ids[evaluation_args.start_index:]

        if evaluation_args.max_videos is not None:
            video_ids = video_ids[:evaluation_args.max_videos]

    out_metrics = []

    all_metrics = {}
    if not evaluation_args.skip_missing:
        all_metrics['total_prediction_accuracy'] = 0
        all_metrics['total_prediction_precision'] = 0
        all_metrics['total_prediction_recall'] = 0
        all_metrics['total_prediction_fscore'] = 0

    if not evaluation_args.skip_incorrect:
        all_metrics['classifier_segment_correct'] = 0
        all_metrics['classifier_segment_count'] = 0

    metric_count = 0

    postfix_info = {}

    try:
        with tqdm(video_ids) as progress:
            for video_index, video_id in enumerate(progress):
                progress.set_description(f'Processing {video_id}')

                words = get_words(video_id)
                if not words:
                    continue

                # Get labels
                sponsor_segments = final_data.get(video_id)

                # Reset previous
                missed_segments = []
                incorrect_segments = []

                current_metrics = {
                    'video_id': video_id
                }
                metric_count += 1

                if not evaluation_args.skip_missing:  # Make predictions
                    predictions = predict(video_id, model, tokenizer, segmentation_args,
                                          classifier=classifier,
                                          min_probability=evaluation_args.min_probability)

                    if sponsor_segments:
                        labelled_words = add_labels_to_words(
                            words, sponsor_segments)

                        current_metrics.update(
                            calculate_metrics(labelled_words, predictions))

                        all_metrics['total_prediction_accuracy'] += current_metrics['accuracy']
                        all_metrics['total_prediction_precision'] += current_metrics['precision']
                        all_metrics['total_prediction_recall'] += current_metrics['recall']
                        all_metrics['total_prediction_fscore'] += current_metrics['f-score']

                        # Just for display purposes
                        postfix_info.update({
                            'accuracy': all_metrics['total_prediction_accuracy']/metric_count,
                            'precision':  all_metrics['total_prediction_precision']/metric_count,
                            'recall':  all_metrics['total_prediction_recall']/metric_count,
                            'f-score': all_metrics['total_prediction_fscore']/metric_count,
                        })

                        sponsor_segments = attach_predictions_to_sponsor_segments(
                            predictions, sponsor_segments)

                        # Identify possible issues:
                        for prediction in predictions:
                            if prediction['best_sponsorship'] is not None:
                                continue

                            prediction_words = prediction.pop('words', [])

                            # Attach original text to missed segments
                            prediction['text'] = ' '.join(
                                x['text'] for x in prediction_words)
                            missed_segments.append(prediction)

                    else:
                        # Not in database (all segments missed)
                        missed_segments = predictions

                if not evaluation_args.skip_incorrect and sponsor_segments:
                    # Check for incorrect segments using the classifier

                    segments_to_check = []
                    cleaned_texts = []  # Texts to send through tokenizer
                    for sponsor_segment in sponsor_segments:
                        segment_words = extract_segment(
                            words,  sponsor_segment['start'],  sponsor_segment['end'])
                        sponsor_segment['text'] = ' '.join(
                            x['text'] for x in segment_words)

                        duration = sponsor_segment['end'] - \
                            sponsor_segment['start']
                        wps = (len(segment_words) /
                               duration) if duration > 0 else 0
                        if wps < 1.5:
                            continue

                        # Do not worry about those that are locked or have enough votes
                        # or segment['votes'] > 5:
                        if sponsor_segment['locked']:
                            continue

                        cleaned_texts.append(
                            clean_text(sponsor_segment['text']))
                        segments_to_check.append(sponsor_segment)

                    if segments_to_check:  # Some segments to check

                        segments_scores = classifier(cleaned_texts)

                        num_correct = 0
                        for segment, scores in zip(segments_to_check, segments_scores):

                            fixed_scores = {
                                score['label']: score['score']
                                for score in scores
                            }

                            all_metrics['classifier_segment_count'] += 1

                            prediction = max(scores, key=lambda x: x['score'])
                            predicted_category = prediction['label'].lower()

                            if predicted_category == segment['category']:
                                num_correct += 1
                                continue  # Ignore correct segments

                            segment.update({
                                'predicted': predicted_category,
                                'scores': fixed_scores
                            })

                            incorrect_segments.append(segment)

                        current_metrics['num_segments'] = len(
                            segments_to_check)
                        current_metrics['classified_correct'] = num_correct

                        all_metrics['classifier_segment_correct'] += num_correct

                    if all_metrics['classifier_segment_count'] > 0:
                        postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
                            all_metrics['classifier_segment_count']

                out_metrics.append(current_metrics)
                progress.set_postfix(postfix_info)

                if missed_segments or incorrect_segments:

                    if evaluation_args.output_as_json:
                        to_print = {'video_id': video_id}

                        if missed_segments:
                            to_print['missed'] = missed_segments

                        if incorrect_segments:
                            to_print['incorrect'] = incorrect_segments

                        safe_print(json.dumps(to_print))

                    else:
                        safe_print(
                            f'Issues identified for {video_id} (#{video_index})')
                        # Potentially missed segments (model predicted, but not in database)
                        if missed_segments:
                            safe_print(' - Missed segments:')
                            segments_to_submit = []
                            for i, missed_segment in enumerate(missed_segments, start=1):
                                safe_print(f'\t#{i}:', seconds_to_time(
                                    missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
                                safe_print('\t\tText: "',
                                           missed_segment['text'], '"', sep='')
                                safe_print('\t\tCategory:',
                                           missed_segment.get('category'))
                                if 'probability' in missed_segment:
                                    safe_print('\t\tProbability:',
                                               missed_segment['probability'])

                                segments_to_submit.append({
                                    'segment': [missed_segment['start'], missed_segment['end']],
                                    'category': missed_segment['category'].lower(),
                                    'actionType': 'skip'
                                })

                            json_data = quote(json.dumps(segments_to_submit))
                            safe_print(
                                f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')

                        # Incorrect segments (in database, but incorrectly classified)
                        if incorrect_segments:
                            safe_print(' - Incorrect segments:')
                            for i, incorrect_segment in enumerate(incorrect_segments, start=1):
                                safe_print(f'\t#{i}:', seconds_to_time(
                                    incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))

                                safe_print(
                                    '\t\tText: "', incorrect_segment['text'], '"', sep='')
                                safe_print(
                                    '\t\tUUID:', incorrect_segment['uuid'])
                                safe_print(
                                    '\t\tVotes:', incorrect_segment['votes'])
                                safe_print(
                                    '\t\tViews:', incorrect_segment['views'])
                                safe_print('\t\tLocked:',
                                           incorrect_segment['locked'])

                                safe_print('\t\tCurrent Category:',
                                           incorrect_segment['category'])
                                safe_print('\t\tPredicted Category:',
                                           incorrect_segment['predicted'])
                                safe_print('\t\tProbabilities:')
                                for label, score in incorrect_segment['scores'].items():
                                    safe_print(
                                        f"\t\t\t{label}: {score}")

                        safe_print()

    except KeyboardInterrupt:
        pass

    df = pd.DataFrame(out_metrics)

    df.to_csv(evaluation_args.output_file)
    logger.info(df.mean())


if __name__ == '__main__':
    main()