|
from langchain_community.document_loaders import UnstructuredMarkdownLoader |
|
from langchain.text_splitter import RecursiveCharacterTextSplitter |
|
from langchain_huggingface import HuggingFaceEmbeddings |
|
|
|
from langchain_community.vectorstores import FAISS |
|
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig |
|
from langchain_huggingface.llms import HuggingFacePipeline |
|
from langchain.prompts import PromptTemplate |
|
|
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.runnables import RunnablePassthrough |
|
import glob |
|
import gradio as gr |
|
from langchain_huggingface import HuggingFaceEndpoint |
|
from langchain_community.embeddings import HuggingFaceInferenceAPIEmbeddings |
|
|
|
import os |
|
|
|
|
|
|
|
secret_value_hf = os.getenv("hf_token") |
|
|
|
hf_embeddings = HuggingFaceInferenceAPIEmbeddings( |
|
api_key=secret_value_hf, |
|
model_name="sentence-transformers/all-MiniLM-l6-v2" |
|
) |
|
|
|
md_path = glob.glob( "md_files/*.md") |
|
|
|
docs = [UnstructuredMarkdownLoader(md).load() for md in md_path] |
|
docs_list = [item for sublist in docs for item in sublist] |
|
|
|
|
|
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( |
|
chunk_size=1000, chunk_overlap=200 |
|
) |
|
doc_splits = text_splitter.split_documents(docs_list) |
|
|
|
|
|
|
|
|
|
db = FAISS.from_documents(doc_splits, |
|
hf_embeddings) |
|
|
|
|
|
|
|
prompt_template = '''You are an assistant for question-answering tasks. |
|
Here is the context to use to answer the question: |
|
{context} |
|
Think carefully about the above context. |
|
Now, review the user question: |
|
{question} |
|
Provide an answer to this questions using only the above context. |
|
Use three sentences maximum and keep the answer concise. |
|
Answer:''' |
|
|
|
prompt = PromptTemplate( |
|
input_variables=["context", "question"], |
|
template=prompt_template, |
|
) |
|
|
|
|
|
|
|
|
|
def get_output(model_name:str,is_RAG:str,questions:str): |
|
if model_name=="mistralai/Mistral-7B-Instruct-v0.2": |
|
|
|
llm = HuggingFaceEndpoint( |
|
repo_id=model_name, |
|
max_length=4096, |
|
temperature=0.2, |
|
huggingfacehub_api_token=secret_value_hf, |
|
) |
|
llm_chain = prompt | llm | StrOutputParser() |
|
retriever = db.as_retriever( |
|
search_type="similarity", |
|
search_kwargs={'k': 4} |
|
) |
|
|
|
rag_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| llm_chain |
|
) |
|
if is_RAG== "RAG": |
|
generation2=rag_chain.invoke(questions) |
|
return generation2 |
|
else: |
|
generation1=llm_chain.invoke({"context":"", "question": questions}) |
|
return generation1 |
|
elif model_name=="meta-llama/Llama-3.2-3B-Instruct": |
|
llm = HuggingFaceEndpoint( |
|
repo_id=model_name, |
|
max_length=4096, |
|
temperature=0.2, |
|
huggingfacehub_api_token=secret_value_hf, |
|
) |
|
llm_chain = prompt | llm | StrOutputParser() |
|
retriever = db.as_retriever() |
|
|
|
rag_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| llm_chain |
|
) |
|
if is_RAG== "RAG": |
|
generation2=rag_chain.invoke(questions) |
|
return generation2 |
|
else: |
|
generation1=llm_chain.invoke({"context":"", "question": questions}) |
|
return generation1 |
|
elif model_name=="Qwen/Qwen2.5-72B-Instruct": |
|
llm = HuggingFaceEndpoint( |
|
repo_id=model_name, |
|
max_length=4096, |
|
temperature=0.2, |
|
huggingfacehub_api_token=secret_value_hf, |
|
) |
|
llm_chain = prompt | llm | StrOutputParser() |
|
retriever = db.as_retriever() |
|
|
|
rag_chain = ( |
|
{"context": retriever, "question": RunnablePassthrough()} |
|
| llm_chain |
|
) |
|
if is_RAG== "RAG": |
|
generation2=rag_chain.invoke(questions) |
|
return generation2 |
|
else: |
|
generation1=llm_chain.invoke({"context":"", "question": questions}) |
|
return generation1 |
|
|
|
|
|
|
|
|
|
custom_css = """ |
|
#output_area { |
|
background-color: #1e1e1e; /* Dark background */ |
|
color: #ffffff; /* White text */ |
|
padding: 10px; |
|
border-radius: 5px; |
|
border: 1px solid #333333; /* Dark border */ |
|
margin-top: 10px; |
|
} |
|
|
|
#output_area h3 { |
|
color: #ffcc00; /* Yellow title color */ |
|
margin-bottom: 10px; |
|
} |
|
""" |
|
|
|
with gr.Blocks(title="Ask Questions on Chalcogenide Perovskites",theme=gr.themes.Ocean(),css=custom_css) as demo: |
|
gr.Markdown(""" |
|
## Retrieval-Augmented Generation for Chalcogenide Perovskites |
|
This space implements Retrieval-Augmented Generation (RAG) using large language models, based on Hui Haolei's work on chalcogenide perovskite papers. You can select different models and choose whether to use RAG to enhance the responses. |
|
For more details, check my [github](https://github.com/HaoleiH/AI-driven-projects/blob/main/RAG-using-Llama3.2-3b/README_RAG.md). |
|
""") |
|
|
|
with gr.Row(): |
|
model_name = gr.Radio( |
|
choices=["mistralai/Mistral-7B-Instruct-v0.2", "meta-llama/Llama-3.2-3B-Instruct", "Qwen/Qwen2.5-72B-Instruct"], |
|
value="mistralai/Mistral-7B-Instruct-v0.2", |
|
label="Model Name", |
|
info="Select the model you want to use." |
|
) |
|
|
|
with gr.Row(): |
|
rag = gr.Radio( |
|
choices=["RAG", "No RAG"], |
|
value="RAG", |
|
label="RAG or Not", |
|
info="Choose whether to use Retrieval-Augmented Generation." |
|
) |
|
|
|
with gr.Row(): |
|
question = gr.Textbox( |
|
label="Input Question", |
|
placeholder="Enter your question about chalcogenide perovskites here...", |
|
lines=2 |
|
) |
|
|
|
with gr.Row(): |
|
submit_button = gr.Button("Submit") |
|
|
|
with gr.Row(): |
|
output = gr.Textbox(label="Response", |
|
lines=10, |
|
elem_id="output_area" |
|
) |
|
submit_button.click( |
|
fn=get_output, |
|
inputs=[model_name, rag, question], |
|
outputs=output |
|
) |
|
|
|
gr.Examples( |
|
examples=[ |
|
["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "What is the advantage of BaZrS3?"], |
|
["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "What is the bandgap of SrHfS3?"], |
|
["mistralai/Mistral-7B-Instruct-v0.2", "RAG", "Why is chalcogenide perovskite important?"] |
|
], |
|
fn=get_output, |
|
inputs=[model_name, rag, question], |
|
outputs=output, |
|
cache_examples=False |
|
) |
|
|
|
demo.launch() |