File size: 2,184 Bytes
8e82bd0 40891f0 8e82bd0 40891f0 8e82bd0 40891f0 8e82bd0 f5213da 184a293 8e82bd0 40891f0 0fcd9d6 8e82bd0 0fcd9d6 8e82bd0 184a293 204f3f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 |
---
language:
- en
---
# nli-entailment-verifier-xxl
## Model description
**nli-entailment-verifier-xxl** is based on [flan-t5-xxl model](https://huggingface.co/google/flan-t5-xxl) and finetuned with a ranking objective (rank the most supported hypothesis from a given pair of hypotheses for a given premise). Please refer to our paper [Are Machines Better at Complex Reasoning? Unveiling Human-Machine Inference Gaps in Entailment Verification](https://arxiv.org/abs/2402.03686) for more detals.
It is built to verify whether a given premise supports a hypothesis or not. It works for both NLI-style datasets and CoT rationales. This model is specifically trained to handle multi-sentence premises (similar to what we expect in CoT rationales and other modern LLM use cases).
**Note**: You can use 4-bit/8-bit [quantization](https://huggingface.co/docs/bitsandbytes/main/en/index) to reduce GPU memory usage.
## Usage
```python
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch
def get_score(model, tokenizer, input_ids):
pos_ids = tokenizer('Yes').input_ids
neg_ids = tokenizer('No').input_ids
pos_id = pos_ids[0]
neg_id = neg_ids[0]
logits = model(input_ids, decoder_input_ids=torch.zeros((input_ids.size(0), 1), dtype=torch.long)).logits
pos_logits = logits[:, 0, pos_id]
neg_logits = logits[:, 0, neg_id]
posneg_logits = torch.cat([pos_logits.unsqueeze(-1), neg_logits.unsqueeze(-1)], dim=1)
scores = torch.nn.functional.softmax(posneg_logits, dim=1)[:, 0]
return scores
tokenizer = AutoTokenizer.from_pretrained('google/flan-t5-xxl')
model = AutoModelForSeq2SeqLM.from_pretrained('soumyasanyal/nli-entailment-verifier-xxl')
premise = "A fossil fuel is a kind of natural resource. Coal is a kind of fossil fuel."
hypothesis = "Coal is a kind of natural resource."
prompt = f"Premise: {premise}\nHypothesis: {hypothesis}\nGiven the premise, is the hypothesis correct?\nAnswer:"
input_ids = tokenizer(prompt, return_tensors='pt').input_ids
scores = get_score(model, tokenizer, input_ids)
print(f'Hypothesis entails the premise: {bool(scores >= 0.5)}')
```
> `['Hypothesis entails the premise: False']` |