File size: 1,356 Bytes
7fdb8e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import os
from typing import Any

from huggingface_hub import InferenceClient

from rag_demo.rag.base.query import Query
from rag_demo.rag.base.template_factory import RAGStep
from rag_demo.rag.prompt_templates import QueryExpansionTemplate


class QueryExpansion(RAGStep):
    def generate(self, query: Query, expand_to_n: int) -> Any:
        api = InferenceClient(
            model="Qwen/Qwen2.5-72B-Instruct",
            token=os.getenv("HF_API_TOKEN"),
        )
        query_expansion_template = QueryExpansionTemplate()
        prompt = query_expansion_template.create_template(expand_to_n - 1)
        response = api.chat_completion(
            [
                {
                    "role": "user",
                    "content": prompt.template.format(
                        question=query.content,
                        expand_to_n=expand_to_n,
                        separator=query_expansion_template.separator,
                    ),
                }
            ]
        )
        result = response.choices[0].message.content
        queries_content = result.split(query_expansion_template.separator)
        queries = [query]
        queries += [
            query.replace_content(stripped_content)
            for content in queries_content
            if (stripped_content := content.strip())
        ]
        return queries