|
import json |
|
import os |
|
import re |
|
import torch |
|
import torch.nn.functional as F |
|
from typing import Optional, Sequence, List, Set, Dict, Any, Union |
|
import transformers |
|
import logging |
|
from dataclasses import dataclass |
|
import pathlib |
|
from torch.utils.data import DataLoader |
|
from utils.constants import DEFAULT_BOS_TOKEN, DEFAULT_EOS_TOKEN, IGNORE_INDEX |
|
import pdb |
|
|
|
|
|
|
|
def read_jsonl(path: str): |
|
try: |
|
with open(path) as fh: |
|
return [json.loads(line) for line in fh.readlines() if line] |
|
except: |
|
return json.load(open(path, 'r', encoding= 'utf-8')) |
|
|
|
|
|
|
|
def get_few_shot_prompt(data_dir, prompt_file): |
|
with open(os.path.join(data_dir, prompt_file), 'r') as f: |
|
prompt = f.read() |
|
return prompt.replace('{', '{{').replace('}', '}}').replace('{{test_question}}', '{test_question}') |
|
|
|
|
|
|
|
def handle_jsonfile_psv(json_list): |
|
def get_label(result ,output): |
|
if result['status'] == 'pass': |
|
return -1 |
|
elif result['status'] in ['nopass_limit', 'nopass_error']: |
|
return 0 |
|
else: |
|
if result['string_pos'] == -1: |
|
return 0 |
|
return result['string_pos'] |
|
all_results = [] |
|
idx = 0 |
|
for item in json_list: |
|
dict_item = { |
|
"idx": idx, |
|
"question": item['question'], |
|
"input": item['question'], |
|
"ground_truth_cot": item['answer'], |
|
"ground_truth":item['answer'], |
|
"outputs": [ |
|
{ |
|
"response": item['total output'][id].split("#align")[0], |
|
"label": get_label(item['results'][id], item['total output'][id].split("#align")[0]), |
|
} |
|
for id in range(len(item['results'])) |
|
], |
|
} |
|
idx += 1 |
|
all_results.append(dict_item) |
|
pass1 = 0 |
|
pass5 = 0 |
|
for item in all_results: |
|
pass1 += item['outputs'][0]['label'] == -1 |
|
for output in item['outputs']: |
|
if output['label'] == -1: |
|
pass5 += 1 |
|
break |
|
print(pass1/len(all_results)) |
|
print(pass5/len(all_results)) |
|
return all_results |
|
def handle_jsonfile(json_list): |
|
def get_label(result,output): |
|
if result['status'] == 'pass': |
|
return True |
|
else: |
|
for me in json.loads(result['stdout'])['messages']: |
|
if me['severity'] == 'error': |
|
endpos = me['endPos'] |
|
return endpos |
|
assert True |
|
all_results = [] |
|
idx = 0 |
|
for item in json_list: |
|
dict_item = { |
|
"idx": idx, |
|
"question": item['question'], |
|
"input": item['question'], |
|
"ground_truth_cot": item['answer'], |
|
"ground_truth":item['answer'], |
|
"outputs": [ |
|
{ |
|
"response": item['total output'][id].split("#align")[0], |
|
"label": item['results'][id]['status'] == 'pass', |
|
} |
|
for id in range(len(item['results'])) |
|
], |
|
} |
|
idx += 1 |
|
all_results.append(dict_item) |
|
return all_results |
|
|
|
def get_model_solutions_easy(data_dir, generator_id, target_set, process : bool = False ): |
|
|
|
examples = [] |
|
for dd in data_dir.split(","): |
|
examples += read_jsonl(dd)['results'] |
|
examples = handle_jsonfile(examples) |
|
print(f"{len(examples)} {target_set} examples, each with {len(examples[0]['outputs'])} solutions") |
|
return examples |
|
|
|
def get_model_solutions_psv(data_dir, generator_id, target_set, process : bool = False ): |
|
|
|
examples = [] |
|
for dd in data_dir.split(","): |
|
examples += read_jsonl(dd)['results'] |
|
examples = handle_jsonfile_psv(examples) |
|
print(f"{len(examples)} {target_set} examples, each with {len(examples[0]['outputs'])} solutions") |
|
return examples |
|
|
|
def get_model_solutions(data_dir, generator_id, target_set, process : bool = False ): |
|
data_dir = os.path.join(data_dir, target_set) |
|
|
|
if process: |
|
files_pattern = f'responses_n*_{generator_id}_process.jsonl' |
|
else: |
|
files_pattern = f'responses_n*_{generator_id}.jsonl' |
|
|
|
response_files = [str(x) for x in pathlib.Path(data_dir).glob(files_pattern)] |
|
if not response_files: |
|
raise ValueError(f'Fail to find {files_pattern} in {data_dir}') |
|
|
|
ordering_and_response_path = [] |
|
for response_file in response_files: |
|
regex_match = re.match(r".*responses_n([0-9]+)", response_file) |
|
if regex_match is not None: |
|
ordering_and_response_path.append((int(regex_match.group(1)), response_file)) |
|
responses_sorted = sorted(ordering_and_response_path) |
|
responses_sorted = [response[1] for response in responses_sorted] |
|
read_file = responses_sorted[-1] |
|
examples = read_jsonl(read_file) |
|
print(f"{len(examples)} {target_set} examples, each with {len(examples[0]['outputs'])} solutions") |
|
return examples |
|
|
|
|
|
def get_model_solutions_self(data_dir,data_id, verifier_id ,generator_id, process: bool = False): |
|
|
|
|
|
|
|
|
|
if not len(data_id): |
|
files_pattern = f"responses_v({verifier_id})_g({generator_id})_process_supervision.jsonl" |
|
else: |
|
files_pattern = f"responses_d({data_id})_v({verifier_id})_g({generator_id})_process_supervision.jsonl" |
|
|
|
|
|
|
|
response_files = [str(x) for x in pathlib.Path(data_dir).glob(files_pattern)] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
read_file = response_files[-1] |
|
except: |
|
print(f"found no files under {data_dir} for the pattern {files_pattern}") |
|
|
|
print(read_file) |
|
|
|
examples = read_jsonl(read_file) |
|
print(f"{len(examples)} examples, each with {len(examples[0]['outputs'])} solutions") |
|
return examples |
|
|
|
|
|
def make_training_dataloaders( |
|
data_module: Dict[str, torch.utils.data.Dataset], |
|
training_args: dataclass = None, |
|
) -> Dict: |
|
train_dataloader = DataLoader( |
|
data_module['train_dataset'], |
|
batch_size=training_args.per_device_train_batch_size, |
|
shuffle=True, |
|
drop_last=False, |
|
collate_fn=data_module['train_dataset'].collate_fn, |
|
) |
|
if data_module['val_dataset'] is not None: |
|
val_dataloader = DataLoader( |
|
data_module['val_dataset'], |
|
batch_size=training_args.per_device_eval_batch_size, |
|
shuffle=False, |
|
drop_last=False, |
|
collate_fn=data_module['val_dataset'].collate_fn, |
|
) |
|
else: |
|
val_dataloader = None |
|
return train_dataloader, val_dataloader |
|
|
|
|
|
|
|
def make_testing_dataloader( |
|
dataset: torch.utils.data.Dataset, |
|
batch_size: int, |
|
): |
|
return DataLoader(dataset, batch_size=batch_size, shuffle=False, drop_last=False, collate_fn=dataset.collate_fn) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_training_verifier_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass) -> Dict: |
|
if data_args.process == True: |
|
dataset_class = VerifierDataset_test |
|
else: |
|
dataset_class = VerifierDataset |
|
|
|
train_dataset = dataset_class( |
|
tokenizer=tokenizer, |
|
data_dir=data_args.data_dir, |
|
target_set=data_args.target_set, |
|
verifier_id=data_args.verifier_id, |
|
data_id=data_args.data_id, |
|
generator_id=data_args.generator_id, |
|
per_problem_sampling_solution=data_args.per_problem_sampling_solution, |
|
loss_level=data_args.loss_level, |
|
loss_on_llm=data_args.loss_on_llm, |
|
dedup=data_args.dedup, |
|
easy=data_args.easy, |
|
) |
|
|
|
val_dataset = None |
|
if data_args.val_target_set is not None: |
|
val_dataset = dataset_class( |
|
tokenizer=tokenizer, |
|
data_dir=data_args.data_dir, |
|
target_set=data_args.val_target_set, |
|
generator_id=data_args.generator_id, |
|
per_problem_sampling_solution=-1, |
|
loss_level=data_args.loss_level, |
|
loss_on_llm=data_args.loss_on_llm, |
|
) |
|
return dict(train_dataset=train_dataset, val_dataset=val_dataset) |
|
|
|
|
|
def make_test_verifier_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args: dataclass) -> Dict: |
|
test_dataset = VerifierDataset( |
|
tokenizer=tokenizer, |
|
data_dir=data_args.data_dir, |
|
target_set=data_args.target_set, |
|
generator_id=data_args.generator_id, |
|
per_problem_sampling_solution=-1, |
|
) |
|
return test_dataset |
|
|
|
|
|
class ProcessVerifierDataset(torch.utils.data.Dataset): |
|
"""Right Padding""" |
|
|
|
def __init__( |
|
self, |
|
tokenizer: transformers.PreTrainedTokenizer = None, |
|
data_dir: str = 'data/gsm8k/model_generation', |
|
target_set: str = None, |
|
generator_id: str = None, |
|
per_problem_sampling_solution: str = None, |
|
loss_level: str = 'token', |
|
loss_on_llm: bool = False, |
|
dedup: bool = False |
|
): |
|
self.tokenizer = tokenizer |
|
self.data_dir = data_dir |
|
self.target_set = target_set |
|
self.generator_id = generator_id |
|
self.loss_level = loss_level |
|
self.loss_on_llm = loss_on_llm |
|
assert loss_level in ('token', 'step') |
|
|
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.eos_token_id = tokenizer.eos_token_id |
|
|
|
self.examples = get_model_solutions(data_dir, generator_id, target_set, process=True) |
|
assert len(self.examples[0]['outputs']) >= per_problem_sampling_solution |
|
|
|
if per_problem_sampling_solution != -1: |
|
for example in self.examples: |
|
example['outputs'] = example['outputs'][:per_problem_sampling_solution] |
|
else: |
|
per_problem_sampling_solution = len(self.examples[0]['outputs']) |
|
|
|
if dedup: |
|
for ex in self.examples: |
|
dedup_outputs = [] |
|
responses = set() |
|
for output in ex['outputs']: |
|
if output['response'] in responses: |
|
continue |
|
responses.add(output['response']) |
|
dedup_outputs.append(output) |
|
ex['outputs'] = dedup_outputs |
|
|
|
indices1 = [[i] * len(ex["outputs"]) for i, ex in enumerate(self.examples)] |
|
indices2 = [[j for j in range(len(ex["outputs"]))] for i, ex in enumerate(self.examples)] |
|
qns_str = [[ex["input"]] * len(ex["outputs"]) for ex in self.examples] |
|
solutions_str = [[outputs["response"] for outputs in ex["outputs"]] for ex in self.examples] |
|
step_labels = [[outputs["step_labels"] for outputs in ex["outputs"]] for ex in self.examples] |
|
v_classes = [[outputs["label"] == True for outputs in ex["outputs"]] for ex in self.examples] |
|
|
|
indices1 = self._flatten(indices1) |
|
indices2 = self._flatten(indices2) |
|
qns_str = self._flatten(qns_str) |
|
solutions_str = self._flatten(solutions_str) |
|
step_labels = self._flatten(step_labels) |
|
v_classes = self._flatten(v_classes) |
|
|
|
qns_tokens = tokenizer(qns_str, padding=False).input_ids |
|
|
|
steps_str = [ |
|
list(map(lambda x: x + '\n', solution_str.split('\n')[:-1])) + [solution_str.split('\n')[-1]] |
|
for solution_str in solutions_str |
|
] |
|
solutions_tokens = [ |
|
[tokenizer.encode(step_str[0], add_special_tokens=False)] |
|
+ [tokenizer.get_continued_input_ids(step) for step in step_str[1:]] |
|
for step_str in steps_str |
|
] |
|
step_tokens_lens = [ |
|
[len(step) for step in tokens] |
|
for tokens in solutions_tokens |
|
] |
|
solutions_tokens = [self._flatten(tokens) for tokens in solutions_tokens] |
|
|
|
self.indices1 = indices1 |
|
self.indices2 = indices2 |
|
self.qns_str = qns_str |
|
self.qns_tokens = qns_tokens |
|
self.solutions_str = solutions_str |
|
self.solutions_tokens = solutions_tokens |
|
self.step_tokens_lens = step_tokens_lens |
|
self.step_labels = step_labels |
|
self.v_classes = v_classes |
|
|
|
self.n_question = len(self.examples) |
|
self.per_problem_sampling_solution = per_problem_sampling_solution |
|
|
|
print( |
|
f'Number of examples = {len(qns_str)} with #deduplication = {self.n_question * self.per_problem_sampling_solution - len(qns_str)}') |
|
|
|
self.max_len = max([ |
|
len(self.qns_tokens[i]) + len(self.solutions_tokens[i]) + 1 |
|
for i in range(len(self.solutions_tokens)) |
|
] |
|
) |
|
print(f"Max tokens: {self.max_len}") |
|
|
|
def __len__(self): |
|
return len(self.solutions_tokens) |
|
|
|
def _flatten(self, ls): |
|
return [item for sublist in ls for item in sublist] |
|
|
|
def __getitem__(self, idx): |
|
qn_tokens = self.qns_tokens[idx] |
|
sol_tokens = self.solutions_tokens[idx] |
|
step_labels = self.step_labels[idx] |
|
step_tokens_lens = self.step_tokens_lens[idx] |
|
|
|
input_ids = qn_tokens + sol_tokens + [self.eos_token_id] |
|
masks = ( |
|
([0] * len(qn_tokens)) |
|
+ ([1] * len(sol_tokens)) |
|
+ ([1]) |
|
) |
|
|
|
|
|
if self.loss_on_llm: |
|
labels = input_ids |
|
labels = mask_labels(labels, masks) |
|
|
|
|
|
if self.loss_level == 'token': |
|
v_labels = ( |
|
[0] * len(qn_tokens) |
|
+ sum( |
|
[ |
|
[1 if step_label else 0] * tokens_len |
|
for tokens_len, step_label in zip(step_tokens_lens, step_labels) |
|
], |
|
[] |
|
) |
|
+ [1 if step_labels[-1] else 0] |
|
) |
|
v_labels = mask_labels(v_labels, masks) |
|
|
|
assert len(v_labels) == len(input_ids) |
|
else: |
|
raise NotImplementedError |
|
|
|
input_ids = torch.tensor(input_ids) |
|
labels = torch.tensor(labels) if self.loss_on_llm else None |
|
v_labels = torch.tensor(v_labels) |
|
return dict( |
|
idx1=self.indices1[idx], idx2=self.indices2[idx], |
|
input_ids=input_ids, labels=labels, v_labels=v_labels, |
|
qn_str=self.qns_str[idx], qn_tokens=self.qns_tokens[idx], sol_str=self.solutions_str[idx], |
|
sol_tokens=self.solutions_tokens[idx], v_class=self.v_classes[idx], |
|
) |
|
|
|
def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels, v_labels = tuple( |
|
[instance[key] for instance in instances] for key in ("input_ids", "labels", "v_labels")) |
|
idx1, idx2, qn_str, qn_tokens, sol_str, sol_tokens, v_class = tuple( |
|
[instance[key] for instance in instances] for key in |
|
("idx1", "idx2", "qn_str", "qn_tokens", "sol_str", "sol_tokens", "v_class")) |
|
|
|
input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, |
|
return_attention_mask=True) |
|
labels = right_pad_sequences(labels, padding_value=IGNORE_INDEX, |
|
return_attention_mask=False) if self.loss_on_llm else None |
|
v_labels = right_pad_sequences(v_labels, padding_value=IGNORE_INDEX, return_attention_mask=False) |
|
|
|
return dict( |
|
idx1=idx1, idx2=idx2, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
v_labels=v_labels, |
|
qn_str=qn_str, qn_tokens=qn_tokens, sol_str=sol_str, sol_tokens=sol_tokens, v_class=v_class, |
|
) |
|
|
|
|
|
class VerifierDataset(torch.utils.data.Dataset): |
|
"""Right Padding""" |
|
def __init__( |
|
self, |
|
tokenizer: transformers.PreTrainedTokenizer = None, |
|
data_dir: str = 'data/gsm8k/model_generation', |
|
target_set: str = None, |
|
data_id : str = None, |
|
verifier_id: str = None, |
|
generator_id: str = None, |
|
per_problem_sampling_solution: str = None, |
|
loss_level: str = 'token', |
|
loss_on_llm: bool = False, |
|
dedup: bool = False, |
|
easy: bool = True |
|
): |
|
self.examples = get_model_solutions_easy(data_dir, generator_id, target_set) |
|
assert len(self.examples[0]['outputs']) >= per_problem_sampling_solution |
|
|
|
self.tokenizer = tokenizer |
|
self.data_dir = data_dir |
|
self.target_set = target_set |
|
self.generator_id = generator_id |
|
self.loss_level = loss_level |
|
self.loss_on_llm = loss_on_llm |
|
assert loss_level in ('token', 'step') |
|
|
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.eos_token_id = tokenizer.eos_token_id |
|
|
|
|
|
if per_problem_sampling_solution != -1: |
|
for example in self.examples: |
|
example['outputs'] = example['outputs'][:per_problem_sampling_solution] |
|
else: |
|
per_problem_sampling_solution = len(self.examples[0]['outputs']) |
|
|
|
|
|
if dedup: |
|
for ex in self.examples: |
|
dedup_outputs = [] |
|
responses = set() |
|
for output in ex['outputs']: |
|
if output['response'] in responses: |
|
continue |
|
responses.add(output['response']) |
|
dedup_outputs.append(output) |
|
ex['outputs'] = dedup_outputs |
|
|
|
indices1 = [[i] * len(ex["outputs"]) for i, ex in enumerate(self.examples)] |
|
indices2 = [[j for j in range(len(ex["outputs"]))] for i, ex in enumerate(self.examples)] |
|
qns_str = [[ex["input"]] * len(ex["outputs"]) for ex in self.examples] |
|
solutions_str = [[outputs["response"] for outputs in ex["outputs"]] for ex in self.examples] |
|
v_classes = [[outputs["label"] == True for outputs in ex["outputs"]] for ex in self.examples] |
|
|
|
indices1 = self._flatten(indices1) |
|
indices2 = self._flatten(indices2) |
|
qns_str = self._flatten(qns_str) |
|
solutions_str = self._flatten(solutions_str) |
|
v_classes = self._flatten(v_classes) |
|
|
|
qns_tokens = tokenizer(qns_str, padding=False).input_ids |
|
solutions_tokens = tokenizer(solutions_str, padding=False, add_special_tokens=False).input_ids |
|
self.indices1 = indices1 |
|
self.indices2 = indices2 |
|
self.qns_str = qns_str |
|
self.qns_tokens = qns_tokens |
|
self.solutions_str = solutions_str |
|
self.solutions_tokens = solutions_tokens |
|
self.v_classes = v_classes |
|
|
|
|
|
self.max_len = max( |
|
[len(qns_tokens[i]) + len(solutions_tokens[i]) + 1 for i in range(len(solutions_tokens))]) |
|
|
|
print(f"Max tokens: {self.max_len}") |
|
self.per_problem_sampling_solution = per_problem_sampling_solution |
|
print(f'Number of examples = {len(self.qns_str)}') |
|
self.n_question = len(self.examples) |
|
|
|
def __len__(self): |
|
return len(self.solutions_tokens) |
|
|
|
def _flatten(self, ls): |
|
return [item for sublist in ls for item in sublist] |
|
|
|
def __getitem__(self, idx): |
|
qn_tokens = self.qns_tokens[idx] |
|
sol_tokens = self.solutions_tokens[idx] |
|
v_class = self.v_classes[idx] |
|
|
|
input_ids = qn_tokens + sol_tokens + [self.eos_token_id] |
|
masks = ( |
|
([0] * len(qn_tokens)) |
|
+ ([1] * len(sol_tokens)) |
|
+ ([1]) |
|
) |
|
|
|
|
|
if self.loss_on_llm: |
|
labels = input_ids |
|
labels = mask_labels(labels, masks) |
|
|
|
|
|
if self.loss_level == 'token': |
|
v_labels = [int(v_class)] * len(input_ids) |
|
v_labels = mask_labels(v_labels, masks) |
|
else: |
|
raise NotImplementedError |
|
|
|
input_ids = torch.tensor(input_ids) |
|
labels = torch.tensor(labels) if self.loss_on_llm else None |
|
v_labels = torch.tensor(v_labels) |
|
return dict( |
|
idx1=self.indices1[idx], idx2=self.indices2[idx], |
|
input_ids=input_ids, labels=labels, v_labels=v_labels, |
|
qn_str=self.qns_str[idx], qn_tokens=self.qns_tokens[idx], sol_str=self.solutions_str[idx], sol_tokens=self.solutions_tokens[idx], v_class=self.v_classes[idx], |
|
) |
|
|
|
def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels, v_labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "v_labels")) |
|
idx1, idx2, qn_str, qn_tokens, sol_str, sol_tokens, v_class = tuple([instance[key] for instance in instances] for key in ("idx1", "idx2", "qn_str", "qn_tokens", "sol_str", "sol_tokens", "v_class")) |
|
|
|
input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True) |
|
labels = right_pad_sequences(labels, padding_value=IGNORE_INDEX, return_attention_mask=False) if self.loss_on_llm else None |
|
v_labels = right_pad_sequences(v_labels, padding_value=IGNORE_INDEX, return_attention_mask=False) |
|
|
|
return dict( |
|
idx1=idx1, idx2=idx2, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
v_labels=v_labels, |
|
qn_str=qn_str, qn_tokens=qn_tokens, sol_str=sol_str, sol_tokens=sol_tokens, v_class=v_class, |
|
) |
|
|
|
|
|
class VerifierDataset_test(torch.utils.data.Dataset): |
|
"""Right Padding""" |
|
def __init__( |
|
self, |
|
tokenizer: transformers.PreTrainedTokenizer = None, |
|
data_dir: str = 'data/gsm8k/model_generation', |
|
target_set: str = None, |
|
data_id : str = None, |
|
verifier_id: str = None, |
|
generator_id: str = None, |
|
per_problem_sampling_solution: str = None, |
|
loss_level: str = 'token', |
|
loss_on_llm: bool = False, |
|
dedup: bool = False, |
|
easy: bool = True |
|
): |
|
self.examples = get_model_solutions_psv(data_dir, generator_id, target_set) |
|
assert len(self.examples[0]['outputs']) >= per_problem_sampling_solution |
|
print("VerifierDataset_test") |
|
|
|
self.tokenizer = tokenizer |
|
self.data_dir = data_dir |
|
self.target_set = target_set |
|
self.generator_id = generator_id |
|
self.loss_level = loss_level |
|
self.loss_on_llm = loss_on_llm |
|
assert loss_level in ('token', 'step') |
|
|
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.eos_token_id = tokenizer.eos_token_id |
|
|
|
|
|
if per_problem_sampling_solution != -1: |
|
for example in self.examples: |
|
example['outputs'] = example['outputs'][:per_problem_sampling_solution] |
|
else: |
|
per_problem_sampling_solution = len(self.examples[0]['outputs']) |
|
|
|
|
|
if dedup: |
|
for ex in self.examples: |
|
dedup_outputs = [] |
|
responses = set() |
|
for output in ex['outputs']: |
|
if output['response'] in responses: |
|
continue |
|
responses.add(output['response']) |
|
dedup_outputs.append(output) |
|
ex['outputs'] = dedup_outputs |
|
|
|
indices1 = [[i] * len(ex["outputs"]) for i, ex in enumerate(self.examples)] |
|
indices2 = [[j for j in range(len(ex["outputs"]))] for i, ex in enumerate(self.examples)] |
|
qns_str = [[ex["input"]] * len(ex["outputs"]) for ex in self.examples] |
|
part1_solutions_str = [] |
|
part2_solutions_str = [] |
|
partition_list= [] |
|
for ex in self.examples: |
|
part1_solutions_str_p = [] |
|
part2_solutions_str_p = [] |
|
partition_list_p = [] |
|
for output_id in range(len(ex['outputs'])): |
|
output = ex['outputs'][output_id]["response"] |
|
partition_id = ex['outputs'][output_id]["label"] |
|
partition_list_p.append(partition_id) |
|
if partition_id == -1: |
|
part1_solutions_str_p.append(output) |
|
part2_solutions_str_p.append("") |
|
else: |
|
part1_solutions_str_p.append(output[:partition_id ]) |
|
part2_solutions_str_p.append(output[partition_id:]) |
|
part1_solutions_str.append(part1_solutions_str_p) |
|
part2_solutions_str.append(part2_solutions_str_p) |
|
partition_list.append(partition_list_p) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
indices1 = self._flatten(indices1) |
|
indices2 = self._flatten(indices2) |
|
qns_str = self._flatten(qns_str) |
|
part1_solutions_str = self._flatten(part1_solutions_str) |
|
part2_solutions_str = self._flatten(part2_solutions_str) |
|
partition_list =self._flatten(partition_list) |
|
|
|
qns_tokens = tokenizer(qns_str, padding=False).input_ids |
|
part1_solutions_tokens = tokenizer(part1_solutions_str, padding=False, add_special_tokens=False).input_ids |
|
part2_solutions_tokens = tokenizer(part2_solutions_str, padding=False, add_special_tokens=False).input_ids |
|
|
|
|
|
v_classes = [len(part1_solutions_tokens[id]) if partition_list[id]!= -1 else -1 for id in range(len(part1_solutions_tokens))] |
|
|
|
solutions_tokens = [part1_solutions_tokens[id] + part2_solutions_tokens[id] for id in range(len(part1_solutions_tokens))] |
|
solutions_str = [part1_solutions_str[id] + part2_solutions_str[id] for id in range(len(part1_solutions_str))] |
|
|
|
|
|
v_class_label = 0 |
|
|
|
for slices in range(0, len(v_classes)): |
|
if v_classes[slices] == -1: |
|
v_class_label += 1 |
|
print("-1 in v_class ratio", v_class_label/len(v_classes)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print("change v_label to 0.5") |
|
one_len = 0 |
|
total_len = 0 |
|
for id in range(len(partition_list)): |
|
total_len += len(solutions_tokens[id]) |
|
if partition_list[id] == -1: |
|
one_len += len(solutions_tokens[id]) |
|
else: |
|
one_len += len(part1_solutions_tokens[id]) |
|
print("1 label ratio", one_len/total_len) |
|
|
|
|
|
|
|
self.indices1 = indices1 |
|
self.indices2 = indices2 |
|
self.qns_str = qns_str |
|
self.qns_tokens = qns_tokens |
|
self.solutions_str = solutions_str |
|
self.solutions_tokens = solutions_tokens |
|
self.v_classes = v_classes |
|
|
|
self.max_len = max( |
|
[len(qns_tokens[i]) + len(solutions_tokens[i]) + 1 for i in range(len(solutions_tokens))]) |
|
|
|
print(f"Max tokens: {self.max_len}") |
|
self.per_problem_sampling_solution = per_problem_sampling_solution |
|
print(f'Number of examples = {len(self.qns_str)}') |
|
self.n_question = len(self.examples) |
|
|
|
def __len__(self): |
|
return len(self.solutions_tokens) |
|
|
|
def _flatten(self, ls): |
|
return [item for sublist in ls for item in sublist] |
|
|
|
def __getitem__(self, idx): |
|
qn_tokens = self.qns_tokens[idx] |
|
sol_tokens = self.solutions_tokens[idx] |
|
v_class = self.v_classes[idx] |
|
|
|
input_ids = qn_tokens + sol_tokens + [self.eos_token_id] |
|
masks = ( |
|
([0] * len(qn_tokens)) |
|
+ ([1] * len(sol_tokens)) |
|
+ ([1]) |
|
) |
|
|
|
|
|
if self.loss_on_llm: |
|
labels = input_ids |
|
labels = mask_labels(labels, masks) |
|
|
|
|
|
if self.loss_level == 'token': |
|
if v_class == -1: |
|
v_labels = [1] * len(input_ids) |
|
else: |
|
v_labels = [0] * len(input_ids) |
|
v_labels[len(qn_tokens): len(qn_tokens) + v_class] = [0.5] * v_class |
|
|
|
v_labels = mask_labels(v_labels, masks) |
|
else: |
|
raise NotImplementedError |
|
|
|
input_ids = torch.tensor(input_ids) |
|
labels = torch.tensor(labels) if self.loss_on_llm else None |
|
v_labels = torch.tensor(v_labels) |
|
return dict( |
|
idx1=self.indices1[idx], idx2=self.indices2[idx], |
|
input_ids=input_ids, labels=labels, v_labels=v_labels, |
|
qn_str=self.qns_str[idx], qn_tokens=self.qns_tokens[idx], sol_str=self.solutions_str[idx], sol_tokens=self.solutions_tokens[idx], v_class=self.v_classes[idx], |
|
) |
|
|
|
def collate_fn(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]: |
|
input_ids, labels, v_labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels", "v_labels")) |
|
idx1, idx2, qn_str, qn_tokens, sol_str, sol_tokens, v_class = tuple([instance[key] for instance in instances] for key in ("idx1", "idx2", "qn_str", "qn_tokens", "sol_str", "sol_tokens", "v_class")) |
|
|
|
input_ids, attention_mask = right_pad_sequences(input_ids, padding_value=self.pad_token_id, return_attention_mask=True) |
|
labels = right_pad_sequences(labels, padding_value=IGNORE_INDEX, return_attention_mask=False) if self.loss_on_llm else None |
|
v_labels = right_pad_sequences(v_labels, padding_value=IGNORE_INDEX, return_attention_mask=False) |
|
|
|
return dict( |
|
idx1=idx1, idx2=idx2, |
|
input_ids=input_ids, |
|
attention_mask=attention_mask, |
|
labels=labels, |
|
v_labels=v_labels, |
|
qn_str=qn_str, qn_tokens=qn_tokens, sol_str=sol_str, sol_tokens=sol_tokens, v_class=v_class, |
|
) |
|
|
|
|
|
|
|
|
|
class VerifierDataset_self(VerifierDataset): |
|
"""Right Padding""" |
|
def __init__( |
|
self, |
|
tokenizer: transformers.PreTrainedTokenizer = None, |
|
data_dir: str = 'data/gsm8k/model_generation', |
|
target_set: str = None, |
|
data_id: str = None, |
|
generator_id: str = None, |
|
verifier_id: str = None, |
|
per_problem_sampling_solution: str = None, |
|
loss_level: str = 'token', |
|
loss_on_llm: bool = False, |
|
dedup: bool = False, |
|
easy: bool = True |
|
): |
|
if easy: |
|
self.examples = get_model_solutions_easy(data_dir, data_id,verifier_id,generator_id) |
|
else: |
|
self.examples = get_model_solutions_self(data_dir, data_id,verifier_id,generator_id) |
|
assert len(self.examples[0]['outputs']) >= per_problem_sampling_solution |
|
|
|
self.tokenizer = tokenizer |
|
self.data_dir = data_dir |
|
self.target_set = target_set |
|
self.generator_id = generator_id |
|
self.loss_level = loss_level |
|
self.loss_on_llm = loss_on_llm |
|
assert loss_level in ('token', 'step') |
|
|
|
self.pad_token_id = tokenizer.pad_token_id |
|
self.eos_token_id = tokenizer.eos_token_id |
|
|
|
if per_problem_sampling_solution != -1: |
|
for example in self.examples: |
|
if "input" not in example: |
|
example['input'] = example['question'] |
|
example['outputs'] = example['outputs'][:per_problem_sampling_solution] |
|
else: |
|
per_problem_sampling_solution = len(self.examples[0]['outputs']) |
|
|
|
if dedup: |
|
for ex in self.examples: |
|
dedup_outputs = [] |
|
responses = set() |
|
for output in ex['outputs']: |
|
if output['response'] in responses: |
|
continue |
|
responses.add(output['response']) |
|
dedup_outputs.append(output) |
|
ex['outputs'] = dedup_outputs |
|
|
|
indices1 = [[i] * len(ex["outputs"]) for i, ex in enumerate(self.examples)] |
|
indices2 = [[j for j in range(len(ex["outputs"]))] for i, ex in enumerate(self.examples)] |
|
qns_str = [[ex["input"]] * len(ex["outputs"]) for ex in self.examples] |
|
solutions_str = [[outputs["response"] for outputs in ex["outputs"]] for ex in self.examples] |
|
v_classes = [[outputs["process_vscores"] for outputs in ex["outputs"]] for ex in self.examples] |
|
|
|
indices1 = self._flatten(indices1) |
|
indices2 = self._flatten(indices2) |
|
qns_str = self._flatten(qns_str) |
|
solutions_str = self._flatten(solutions_str) |
|
v_classes = self._flatten(v_classes) |
|
|
|
qns_tokens = tokenizer(qns_str, padding=False).input_ids |
|
solutions_tokens = tokenizer(solutions_str, padding=False, add_special_tokens=False).input_ids |
|
|
|
|
|
self.indices1 = indices1 |
|
self.indices2 = indices2 |
|
self.qns_str = qns_str |
|
self.qns_tokens = qns_tokens |
|
self.solutions_str = solutions_str |
|
self.solutions_tokens = solutions_tokens |
|
self.v_classes = v_classes |
|
|
|
self.n_question = len(self.examples) |
|
self.per_problem_sampling_solution = per_problem_sampling_solution |
|
|
|
print( |
|
f'Number of examples = {len(qns_str)} with #deduplication = {self.n_question * self.per_problem_sampling_solution - len(qns_str)}') |
|
|
|
self.max_len = max([ |
|
len(self.qns_tokens[i]) + len(self.solutions_tokens[i]) + 1 |
|
for i in range(len(self.solutions_tokens)) |
|
] |
|
) |
|
print(f"Max tokens: {self.max_len}") |
|
|
|
|
|
def __getitem__(self, idx): |
|
qn_tokens = self.qns_tokens[idx] |
|
sol_tokens = self.solutions_tokens[idx] |
|
v_class = self.v_classes[idx] |
|
|
|
input_ids = qn_tokens + sol_tokens + [self.eos_token_id] |
|
masks = ( |
|
([0] * len(qn_tokens)) |
|
+ ([1] * len(sol_tokens)) |
|
+ ([1]) |
|
) |
|
|
|
|
|
if self.loss_on_llm: |
|
labels = input_ids |
|
labels = mask_labels(labels, masks) |
|
|
|
|
|
if self.loss_level == 'token': |
|
v_class = [1] * len(qn_tokens)+ v_class |
|
v_labels = mask_labels(v_class, masks) |
|
else: |
|
raise NotImplementedError |
|
|
|
input_ids = torch.tensor(input_ids) |
|
labels = torch.tensor(labels) if self.loss_on_llm else None |
|
v_labels = torch.tensor(v_labels) |
|
return dict( |
|
idx1=self.indices1[idx], idx2=self.indices2[idx], |
|
input_ids=input_ids, labels=labels, v_labels=v_labels, |
|
qn_str=self.qns_str[idx], qn_tokens=self.qns_tokens[idx], sol_str=self.solutions_str[idx], |
|
sol_tokens=self.solutions_tokens[idx], v_class=self.v_classes[idx], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def left_pad_sequences(sequences: List[torch.LongTensor], padding_value: int, return_attention_mask: bool = False): |
|
max_length = max(len(x) for x in sequences) |
|
padded_sequences = torch.stack([F.pad(seq, (max_length - seq.shape[-1], 0), value=padding_value) for seq in sequences], dim=0) |
|
if return_attention_mask: |
|
attention_mask = padded_sequences.ne(padding_value) |
|
return padded_sequences, attention_mask |
|
return padded_sequences |
|
|
|
def right_pad_sequences(sequences: List[torch.LongTensor], padding_value: int, return_attention_mask: bool = False): |
|
padded_sequences = torch.nn.utils.rnn.pad_sequence( |
|
sequences, |
|
batch_first=True, |
|
padding_value=padding_value, |
|
) |
|
if return_attention_mask: |
|
attention_mask = padded_sequences.ne(padding_value) |
|
return padded_sequences, attention_mask |
|
return padded_sequences |
|
|
|
|
|
def mask_labels(labels: List[int], masks: List[bool]): |
|
"""Mask the corresponding label into IGNORE_INDEX""" |
|
assert len(labels) == len(masks) |
|
return [ |
|
token if mask |
|
else IGNORE_INDEX |
|
for token, mask in zip(labels, masks) |
|
] |
|
|
|
|
|
|
|
|