File size: 3,434 Bytes
85e172b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import argparse

import numpy as np
from tqdm import tqdm

from util import utils
from dsets import wikipedia



def extract_wikipedia_context_cache(
        cache_path,
        models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'],
        max_token_len = 100,
        max_len = 25,
        min_len = 7,
        total_to_sample = 10000
    ):

    # find paths to wikitrain and wikitest sets
    ps = [
        os.path.join(cache_path, 'wiki_train'),
        os.path.join(cache_path, 'wiki_test')
    ]

    # find all wikipedia feature pickles
    pickle_files = []
    for p in ps:
        for model in models:
            pickle_files += [os.path.join(p, f) for f in os.listdir(p) if f.endswith('.pickle') if model in f]

    print(f'Based on {len(pickle_files)} cached wikipedia feature pickles')

    # find all wikipedia samples already sampled
    sampled_indices = []
    for f in tqdm(pickle_files):
        contents = utils.loadpickle(f)
        sampled_indices += list(contents['sampled_indices'])

    sampled_indices = np.unique(sampled_indices)
    print('Total number of sampled indices:', len(sampled_indices))

    # load a tokenizer
    tok = utils.load_tok('llama-3-8b')

    # load model 
    raw_ds, _ = wikipedia.get_ds(tok, maxlen=max_token_len)

    # find potential indices to sample
    o1, o2, bt = utils.comp(np.arange(len(raw_ds)), sampled_indices)
    potential_indices = np.array(list(o1))

    new_sampled_indices = []
    new_sampled_texts = []
    number_sampled = 0

    # progress bar
    pbar = tqdm(total=total_to_sample)

    while number_sampled < total_to_sample:

        i = int(np.random.choice(potential_indices))

        if i not in new_sampled_indices:
            first_sentence = raw_ds.__getitem__(i)['text'].split('. ')[0]

            if ('{' not in first_sentence) and ('}' not in first_sentence):

                token_length = len(tok.encode(first_sentence))

                if (token_length <= max_len) and (token_length >= min_len):

                    new_sampled_indices.append(i)
                    new_sampled_texts.append(first_sentence)

                    number_sampled += 1
                    pbar.update(1)

    # back to full sentences
    new_sampled_texts = [t + '. ' for t in new_sampled_texts]

    augmented_cache_path = os.path.join(cache_path, f'augmented_wikipedia_context_first_sentence_max{max_len}_min{min_len}.json')
    utils.savejson(augmented_cache_path, {'augmented_cache': new_sampled_texts})
    print('Saved to:', augmented_cache_path)


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
    
    parser.add_argument(
        '--cache_path', type=str, default='./cache/', help='output directory')

    parser.add_argument(
        '--min_len', type=int, default=7, help='minimum length of sentences in tokens')
    parser.add_argument(
        '--max_len', type=int, default=25, help='maximum length of sentences in tokens')

    parser.add_argument(
        '--sample_size', type=int, default=10000, help='number of sentences to sample')

    args = parser.parse_args()

    # find wikipeida context cache
    extract_wikipedia_context_cache(
            cache_path = args.cache_path,
            models = ['gpt-j-6b', 'llama-3-8b', 'mamba-1.4b'],
            max_token_len = 100,
            max_len = args.max_len,
            min_len = args.min_len,
            total_to_sample = args.sample_size
        )