Shitao commited on
Commit
5a48fd9
·
verified ·
1 Parent(s): a15ce40

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. README.md +33 -12
README.md CHANGED
@@ -1778,11 +1778,29 @@ def last_token_pool(last_hidden_states: Tensor,
1778
 
1779
 
1780
  def get_detailed_instruct(task_description: str, query: str) -> str:
1781
- return f'<instruct>{task_description}\n<query>{query}\n<response>'
1782
 
1783
  def get_detailed_example(task_description: str, query: str, response: str) -> str:
1784
  return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
1785
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1786
  task = 'Given a web search query, retrieve relevant passages that answer the query.'
1787
  examples = [
1788
  {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
@@ -1795,31 +1813,34 @@ examples = [
1795
  examples = [get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
1796
  examples_prefix = '\n\n'.join(examples) + '\n\n'
1797
  queries = [
1798
- examples_prefix + get_detailed_instruct(task, 'how much protein should a female eat'),
1799
- examples_prefix + get_detailed_instruct(task, 'summit define')
1800
  ]
1801
- # No need to add instructions for documents
1802
  documents = [
1803
  "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
1804
  "Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
1805
  ]
1806
- input_texts = queries + documents
1807
 
1808
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
1809
  model = AutoModel.from_pretrained('BAAI/bge-en-icl')
1810
  model.eval()
1811
 
1812
- max_length = 4096
1813
- # Tokenize the input texts
1814
- batch_dict = tokenizer(input_texts, max_length=max_length, padding=True, truncation=True, return_tensors='pt')
 
1815
 
1816
  with torch.no_grad():
1817
- outputs = model(**batch_dict)
1818
- embeddings = last_token_pool(outputs.last_hidden_state, batch_dict['attention_mask'])
 
 
1819
 
1820
  # normalize embeddings
1821
- embeddings = F.normalize(embeddings, p=2, dim=1)
1822
- scores = (embeddings[:2] @ embeddings[2:].T) * 100
 
1823
  print(scores.tolist())
1824
  ```
1825
 
 
1778
 
1779
 
1780
  def get_detailed_instruct(task_description: str, query: str) -> str:
1781
+ return f'<instruct>{task_description}\n<query>{query}'
1782
 
1783
  def get_detailed_example(task_description: str, query: str, response: str) -> str:
1784
  return f'<instruct>{task_description}\n<query>{query}\n<response>{response}'
1785
 
1786
+ def get_new_queries(queries, query_max_len, examples_prefix, tokenizer):
1787
+ inputs = tokenizer(
1788
+ queries,
1789
+ max_length=query_max_len - len(tokenizer('<s>', add_special_tokens=False)['input_ids']) - len(
1790
+ tokenizer('\n<response></s>', add_special_tokens=False)['input_ids']),
1791
+ return_token_type_ids=False,
1792
+ truncation=True,
1793
+ return_tensors=None,
1794
+ add_special_tokens=False
1795
+ )
1796
+ prefix_ids = tokenizer(examples_prefix)['input_ids']
1797
+ suffix_ids = tokenizer('\n<response>')['input_ids']
1798
+ new_max_length = (len(prefix_ids) + len(suffix_ids) + query_max_len) // 8 * 8 + 8
1799
+ new_queries = tokenizer.batch_decode(inputs['input_ids'])
1800
+ for i in range(len(new_queries)):
1801
+ new_queries[i] = examples_prefix + new_queries[i] + '\n<response>'
1802
+ return new_max_length, new_queries
1803
+
1804
  task = 'Given a web search query, retrieve relevant passages that answer the query.'
1805
  examples = [
1806
  {'instruct': 'Given a web search query, retrieve relevant passages that answer the query.',
 
1813
  examples = [get_detailed_example(e['instruct'], e['query'], e['response']) for e in examples]
1814
  examples_prefix = '\n\n'.join(examples) + '\n\n'
1815
  queries = [
1816
+ get_detailed_instruct(task, 'how much protein should a female eat'),
1817
+ get_detailed_instruct(task, 'summit define')
1818
  ]
 
1819
  documents = [
1820
  "As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.",
1821
  "Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments."
1822
  ]
1823
+ query_max_len, doc_max_len = 512, 512
1824
 
1825
  tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-en-icl')
1826
  model = AutoModel.from_pretrained('BAAI/bge-en-icl')
1827
  model.eval()
1828
 
1829
+ new_query_max_len, new_queries = get_new_queries(queries, query_max_len, examples_prefix, tokenizer)
1830
+
1831
+ query_batch_dict = tokenizer(new_queries, max_length=new_query_max_len, padding=True, truncation=True, return_tensors='pt')
1832
+ doc_batch_dict = tokenizer(documents, max_length=doc_max_len, padding=True, truncation=True, return_tensors='pt')
1833
 
1834
  with torch.no_grad():
1835
+ query_outputs = model(**query_batch_dict)
1836
+ query_embeddings = last_token_pool(query_outputs.last_hidden_state, query_batch_dict['attention_mask'])
1837
+ doc_outputs = model(**doc_batch_dict)
1838
+ doc_embeddings = last_token_pool(doc_outputs.last_hidden_state, doc_batch_dict['attention_mask'])
1839
 
1840
  # normalize embeddings
1841
+ query_embeddings = F.normalize(query_embeddings, p=2, dim=1)
1842
+ doc_embeddings = F.normalize(doc_embeddings, p=2, dim=1)
1843
+ scores = (query_embeddings @ doc_embeddings.T) * 100
1844
  print(scores.tolist())
1845
  ```
1846