Spaces:
Sleeping
Sleeping
Merge branch 'import_sorbobot' into main
Browse files- .gitattributes +1 -0
- .gitignore +161 -0
- Dockerfile +56 -0
- README.md +7 -36
- app.py +0 -65
- docker-entrypoint-initdb.d/01-restore.sh +7 -0
- SU_CSV.db → docker-entrypoint-initdb.d/dump.pgdata +2 -2
- docs/docker.md +31 -0
- docs/sorbobot.md +55 -0
- execution.sh +7 -0
- pyproject.toml +91 -0
- requirements.txt +14 -11
- setup.py +10 -0
- sorbobotapp/__init__.py +0 -0
- sorbobotapp/app.py +190 -0
- sorbobotapp/chain.py +39 -0
- sorbobotapp/chat_history.py +49 -0
- sorbobotapp/connection.py +10 -0
- sorbobotapp/conversation_retrieval_chain.py +100 -0
- sorbobotapp/css.py +7 -0
- sorbobotapp/keyword_extraction.py +58 -0
- sorbobotapp/message.py +16 -0
- sorbobotapp/model.py +17 -0
- sorbobotapp/models/article.py +17 -0
- sorbobotapp/models/distance.py +15 -0
- sorbobotapp/static/ai_icon.png +0 -0
- sorbobotapp/static/styles.css +36 -0
- sorbobotapp/static/user_icon.png +0 -0
- sorbobotapp/utils.py +20 -0
- sorbobotapp/vector_store.py +364 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
SU_CSV.db filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
SU_CSV.db filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.pgdata filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
.ruff_cache
|
54 |
+
|
55 |
+
# Translations
|
56 |
+
*.mo
|
57 |
+
*.pot
|
58 |
+
|
59 |
+
# Django stuff:
|
60 |
+
*.log
|
61 |
+
local_settings.py
|
62 |
+
db.sqlite3
|
63 |
+
db.sqlite3-journal
|
64 |
+
|
65 |
+
# Flask stuff:
|
66 |
+
instance/
|
67 |
+
.webassets-cache
|
68 |
+
|
69 |
+
# Scrapy stuff:
|
70 |
+
.scrapy
|
71 |
+
|
72 |
+
# Sphinx documentation
|
73 |
+
docs/_build/
|
74 |
+
|
75 |
+
# PyBuilder
|
76 |
+
.pybuilder/
|
77 |
+
target/
|
78 |
+
|
79 |
+
# Jupyter Notebook
|
80 |
+
.ipynb_checkpoints
|
81 |
+
|
82 |
+
# IPython
|
83 |
+
profile_default/
|
84 |
+
ipython_config.py
|
85 |
+
|
86 |
+
# pyenv
|
87 |
+
# For a library or package, you might want to ignore these files since the code is
|
88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
89 |
+
# .python-version
|
90 |
+
|
91 |
+
# pipenv
|
92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
95 |
+
# install all needed dependencies.
|
96 |
+
#Pipfile.lock
|
97 |
+
|
98 |
+
# poetry
|
99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
101 |
+
# commonly ignored for libraries.
|
102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
103 |
+
#poetry.lock
|
104 |
+
|
105 |
+
# pdm
|
106 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
107 |
+
#pdm.lock
|
108 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
109 |
+
# in version control.
|
110 |
+
# https://pdm.fming.dev/#use-with-ide
|
111 |
+
.pdm.toml
|
112 |
+
|
113 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
114 |
+
__pypackages__/
|
115 |
+
|
116 |
+
# Celery stuff
|
117 |
+
celerybeat-schedule
|
118 |
+
celerybeat.pid
|
119 |
+
|
120 |
+
# SageMath parsed files
|
121 |
+
*.sage.py
|
122 |
+
|
123 |
+
# Environments
|
124 |
+
.env
|
125 |
+
.venv
|
126 |
+
env/
|
127 |
+
venv/
|
128 |
+
ENV/
|
129 |
+
env.bak/
|
130 |
+
venv.bak/
|
131 |
+
|
132 |
+
# Spyder project settings
|
133 |
+
.spyderproject
|
134 |
+
.spyproject
|
135 |
+
|
136 |
+
# Rope project settings
|
137 |
+
.ropeproject
|
138 |
+
|
139 |
+
# mkdocs documentation
|
140 |
+
/site
|
141 |
+
|
142 |
+
# mypy
|
143 |
+
.mypy_cache/
|
144 |
+
.dmypy.json
|
145 |
+
dmypy.json
|
146 |
+
|
147 |
+
# Pyre type checker
|
148 |
+
.pyre/
|
149 |
+
|
150 |
+
# pytype static type analyzer
|
151 |
+
.pytype/
|
152 |
+
|
153 |
+
# Cython debug symbols
|
154 |
+
cython_debug/
|
155 |
+
|
156 |
+
# PyCharm
|
157 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
158 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
159 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
160 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
161 |
+
#.idea/
|
Dockerfile
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM postgres:14.9-bookworm
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
RUN apt update && \
|
6 |
+
apt install -y --no-install-recommends \
|
7 |
+
build-essential \
|
8 |
+
python3 \
|
9 |
+
python3-pip \
|
10 |
+
python3-dev \
|
11 |
+
postgresql-server-dev-14 \
|
12 |
+
libpq-dev \
|
13 |
+
libblas-dev \
|
14 |
+
htop \
|
15 |
+
git
|
16 |
+
|
17 |
+
COPY ./ /app/
|
18 |
+
|
19 |
+
RUN pip3 install -r ./requirements.txt --break-system-packages
|
20 |
+
|
21 |
+
EXPOSE 5432
|
22 |
+
EXPOSE 7860
|
23 |
+
|
24 |
+
ENV POSTGRES_USER=postgres
|
25 |
+
ENV POSTGRES_PASSWORD=pwd
|
26 |
+
ENV POSTGRES_DB=sorbobot
|
27 |
+
|
28 |
+
# User
|
29 |
+
RUN useradd -m -u 1000 user
|
30 |
+
ENV HOME /home/user
|
31 |
+
ENV PATH $HOME/.local/bin:$PATH
|
32 |
+
|
33 |
+
# Install PGVector
|
34 |
+
WORKDIR /tmp
|
35 |
+
RUN git clone --branch v0.5.1 https://github.com/pgvector/pgvector.git
|
36 |
+
WORKDIR /tmp/pgvector
|
37 |
+
RUN make
|
38 |
+
RUN make install # may need sudo
|
39 |
+
WORKDIR $HOME
|
40 |
+
COPY ./ $HOME
|
41 |
+
|
42 |
+
COPY "execution.sh" "/usr/local/bin/"
|
43 |
+
|
44 |
+
COPY ./docker-entrypoint-initdb.d/ /docker-entrypoint-initdb.d/
|
45 |
+
|
46 |
+
RUN chown -R user:user /var/lib/postgresql/data
|
47 |
+
|
48 |
+
USER user
|
49 |
+
|
50 |
+
ENTRYPOINT ["execution.sh"]
|
51 |
+
|
52 |
+
STOPSIGNAL SIGINT
|
53 |
+
|
54 |
+
HEALTHCHECK CMD curl --fail http://localhost:7860/_stcore/health
|
55 |
+
|
56 |
+
CMD ["postgres"]
|
README.md
CHANGED
@@ -1,42 +1,13 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
|
|
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
11 |
|
12 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
13 |
-
|
14 |
-
Malheureusement, faute de temps, je n'ai pas pu terminer le fine-tuning du modèle GPT-2 et de BERT à temps, car ce dernier prenait plus de 130 heures de calcul.
|
15 |
-
|
16 |
-
Vous pouvez relancer le fine-tuning du GPT-2 en vous rendant sur le dépôt du projet et en exécutant la commande `python3 GPT2_Lorafinetune.py`.
|
17 |
-
|
18 |
-
Une fois le processus de fine-tuning terminé, vous pouvez créer un modèle et le déployer sur HuggingFace pour une utilisation ultérieure. Pour ce faire :
|
19 |
-
|
20 |
-
- Cliquez sur l'icône de votre profil sur HuggingFace -> New model (choisissez un nom / public ou privé) -> create model. Une fois le modèle créé, allez dans "Files" et ajoutez les fichiers nécessaires.
|
21 |
-
|
22 |
-
Vous pouvez vous référer à ce tutoriel : https://huggingface.co/transformers/v4.0.1/model_sharing.html
|
23 |
-
|
24 |
-
Assurez-vous qu'il n'y a pas de fichiers inutiles dans le répertoire que vous allez charger. Il doit contenir uniquement :
|
25 |
-
|
26 |
-
- Un fichier config.json, qui enregistre la configuration de votre modèle.
|
27 |
-
|
28 |
-
- Un fichier pytorch_model.bin, qui est la checkpoint PyTorch (sauf si vous ne pouvez pas l'avoir pour une raison quelconque).
|
29 |
-
|
30 |
-
- Un fichier tf_model.h5, qui est la checkpoint TensorFlow (sauf si vous ne pouvez pas l'avoir pour une raison quelconque).
|
31 |
-
|
32 |
-
- Un fichier special_tokens_map.json, qui fait partie de l'enregistrement de votre tokenizer.
|
33 |
-
|
34 |
-
- Un fichier tokenizer_config.json, qui fait partie de l'enregistrement de votre tokenizer.
|
35 |
-
|
36 |
-
- Des fichiers nommés vocab.json, vocab.txt, merges.txt ou similaires, qui contiennent le vocabulaire de votre tokenizer, faisant partie de l'enregistrement de votre tokenizer.
|
37 |
-
|
38 |
-
- Éventuellement un fichier added_tokens.json, qui fait partie de l'enregistrement de votre tokenizer.
|
39 |
-
|
40 |
-
Une fois que vous avez créé et déployé votre modèle sur HuggingFace, vous devez remplacer le modèle dans app.py par le chemin vers votre modèle.
|
41 |
-
|
42 |
-
Référence : https://huggingface.co/learn/nlp-course/chapter4/2?fw=pt
|
|
|
1 |
---
|
2 |
+
title: Test Streamlit
|
3 |
+
emoji: 📈
|
4 |
+
colorFrom: gray
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: docker
|
7 |
+
app_port: 7860
|
8 |
+
sdk_version: 1.27.2
|
9 |
app_file: app.py
|
10 |
pinned: false
|
11 |
---
|
12 |
|
13 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
DELETED
@@ -1,65 +0,0 @@
|
|
1 |
-
import streamlit as st
|
2 |
-
from transformers import pipeline
|
3 |
-
import sqlite3
|
4 |
-
from sentence_transformers import SentenceTransformer
|
5 |
-
from sklearn.feature_extraction.text import CountVectorizer
|
6 |
-
from sklearn.metrics.pairwise import cosine_similarity
|
7 |
-
import numpy as np
|
8 |
-
pipe=pipeline('sentiment-analysis')
|
9 |
-
|
10 |
-
text = """
|
11 |
-
Welcome to SorboBot, a Hugging Face Space designed to revolutionize the way you find published articles.
|
12 |
-
|
13 |
-
Powered by a full export from ScanR and Hal at Sorbonne University, SorboBot utilizes advanced language model technology to provide you with a list of published articles based on your prompt
|
14 |
-
|
15 |
-
Work in progress
|
16 |
-
|
17 |
-
Write your request:
|
18 |
-
"""
|
19 |
-
text=st.text_area(text)
|
20 |
-
|
21 |
-
|
22 |
-
if text:
|
23 |
-
n_gram_range = (2, 2)
|
24 |
-
stop_words = "english"
|
25 |
-
# Extract candidate words/phrases
|
26 |
-
count = CountVectorizer(ngram_range=n_gram_range, stop_words=stop_words).fit([text])
|
27 |
-
candidates = count.get_feature_names_out()
|
28 |
-
model = SentenceTransformer('distilbert-base-nli-mean-tokens')
|
29 |
-
doc_embedding = model.encode([text])
|
30 |
-
candidate_embeddings = model.encode(candidates)
|
31 |
-
top_n = 5
|
32 |
-
distances = cosine_similarity(doc_embedding, candidate_embeddings)
|
33 |
-
keywords = [candidates[index] for index in distances.argsort()[0][-top_n:]]
|
34 |
-
conn = sqlite3.connect('SU_CSV.db')
|
35 |
-
cursor = conn.cursor()
|
36 |
-
|
37 |
-
mots_cles_recherches = keywords
|
38 |
-
|
39 |
-
# Création de la requête SQL
|
40 |
-
query = f"SELECT title_s FROM BDD_Provisoire_SU WHERE {' OR '.join(['keyword_s LIKE ?'] * len(mots_cles_recherches))}"
|
41 |
-
params = ['%' + mot + '%' for mot in mots_cles_recherches]
|
42 |
-
|
43 |
-
cursor.execute(query, params)
|
44 |
-
resultats = cursor.fetchall()
|
45 |
-
|
46 |
-
# Affichage des titres d'articles trouvés
|
47 |
-
if resultats:
|
48 |
-
st.write("Titles of articles corresponding to your search:")
|
49 |
-
for row in resultats[:3]:
|
50 |
-
st.json(row[0])
|
51 |
-
else:
|
52 |
-
st.write("No article found in the database\n\n")
|
53 |
-
st.json({})
|
54 |
-
|
55 |
-
|
56 |
-
conn.close()
|
57 |
-
generator = pipeline("text-generation", model="gpt2") # to modify for another model
|
58 |
-
txt = generator(
|
59 |
-
text,
|
60 |
-
max_length=150,
|
61 |
-
num_return_sequences=1,
|
62 |
-
)[0]["generated_text"]
|
63 |
-
|
64 |
-
st.write("Model output")
|
65 |
-
st.write(txt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
docker-entrypoint-initdb.d/01-restore.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
file="/docker-entrypoint-initdb.d/dump.pgdata"
|
4 |
+
dbname=sorbobot
|
5 |
+
|
6 |
+
echo "Restoring DB using $file"
|
7 |
+
pg_restore -U postgres --dbname=$dbname --verbose --single-transaction < "$file" || exit 1
|
SU_CSV.db → docker-entrypoint-initdb.d/dump.pgdata
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:48d2146bac1789085d813baa28930706bcf618b516d2233d7a6146b36b4ed6e9
|
3 |
+
size 917294274
|
docs/docker.md
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker and Deployment Documentation
|
2 |
+
|
3 |
+
## Introduction
|
4 |
+
|
5 |
+
This document outlines the Docker configuration and deployment strategy for Sorbobot, a chatbot designed for Sorbonne Université to assist in locating academic experts. The application is containerized using Docker and hosted on Hugging Face Spaces.
|
6 |
+
|
7 |
+
## Docker Configuration
|
8 |
+
|
9 |
+
### Base Image and Dependencies
|
10 |
+
|
11 |
+
The Docker environment for Sorbobot is based on a Postgres image, supplemented with necessary dependencies to run a Streamlit server. This approach integrates the database and front-end interface within the same Docker container, in order to deploy it on Hugging Face which supports only one docker image.
|
12 |
+
|
13 |
+
### Database Initialization
|
14 |
+
|
15 |
+
The database for Sorbobot is initialized using a dump that has been previously created and stored in a Git repository. This repository, located at [https://git.isir.upmc.fr/sorbobot/sorbobot](https://git.isir.upmc.fr/sorbobot/sorbobot), contains essential data and schema information to set up the Postgres database correctly. During the Docker container's initialization phase, this dump is used to populate the database with the required structure and data, ensuring that the chatbot has immediate access to all necessary information for expert retrieval.
|
16 |
+
|
17 |
+
## Deployment Process
|
18 |
+
|
19 |
+
### Hosting on Hugging Face Spaces
|
20 |
+
|
21 |
+
Sorbobot is hosted on Hugging Face Spaces, a platform specifically designed for machine learning models and applications. This hosting choice offers seamless integration and effective showcasing of the chatbot's capabilities.
|
22 |
+
|
23 |
+
### Continuous Deployment via Git
|
24 |
+
|
25 |
+
The deployment of Sorbobot is managed through an automated process linked with Git. Every 'git push' to the repository initiates an automatic update and deployment sequence
|
26 |
+
|
27 |
+
### Deployment Workflow
|
28 |
+
|
29 |
+
1. **Code Updates**: Developers push the latest code changes to the Git repository.
|
30 |
+
2. **Automatic Docker Build**: The new changes trigger the Docker build process, incorporating recent updates into the Docker container.
|
31 |
+
3. **Deployment on Hugging Face Spaces**: Following the build, the updated version of Sorbobot is automatically deployed to Hugging Face Spaces, making it accessible to users with the latest features and improvements.
|
docs/sorbobot.md
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Sorbobot: Expert Finder Chatbot Documentation
|
2 |
+
|
3 |
+
## Overview
|
4 |
+
|
5 |
+
Sorbobot is a chatbot designed for Sorbonne Université to assist their administration in locating academic experts within the university. This document outlines the structure, functionality, and implementation details of Sorbobot.
|
6 |
+
|
7 |
+
### Context
|
8 |
+
|
9 |
+
Sorbobot centers around identifying experts with precision, avoiding confusion with individuals sharing similar names. It leverages HAL unique identifiers to distinguish between experts.
|
10 |
+
|
11 |
+
## System Architecture
|
12 |
+
|
13 |
+
Sorbobot operates on a Retrieval Augmented Generation (RAG) system, composed of two primary steps:
|
14 |
+
|
15 |
+
1. **Retrieval**: Identifies publications most similar to the user queries.
|
16 |
+
2. **Generation**: Produces responses based on the context extracted from relevant publications.
|
17 |
+
|
18 |
+
## Implementation Details
|
19 |
+
|
20 |
+
### Programming Language and Libraries
|
21 |
+
|
22 |
+
- **Language**: Python
|
23 |
+
- **Frontend**: Streamlit
|
24 |
+
- **Database**: PostgreSQL with pgvector for similarity search
|
25 |
+
- **NLP Processing**: langchain and GPT4all libraries
|
26 |
+
|
27 |
+
### Database
|
28 |
+
|
29 |
+
- **Postgres with pgvector**: Used for storing data and performing similarity searches based on cosine similarity metrics.
|
30 |
+
|
31 |
+
### Natural Language Processing
|
32 |
+
|
33 |
+
- **Abstracts as Data Source**: The chatbot utilizes publication abstracts to identify experts.
|
34 |
+
- **GPT4all for Word Embedding**: Converts text from author publications into word embeddings, enhancing the accuracy of expert identification.
|
35 |
+
|
36 |
+
### Retrieval Process
|
37 |
+
|
38 |
+
1. **Query Processing**: User queries are processed to extract key terms.
|
39 |
+
2. **Similarity Search**: The system searches the database using pgvector to find publications with low cosine distance to the query.
|
40 |
+
3. **Expert Identification**: The system identifies authors of these publications, ensuring unique identification of experts.
|
41 |
+
|
42 |
+
### Generation Process
|
43 |
+
|
44 |
+
1. **Context Extraction**: Relevant information is extracted from the identified publications.
|
45 |
+
2. **Response Generation**: Utilizes a LLM to generate informative responses based on the extracted context.
|
46 |
+
|
47 |
+
## User Interaction Flow
|
48 |
+
|
49 |
+
1. **Query Submission**: Users submit queries related to their expert search.
|
50 |
+
2. **Chatbot Processing**: Sorbobot processes the query, retrieves relevant publications, and identifies experts.
|
51 |
+
3. **Response Presentation**: The system presents a list of experts, including unique identifiers and relevant publication abstracts.
|
52 |
+
|
53 |
+
## Conclusion
|
54 |
+
|
55 |
+
Sorbobot is a powerful tool for Sorbonne Université, streamlining the process of finding academic experts. Its advanced NLP capabilities, combined with a robust database and intelligent retrieval-generation framework, ensure accurate and efficient expert identification.
|
execution.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
bash /usr/local/bin/docker-entrypoint.sh "$@" &
|
4 |
+
postgres &
|
5 |
+
sleep 2
|
6 |
+
|
7 |
+
streamlit run sorbobotapp/app.py --server.port=7860 --server.address=0.0.0.0
|
pyproject.toml
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Packages configs
|
2 |
+
[project]
|
3 |
+
name = "sorbobotapp"
|
4 |
+
version = "0.0.1"
|
5 |
+
requires-python = ">=3.10"
|
6 |
+
readme = "README.md"
|
7 |
+
|
8 |
+
[build-system]
|
9 |
+
requires = ["setuptools"]
|
10 |
+
|
11 |
+
## coverage
|
12 |
+
|
13 |
+
[tool.coverage.run]
|
14 |
+
branch = true
|
15 |
+
|
16 |
+
[tool.coverage.report]
|
17 |
+
skip_empty = true
|
18 |
+
fail_under = 70.00
|
19 |
+
precision = 2
|
20 |
+
|
21 |
+
## black
|
22 |
+
|
23 |
+
[tool.black]
|
24 |
+
target-version = ['py310']
|
25 |
+
|
26 |
+
## ruff
|
27 |
+
# Recommended ruff config for now, to be updated as we go along.
|
28 |
+
[tool.ruff]
|
29 |
+
target-version = 'py310'
|
30 |
+
|
31 |
+
# See all rules at https://beta.ruff.rs/docs/rules/
|
32 |
+
select = [
|
33 |
+
"E", # pycodestyle
|
34 |
+
"W", # pycodestyle
|
35 |
+
"F", # Pyflakes
|
36 |
+
"B", # flake8-bugbear
|
37 |
+
"C4", # flake8-comprehensions
|
38 |
+
"D", # flake8-docstrings
|
39 |
+
"I", # isort
|
40 |
+
"SIM", # flake8-simplify
|
41 |
+
"TCH", # flake8-type-checking
|
42 |
+
"TID", # flake8-tidy-imports
|
43 |
+
"Q", # flake8-quotes
|
44 |
+
"UP", # pyupgrade
|
45 |
+
"PT", # flake8-pytest-style
|
46 |
+
"RUF", # Ruff-specific rules
|
47 |
+
]
|
48 |
+
|
49 |
+
ignore = [
|
50 |
+
"E501", # "Line too long"
|
51 |
+
# -> line length already regulated by black
|
52 |
+
"PT011", # "pytest.raises() should specify expected exception"
|
53 |
+
# -> would imply to update tests every time you update exception message
|
54 |
+
"SIM102", # "Use a single `if` statement instead of nested `if` statements"
|
55 |
+
# -> too restrictive
|
56 |
+
"D100",
|
57 |
+
]
|
58 |
+
|
59 |
+
[tool.ruff.pydocstyle]
|
60 |
+
# Automatically disable rules that are incompatible with Google docstring convention
|
61 |
+
convention = "google"
|
62 |
+
|
63 |
+
[tool.ruff.pycodestyle]
|
64 |
+
max-doc-length = 88
|
65 |
+
|
66 |
+
[tool.ruff.flake8-tidy-imports]
|
67 |
+
ban-relative-imports = "all"
|
68 |
+
|
69 |
+
[tool.ruff.flake8-type-checking]
|
70 |
+
strict = true
|
71 |
+
runtime-evaluated-base-classes = ["pydantic.BaseModel"]
|
72 |
+
# Pydantic needs to be able to evaluate types at runtime
|
73 |
+
# see https://pypi.org/project/flake8-type-checking/ for flake8-type-checking documentation
|
74 |
+
# see https://beta.ruff.rs/docs/settings/#flake8-type-checking-runtime-evaluated-base-classes for ruff documentation
|
75 |
+
|
76 |
+
[tool.ruff.per-file-ignores]
|
77 |
+
# Allow missing docstrings for tests
|
78 |
+
"tests/**/*.py" = ["D100", "D103"]
|
79 |
+
|
80 |
+
## mypy
|
81 |
+
|
82 |
+
[tool.mypy]
|
83 |
+
python_version = "3.10"
|
84 |
+
# Enable all optional error checking flags, providing stricter type checking; see https://mypy.readthedocs.io/en/stable/getting_started.html#strict-mode-and-configuration
|
85 |
+
strict = true
|
86 |
+
|
87 |
+
# Type-check the interiors of functions without type annotations; if missing, mypy won't check function bodies without type hints, for instance those coming from third-party libraries
|
88 |
+
check_untyped_defs = true
|
89 |
+
|
90 |
+
# Make __init__.py file optional for package definitions; if missing, mypy requires __init__.py at packages roots, see https://mypy.readthedocs.io/en/stable/running_mypy.html#mapping-file-paths-to-modules
|
91 |
+
explicit_package_bases = true
|
requirements.txt
CHANGED
@@ -1,11 +1,14 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
1 |
+
black==23.11.0
|
2 |
+
gpt4all==1.0.12
|
3 |
+
langchain==0.0.313
|
4 |
+
openai==0.28.1
|
5 |
+
pandas==2.1.1
|
6 |
+
pgvector==0.2.3
|
7 |
+
pre-commit==3.5.0
|
8 |
+
psycopg2-binary==2.9.9
|
9 |
+
psycopg2==2.9.9
|
10 |
+
streamlit==1.27.2
|
11 |
+
streamlit-chat==0.1.1
|
12 |
+
SQLAlchemy==2.0.22
|
13 |
+
sqlite-vss==0.1.2
|
14 |
+
tiktoken==0.5.1
|
setup.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
from distutils.core import setup
|
4 |
+
|
5 |
+
setup(
|
6 |
+
name="sorbobotapp",
|
7 |
+
version="0.0.1",
|
8 |
+
authors=["Leo Bourrel <[email protected]>"],
|
9 |
+
package_dir={"": "sorbobotapp"},
|
10 |
+
)
|
sorbobotapp/__init__.py
ADDED
File without changes
|
sorbobotapp/app.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import streamlit as st
|
5 |
+
import streamlit.components.v1 as components
|
6 |
+
from chain import get_chain
|
7 |
+
from chat_history import insert_chat_history, insert_chat_history_articles
|
8 |
+
from connection import connect
|
9 |
+
from css import load_css
|
10 |
+
from langchain.callbacks import get_openai_callback
|
11 |
+
from message import Message
|
12 |
+
|
13 |
+
st.set_page_config(layout="wide")
|
14 |
+
|
15 |
+
st.title("Sorbobot - Le futur de la recherche scientifique interactive")
|
16 |
+
|
17 |
+
chat_column, doc_column = st.columns([2, 1])
|
18 |
+
|
19 |
+
conn = connect()
|
20 |
+
|
21 |
+
|
22 |
+
def initialize_session_state():
|
23 |
+
if "history" not in st.session_state:
|
24 |
+
st.session_state.history = []
|
25 |
+
if "token_count" not in st.session_state:
|
26 |
+
st.session_state.token_count = 0
|
27 |
+
if "conversation" not in st.session_state:
|
28 |
+
st.session_state.conversation = get_chain(conn)
|
29 |
+
|
30 |
+
|
31 |
+
def send_message_callback():
|
32 |
+
with st.spinner("Wait for it..."):
|
33 |
+
with get_openai_callback() as cb:
|
34 |
+
human_prompt = st.session_state.human_prompt.strip()
|
35 |
+
if len(human_prompt) == 0:
|
36 |
+
return
|
37 |
+
llm_response = st.session_state.conversation(human_prompt)
|
38 |
+
st.session_state.history.append(Message("human", human_prompt))
|
39 |
+
st.session_state.history.append(
|
40 |
+
Message(
|
41 |
+
"ai",
|
42 |
+
llm_response["answer"],
|
43 |
+
documents=llm_response["source_documents"],
|
44 |
+
)
|
45 |
+
)
|
46 |
+
st.session_state.token_count += cb.total_tokens
|
47 |
+
if os.environ.get("ENVIRONMENT") == "dev":
|
48 |
+
history_id = insert_chat_history(
|
49 |
+
conn, human_prompt, llm_response["answer"]
|
50 |
+
)
|
51 |
+
insert_chat_history_articles(
|
52 |
+
conn, history_id, llm_response["source_documents"]
|
53 |
+
)
|
54 |
+
|
55 |
+
|
56 |
+
def exemple_message_callback_button(args):
|
57 |
+
st.session_state.human_prompt = args
|
58 |
+
send_message_callback()
|
59 |
+
st.session_state.human_prompt = ""
|
60 |
+
|
61 |
+
|
62 |
+
def clear_history():
|
63 |
+
st.session_state.history.clear()
|
64 |
+
st.session_state.token_count = 0
|
65 |
+
st.session_state.conversation.memory.clear()
|
66 |
+
|
67 |
+
|
68 |
+
load_css()
|
69 |
+
initialize_session_state()
|
70 |
+
|
71 |
+
exemples = [
|
72 |
+
"Who has published influential research on quantum computing?",
|
73 |
+
"List any prominent authors in the field of artificial intelligence ethics?",
|
74 |
+
"Who are the leading experts on climate change mitigation strategies?",
|
75 |
+
]
|
76 |
+
|
77 |
+
with chat_column:
|
78 |
+
chat_placeholder = st.container()
|
79 |
+
prompt_placeholder = st.form("chat-form", clear_on_submit=True)
|
80 |
+
information_placeholder = st.container()
|
81 |
+
|
82 |
+
with chat_placeholder:
|
83 |
+
div = f"""
|
84 |
+
<div class="chat-row">
|
85 |
+
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/512/1129/1129398.png" width=32 height=32>
|
86 |
+
<div class="chat-bubble ai-bubble">
|
87 |
+
Welcome to SorboBot, a Hugging Face Space designed to revolutionize the way you find published articles. <br/>
|
88 |
+
Powered by a full export from ScanR and Hal at Sorbonne University, SorboBot utilizes advanced language model technology
|
89 |
+
to provide you with a list of published articles based on your prompt.
|
90 |
+
</div>
|
91 |
+
</div>
|
92 |
+
"""
|
93 |
+
st.markdown(div, unsafe_allow_html=True)
|
94 |
+
|
95 |
+
for chat in st.session_state.history:
|
96 |
+
div = f"""
|
97 |
+
<div class="chat-row
|
98 |
+
{'' if chat.origin == 'ai' else 'row-reverse'}">
|
99 |
+
<img class="chat-icon" src="https://cdn-icons-png.flaticon.com/512/{
|
100 |
+
'1129/1129398.png' if chat.origin == 'ai'
|
101 |
+
else '1077/1077012.png'}"
|
102 |
+
width=32 height=32>
|
103 |
+
<div class="chat-bubble
|
104 |
+
{'ai-bubble' if chat.origin == 'ai' else 'human-bubble'}">
|
105 |
+
​{chat.message}
|
106 |
+
</div>
|
107 |
+
</div>
|
108 |
+
"""
|
109 |
+
st.markdown(div, unsafe_allow_html=True)
|
110 |
+
|
111 |
+
for _ in range(3):
|
112 |
+
st.markdown("")
|
113 |
+
|
114 |
+
with prompt_placeholder:
|
115 |
+
st.markdown("**Chat**")
|
116 |
+
cols = st.columns((6, 1))
|
117 |
+
cols[0].text_input(
|
118 |
+
"Chat",
|
119 |
+
label_visibility="collapsed",
|
120 |
+
key="human_prompt",
|
121 |
+
)
|
122 |
+
cols[1].form_submit_button(
|
123 |
+
"Submit",
|
124 |
+
type="primary",
|
125 |
+
on_click=send_message_callback,
|
126 |
+
)
|
127 |
+
|
128 |
+
if st.session_state.token_count == 0:
|
129 |
+
information_placeholder.markdown("### Test me !")
|
130 |
+
for idx_exemple, exemple in enumerate(exemples):
|
131 |
+
information_placeholder.button(
|
132 |
+
exemple,
|
133 |
+
key=f"{idx_exemple}_button",
|
134 |
+
on_click=exemple_message_callback_button,
|
135 |
+
args=(exemple,),
|
136 |
+
)
|
137 |
+
|
138 |
+
st.button(
|
139 |
+
":new: Start a new conversation", on_click=clear_history, type="secondary"
|
140 |
+
)
|
141 |
+
|
142 |
+
if os.environ.get("ENVIRONMENT") == "dev":
|
143 |
+
information_placeholder.caption(
|
144 |
+
f"""
|
145 |
+
Used {st.session_state.token_count} tokens \n
|
146 |
+
Debug Langchain conversation:
|
147 |
+
{st.session_state.history}
|
148 |
+
"""
|
149 |
+
)
|
150 |
+
|
151 |
+
components.html(
|
152 |
+
"""
|
153 |
+
<script>
|
154 |
+
const streamlitDoc = window.parent.document;
|
155 |
+
|
156 |
+
const buttons = Array.from(
|
157 |
+
streamlitDoc.querySelectorAll('.stButton > button')
|
158 |
+
);
|
159 |
+
const submitButton = buttons.find(
|
160 |
+
el => el.innerText === 'Submit'
|
161 |
+
);
|
162 |
+
|
163 |
+
streamlitDoc.addEventListener('keydown', function(e) {
|
164 |
+
switch (e.key) {
|
165 |
+
case 'Enter':
|
166 |
+
submitButton.click();
|
167 |
+
break;
|
168 |
+
}
|
169 |
+
});
|
170 |
+
</script>
|
171 |
+
""",
|
172 |
+
height=0,
|
173 |
+
width=0,
|
174 |
+
)
|
175 |
+
|
176 |
+
with doc_column:
|
177 |
+
st.markdown("**Source documents**")
|
178 |
+
if len(st.session_state.history) > 0:
|
179 |
+
for doc in st.session_state.history[-1].documents:
|
180 |
+
doc_content = json.loads(doc.page_content)
|
181 |
+
doc_metadata = doc.metadata
|
182 |
+
|
183 |
+
expander = st.expander(doc_content["title"])
|
184 |
+
expander.markdown(
|
185 |
+
f"**HalID** : https://hal.science/{doc_metadata['hal_id']}"
|
186 |
+
)
|
187 |
+
expander.markdown(doc_metadata["abstract"])
|
188 |
+
expander.markdown(f"**Authors** : {doc_content['authors']}")
|
189 |
+
expander.markdown(f"**Keywords** : {doc_content['keywords']}")
|
190 |
+
expander.markdown(f"**Distance** : {doc_metadata['distance']}")
|
sorbobotapp/chain.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import sqlalchemy
|
4 |
+
from conversation_retrieval_chain import CustomConversationalRetrievalChain
|
5 |
+
from langchain.chains.conversation.memory import ConversationBufferMemory
|
6 |
+
from langchain.embeddings import GPT4AllEmbeddings
|
7 |
+
from langchain.llms import OpenAI
|
8 |
+
from vector_store import CustomVectorStore
|
9 |
+
|
10 |
+
|
11 |
+
def get_chain(conn: sqlalchemy.engine.Connection):
|
12 |
+
embeddings = GPT4AllEmbeddings()
|
13 |
+
|
14 |
+
db = CustomVectorStore(
|
15 |
+
embedding_function=embeddings,
|
16 |
+
table_name="article",
|
17 |
+
column_name="abstract_embedding",
|
18 |
+
connection=conn,
|
19 |
+
)
|
20 |
+
|
21 |
+
retriever = db.as_retriever()
|
22 |
+
|
23 |
+
llm = OpenAI(
|
24 |
+
temperature=0,
|
25 |
+
openai_api_key=os.environ["OPENAI_API_KEY"],
|
26 |
+
model="text-davinci-003",
|
27 |
+
)
|
28 |
+
|
29 |
+
memory = ConversationBufferMemory(
|
30 |
+
output_key="answer", memory_key="chat_history", return_messages=True
|
31 |
+
)
|
32 |
+
return CustomConversationalRetrievalChain.from_llm(
|
33 |
+
llm=llm,
|
34 |
+
retriever=retriever,
|
35 |
+
verbose=True,
|
36 |
+
memory=memory,
|
37 |
+
return_source_documents=True,
|
38 |
+
max_tokens_limit=3700,
|
39 |
+
)
|
sorbobotapp/chat_history.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
import sqlalchemy
|
5 |
+
from sqlalchemy import text
|
6 |
+
from sqlalchemy.orm import Session
|
7 |
+
|
8 |
+
|
9 |
+
def insert_chat_history(conn: sqlalchemy.engine.Connection, query: str, answer: str):
|
10 |
+
with Session(conn) as conn:
|
11 |
+
conn.execute(
|
12 |
+
text("INSERT INTO chat_history (query, answer) VALUES (:query, :answer);"),
|
13 |
+
[
|
14 |
+
{
|
15 |
+
"query": query,
|
16 |
+
"answer": answer,
|
17 |
+
}
|
18 |
+
],
|
19 |
+
)
|
20 |
+
conn.commit()
|
21 |
+
|
22 |
+
result = conn.execute(
|
23 |
+
text("SELECT id FROM chat_history ORDER BY id DESC LIMIT 1;")
|
24 |
+
)
|
25 |
+
last_row_id = result.fetchone()[0]
|
26 |
+
conn.commit()
|
27 |
+
return last_row_id
|
28 |
+
|
29 |
+
|
30 |
+
def insert_chat_history_articles(
|
31 |
+
conn: sqlalchemy.engine.Connection, chat_history_id: int, articles: List[str]
|
32 |
+
):
|
33 |
+
with Session(conn) as conn:
|
34 |
+
conn.execute(
|
35 |
+
text(
|
36 |
+
"""
|
37 |
+
INSERT INTO chat_history_articles (chat_history_id, article_id)
|
38 |
+
VALUES (:chat_history_id, :article_id) ON CONFLICT DO NOTHING;
|
39 |
+
"""
|
40 |
+
),
|
41 |
+
[
|
42 |
+
{
|
43 |
+
"chat_history_id": chat_history_id,
|
44 |
+
"article_id": article.metadata["id"],
|
45 |
+
}
|
46 |
+
for article in articles
|
47 |
+
],
|
48 |
+
)
|
49 |
+
conn.commit()
|
sorbobotapp/connection.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlalchemy
|
2 |
+
|
3 |
+
CONNECTION_STRING = "postgresql+psycopg2://postgres@/sorbobot?host=0.0.0.0"
|
4 |
+
|
5 |
+
|
6 |
+
def connect() -> sqlalchemy.engine.Connection:
|
7 |
+
engine = sqlalchemy.create_engine(CONNECTION_STRING, pool_pre_ping=True)
|
8 |
+
|
9 |
+
conn = engine.connect()
|
10 |
+
return conn
|
sorbobotapp/conversation_retrieval_chain.py
ADDED
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import inspect
|
2 |
+
import json
|
3 |
+
from typing import Any, Dict, Optional
|
4 |
+
|
5 |
+
from keyword_extraction import KeywordExtractor
|
6 |
+
from langchain.callbacks.manager import CallbackManagerForChainRun
|
7 |
+
from langchain.chains.conversational_retrieval.base import (
|
8 |
+
ConversationalRetrievalChain,
|
9 |
+
_get_chat_history,
|
10 |
+
)
|
11 |
+
from langchain.schema import Document
|
12 |
+
|
13 |
+
|
14 |
+
class CustomConversationalRetrievalChain(ConversationalRetrievalChain):
|
15 |
+
keyword_extractor: KeywordExtractor = KeywordExtractor()
|
16 |
+
|
17 |
+
def _handle_docs(self, docs):
|
18 |
+
if len(docs) == 0:
|
19 |
+
return False, "No documents found. Can you rephrase ?"
|
20 |
+
elif len(docs) == 1:
|
21 |
+
return False, "Only one document found. Can you rephrase ?"
|
22 |
+
elif len(docs) > 10:
|
23 |
+
return False, "Too many documents found. Can you specify your request ?"
|
24 |
+
return True, ""
|
25 |
+
|
26 |
+
def rerank_documents(self, question: str, docs: list[Document]) -> list[Document]:
|
27 |
+
"""Rerank documents based on the number of similar keywords
|
28 |
+
|
29 |
+
Args:
|
30 |
+
question (str): Orinal question
|
31 |
+
docs (list[Document]): List of documents
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
list[Document]: List of documents sorted by the number of similar keywords
|
35 |
+
"""
|
36 |
+
keywords = self.keyword_extractor(question)
|
37 |
+
|
38 |
+
for doc in docs:
|
39 |
+
doc.metadata["similar_keyword"] = 0
|
40 |
+
doc_keywords = json.loads(doc.page_content)["keywords"]
|
41 |
+
if doc_keywords is None:
|
42 |
+
continue
|
43 |
+
doc_keywords = doc_keywords.lower().split(",")
|
44 |
+
|
45 |
+
for kw in keywords:
|
46 |
+
if kw.lower() in doc_keywords:
|
47 |
+
doc.metadata["similar_keyword"] += 1
|
48 |
+
print("similar keyword : ", kw)
|
49 |
+
|
50 |
+
docs = sorted(docs, key=lambda x: x.metadata["similar_keyword"])
|
51 |
+
return docs
|
52 |
+
|
53 |
+
def _call(
|
54 |
+
self,
|
55 |
+
inputs: Dict[str, Any],
|
56 |
+
run_manager: Optional[CallbackManagerForChainRun] = None,
|
57 |
+
) -> Dict[str, Any]:
|
58 |
+
_run_manager = run_manager or CallbackManagerForChainRun.get_noop_manager()
|
59 |
+
question = inputs["question"]
|
60 |
+
get_chat_history = self.get_chat_history or _get_chat_history
|
61 |
+
chat_history_str = get_chat_history(inputs["chat_history"])
|
62 |
+
|
63 |
+
if chat_history_str:
|
64 |
+
callbacks = _run_manager.get_child()
|
65 |
+
new_question = self.question_generator.run(
|
66 |
+
question=question, chat_history=chat_history_str, callbacks=callbacks
|
67 |
+
)
|
68 |
+
else:
|
69 |
+
new_question = question
|
70 |
+
accepts_run_manager = (
|
71 |
+
"run_manager" in inspect.signature(self._get_docs).parameters
|
72 |
+
)
|
73 |
+
if accepts_run_manager:
|
74 |
+
docs = self._get_docs(new_question, inputs, run_manager=_run_manager)
|
75 |
+
else:
|
76 |
+
docs = self._get_docs(new_question, inputs) # type: ignore[call-arg]
|
77 |
+
|
78 |
+
valid_docs, message = self._handle_docs(docs)
|
79 |
+
if not valid_docs:
|
80 |
+
return {
|
81 |
+
self.output_key: message,
|
82 |
+
"source_documents": docs,
|
83 |
+
}
|
84 |
+
|
85 |
+
# Add reranking
|
86 |
+
docs = self.rerank_documents(new_question, docs)
|
87 |
+
|
88 |
+
new_inputs = inputs.copy()
|
89 |
+
if self.rephrase_question:
|
90 |
+
new_inputs["question"] = new_question
|
91 |
+
new_inputs["chat_history"] = chat_history_str
|
92 |
+
answer = self.combine_docs_chain.run(
|
93 |
+
input_documents=docs, callbacks=_run_manager.get_child(), **new_inputs
|
94 |
+
)
|
95 |
+
output: Dict[str, Any] = {self.output_key: answer}
|
96 |
+
if self.return_source_documents:
|
97 |
+
output["source_documents"] = docs
|
98 |
+
if self.return_generated_question:
|
99 |
+
output["generated_question"] = new_question
|
100 |
+
return output
|
sorbobotapp/css.py
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
|
3 |
+
|
4 |
+
def load_css():
|
5 |
+
with open("sorbobotapp/static/styles.css", "r") as f:
|
6 |
+
css = f"<style>{f.read()}</style>"
|
7 |
+
st.markdown(css, unsafe_allow_html=True)
|
sorbobotapp/keyword_extraction.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any
|
2 |
+
|
3 |
+
from langchain.chat_models import ChatOpenAI
|
4 |
+
from langchain.output_parsers import NumberedListOutputParser
|
5 |
+
from langchain.prompts import ChatPromptTemplate
|
6 |
+
from utils import str_to_list
|
7 |
+
|
8 |
+
query_template = """
|
9 |
+
You are a bi-lingual (french and english) linguistic teacher working at a top-tier university.
|
10 |
+
We are conducting a research project that requires the extraction of keywords from chatbot queries.
|
11 |
+
Below, you will find a query. Please identify and rank the three most important keywords or phrases (n-grams) based on their relevance to the main topic of the query.
|
12 |
+
For each keyword or phrase, assign it to one of the following categories: ["University / Company", "Research domain", "Country", "Name", "Other"].
|
13 |
+
An 'n-gram' refers to a contiguous sequence of words, where 'n' can be 1 for a single word, 2 for a pair of words, and so on, up to two words in length.
|
14 |
+
Please ensure not to list more than three n-grams in total.
|
15 |
+
Your expertise in linguistic analysis is crucial for the success of this project. Thank you for your contribution.
|
16 |
+
|
17 |
+
Please attach your ranked list in the following format:
|
18 |
+
1. Keyword/Phrase - Category
|
19 |
+
2. Keyword/Phrase - Category
|
20 |
+
3. Keyword/Phrase - Category
|
21 |
+
|
22 |
+
You must be concise and don't need to justify your choices.
|
23 |
+
```
|
24 |
+
{query}
|
25 |
+
```
|
26 |
+
"""
|
27 |
+
|
28 |
+
output_parser = NumberedListOutputParser()
|
29 |
+
format_instructions = output_parser.get_format_instructions()
|
30 |
+
|
31 |
+
|
32 |
+
class KeywordExtractor:
|
33 |
+
def __init__(self):
|
34 |
+
super().__init__()
|
35 |
+
self.model = ChatOpenAI()
|
36 |
+
self.prompt = ChatPromptTemplate.from_template(
|
37 |
+
template=query_template,
|
38 |
+
)
|
39 |
+
|
40 |
+
self.chain = self.prompt | self.model # | output_parser
|
41 |
+
|
42 |
+
def __call__(
|
43 |
+
self, inputs: str, filter_categories: list[str] = ["Research domain"]
|
44 |
+
) -> Any:
|
45 |
+
output = self.chain.invoke({"query": inputs})
|
46 |
+
|
47 |
+
keywords = output_parser.parse(output.content)
|
48 |
+
|
49 |
+
filtered_keywords = []
|
50 |
+
for keyword in keywords:
|
51 |
+
if " - " not in keyword:
|
52 |
+
continue
|
53 |
+
|
54 |
+
keyword, category = keyword.split(" - ", maxsplit=2)
|
55 |
+
if category in filter_categories:
|
56 |
+
filtered_keywords.append(keyword)
|
57 |
+
|
58 |
+
return filtered_keywords
|
sorbobotapp/message.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import List, Literal, Optional
|
3 |
+
|
4 |
+
from langchain.schema import Document
|
5 |
+
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class Message:
|
9 |
+
"""Class for keeping track of a chat message."""
|
10 |
+
|
11 |
+
origin: Literal["human", "ai"]
|
12 |
+
message: str
|
13 |
+
documents: Optional[List[Document]] = None
|
14 |
+
|
15 |
+
def __repr__(self) -> str:
|
16 |
+
return f"Message(origin={self.origin}, message={self.message})"
|
sorbobotapp/model.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlalchemy
|
2 |
+
from pgvector.sqlalchemy import Vector
|
3 |
+
from sqlalchemy.orm import declarative_base
|
4 |
+
|
5 |
+
Base = declarative_base() # type: Any
|
6 |
+
|
7 |
+
|
8 |
+
class Article(Base):
|
9 |
+
"""Embedding store."""
|
10 |
+
|
11 |
+
__tablename__ = "article"
|
12 |
+
|
13 |
+
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, nullable=False)
|
14 |
+
title = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
15 |
+
abstract = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
16 |
+
embedding: Vector = sqlalchemy.Column("abstract_embedding", Vector(None))
|
17 |
+
doi = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
sorbobotapp/models/article.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sqlalchemy
|
2 |
+
from pgvector.sqlalchemy import Vector
|
3 |
+
from sqlalchemy.orm import declarative_base
|
4 |
+
|
5 |
+
Base = declarative_base() # type: Any
|
6 |
+
|
7 |
+
|
8 |
+
class Article(Base):
|
9 |
+
"""Embedding store."""
|
10 |
+
|
11 |
+
__tablename__ = "article"
|
12 |
+
|
13 |
+
id = sqlalchemy.Column(sqlalchemy.Integer, primary_key=True, nullable=False)
|
14 |
+
title = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
15 |
+
abstract = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
16 |
+
embedding: Vector = sqlalchemy.Column("abstract_embedding", Vector(None))
|
17 |
+
doi = sqlalchemy.Column(sqlalchemy.String, nullable=True)
|
sorbobotapp/models/distance.py
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import enum
|
2 |
+
|
3 |
+
distance_strategy_limit = {
|
4 |
+
"l2": 1.05,
|
5 |
+
"cosine": 0.55,
|
6 |
+
"inner": 1.0,
|
7 |
+
}
|
8 |
+
|
9 |
+
|
10 |
+
class DistanceStrategy(str, enum.Enum):
|
11 |
+
"""Enumerator of the Distance strategies."""
|
12 |
+
|
13 |
+
EUCLIDEAN = "l2"
|
14 |
+
COSINE = "cosine"
|
15 |
+
MAX_INNER_PRODUCT = "inner"
|
sorbobotapp/static/ai_icon.png
ADDED
![]() |
sorbobotapp/static/styles.css
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.chat-row {
|
2 |
+
display: flex;
|
3 |
+
margin: 5px;
|
4 |
+
width: 100%;
|
5 |
+
}
|
6 |
+
|
7 |
+
.row-reverse {
|
8 |
+
flex-direction: row-reverse;
|
9 |
+
}
|
10 |
+
|
11 |
+
.appview-container .main .block-container {
|
12 |
+
padding-top: 2rem;
|
13 |
+
}
|
14 |
+
|
15 |
+
.chat-bubble {
|
16 |
+
font-family: "Source Sans Pro", sans-serif, "Segoe UI", "Roboto", sans-serif;
|
17 |
+
border: 1px solid transparent;
|
18 |
+
padding: 5px 10px;
|
19 |
+
margin: 0px 7px;
|
20 |
+
max-width: 70%;
|
21 |
+
}
|
22 |
+
|
23 |
+
.ai-bubble {
|
24 |
+
background: rgb(240, 242, 246);
|
25 |
+
border-radius: 10px;
|
26 |
+
}
|
27 |
+
|
28 |
+
.human-bubble {
|
29 |
+
background: linear-gradient(135deg, rgb(0, 178, 255) 0%, rgb(0, 106, 255) 100%);
|
30 |
+
color: white;
|
31 |
+
border-radius: 20px;
|
32 |
+
}
|
33 |
+
|
34 |
+
.chat-icon {
|
35 |
+
border-radius: 5px;
|
36 |
+
}
|
sorbobotapp/static/user_icon.png
ADDED
![]() |
sorbobotapp/utils.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
|
3 |
+
|
4 |
+
def str_to_list(str_input: str) -> list[str]:
|
5 |
+
if isinstance(str_input, list):
|
6 |
+
return str_input
|
7 |
+
|
8 |
+
splits = re.split(r"', '|\", \"|', \"|\", '", str_input)
|
9 |
+
splits = [
|
10 |
+
split.removeprefix("[")
|
11 |
+
.removesuffix("]")
|
12 |
+
.removeprefix("(")
|
13 |
+
.removesuffix(")")
|
14 |
+
.removeprefix("'")
|
15 |
+
.removesuffix("'")
|
16 |
+
.removeprefix('"')
|
17 |
+
.removesuffix('"')
|
18 |
+
for split in splits
|
19 |
+
]
|
20 |
+
return splits
|
sorbobotapp/vector_store.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import contextlib
|
4 |
+
import json
|
5 |
+
import logging
|
6 |
+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, Type
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
import sqlalchemy
|
10 |
+
from langchain.docstore.document import Document
|
11 |
+
from langchain.schema.embeddings import Embeddings
|
12 |
+
from langchain.vectorstores.base import VectorStore
|
13 |
+
from models.article import Article
|
14 |
+
from models.distance import DistanceStrategy, distance_strategy_limit
|
15 |
+
from sqlalchemy import delete, text
|
16 |
+
from sqlalchemy.orm import Session
|
17 |
+
from utils import str_to_list
|
18 |
+
|
19 |
+
DEFAULT_DISTANCE_STRATEGY = DistanceStrategy.COSINE
|
20 |
+
|
21 |
+
_LANGCHAIN_DEFAULT_COLLECTION_NAME = "langchain"
|
22 |
+
|
23 |
+
|
24 |
+
def _results_to_docs(docs_and_scores: Any) -> List[Document]:
|
25 |
+
"""Return docs from docs and scores."""
|
26 |
+
return [doc for doc, _ in docs_and_scores]
|
27 |
+
|
28 |
+
|
29 |
+
class CustomVectorStore(VectorStore):
|
30 |
+
"""`Postgres`/`PGVector` vector store.
|
31 |
+
|
32 |
+
To use, you should have the ``pgvector`` python package installed.
|
33 |
+
|
34 |
+
Args:
|
35 |
+
connection: Postgres connection string.
|
36 |
+
embedding_function: Any embedding function implementing
|
37 |
+
`langchain.embeddings.base.Embeddings` interface.
|
38 |
+
table_name: The name of the collection to use. (default: langchain)
|
39 |
+
NOTE: This is not the name of the table, but the name of the collection.
|
40 |
+
The tables will be created when initializing the store (if not exists)
|
41 |
+
So, make sure the user has the right permissions to create tables.
|
42 |
+
distance_strategy: The distance strategy to use. (default: COSINE)
|
43 |
+
pre_delete_collection: If True, will delete the collection if it exists.
|
44 |
+
(default: False). Useful for testing.
|
45 |
+
|
46 |
+
Example:
|
47 |
+
.. code-block:: python
|
48 |
+
|
49 |
+
from langchain.vectorstores import PGVector
|
50 |
+
from langchain.embeddings.openai import OpenAIEmbeddings
|
51 |
+
|
52 |
+
COLLECTION_NAME = "state_of_the_union_test"
|
53 |
+
embeddings = OpenAIEmbeddings()
|
54 |
+
vectorestore = PGVector.from_documents(
|
55 |
+
embedding=embeddings,
|
56 |
+
documents=docs,
|
57 |
+
table_name=COLLECTION_NAME,
|
58 |
+
connection=connection,
|
59 |
+
)
|
60 |
+
|
61 |
+
|
62 |
+
"""
|
63 |
+
|
64 |
+
def __init__(
|
65 |
+
self,
|
66 |
+
connection: sqlalchemy.engine.Connection,
|
67 |
+
embedding_function: Embeddings,
|
68 |
+
table_name: str,
|
69 |
+
column_name: str,
|
70 |
+
collection_metadata: Optional[dict] = None,
|
71 |
+
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
72 |
+
pre_delete_collection: bool = False,
|
73 |
+
logger: Optional[logging.Logger] = None,
|
74 |
+
) -> None:
|
75 |
+
self._conn = connection
|
76 |
+
self.embedding_function = embedding_function
|
77 |
+
self.table_name = table_name
|
78 |
+
self.column_name = column_name
|
79 |
+
self.collection_metadata = collection_metadata
|
80 |
+
self._distance_strategy = distance_strategy
|
81 |
+
self.pre_delete_collection = pre_delete_collection
|
82 |
+
self.logger = logger or logging.getLogger(__name__)
|
83 |
+
self.__post_init__()
|
84 |
+
|
85 |
+
def __post_init__(
|
86 |
+
self,
|
87 |
+
) -> None:
|
88 |
+
"""
|
89 |
+
Initialize the store.
|
90 |
+
"""
|
91 |
+
# self._conn = self.connect()
|
92 |
+
|
93 |
+
self.EmbeddingStore = Article
|
94 |
+
|
95 |
+
@property
|
96 |
+
def embeddings(self) -> Embeddings:
|
97 |
+
return self.embedding_function
|
98 |
+
|
99 |
+
@contextlib.contextmanager
|
100 |
+
def _make_session(self) -> Generator[Session, None, None]:
|
101 |
+
"""Create a context manager for the session, bind to _conn string."""
|
102 |
+
yield Session(self._conn)
|
103 |
+
|
104 |
+
def add_embeddings(
|
105 |
+
self,
|
106 |
+
texts: Iterable[str],
|
107 |
+
embeddings: List[List[float]],
|
108 |
+
metadatas: Optional[List[dict]] = None,
|
109 |
+
ids: Optional[List[str]] = None,
|
110 |
+
**kwargs: Any,
|
111 |
+
) -> List[str]:
|
112 |
+
"""Add embeddings to the vectorstore.
|
113 |
+
|
114 |
+
Args:
|
115 |
+
texts: Iterable of strings to add to the vectorstore.
|
116 |
+
embeddings: List of list of embedding vectors.
|
117 |
+
metadatas: List of metadatas associated with the texts.
|
118 |
+
kwargs: vectorstore specific parameters
|
119 |
+
"""
|
120 |
+
if not metadatas:
|
121 |
+
metadatas = [{} for _ in texts]
|
122 |
+
|
123 |
+
with Session(self._conn) as session:
|
124 |
+
for txt, metadata, embedding, id in zip(texts, metadatas, embeddings, ids):
|
125 |
+
embedding_store = self.EmbeddingStore(
|
126 |
+
embedding=embedding,
|
127 |
+
document=txt,
|
128 |
+
cmetadata=metadata,
|
129 |
+
custom_id=id,
|
130 |
+
)
|
131 |
+
session.add(embedding_store)
|
132 |
+
session.commit()
|
133 |
+
|
134 |
+
return ids
|
135 |
+
|
136 |
+
def add_texts(
|
137 |
+
self,
|
138 |
+
texts: Iterable[str],
|
139 |
+
metadatas: Optional[List[dict]] = None,
|
140 |
+
ids: Optional[List[str]] = None,
|
141 |
+
**kwargs: Any,
|
142 |
+
) -> List[str]:
|
143 |
+
"""Run more texts through the embeddings and add to the vectorstore.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
texts: Iterable of strings to add to the vectorstore.
|
147 |
+
metadatas: Optional list of metadatas associated with the texts.
|
148 |
+
kwargs: vectorstore specific parameters
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
List of ids from adding the texts into the vectorstore.
|
152 |
+
"""
|
153 |
+
embeddings = self.embedding_function.embed_documents(list(texts))
|
154 |
+
return self.add_embeddings(
|
155 |
+
texts=texts, embeddings=embeddings, metadatas=metadatas, ids=ids, **kwargs
|
156 |
+
)
|
157 |
+
|
158 |
+
def similarity_search(
|
159 |
+
self,
|
160 |
+
query: str,
|
161 |
+
k: int = 4,
|
162 |
+
filter: Optional[dict] = None,
|
163 |
+
**kwargs: Any,
|
164 |
+
) -> List[Document]:
|
165 |
+
"""Run similarity search with PGVector with distance.
|
166 |
+
|
167 |
+
Args:
|
168 |
+
query (str): Query text to search for.
|
169 |
+
k (int): Number of results to return. Defaults to 4.
|
170 |
+
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
171 |
+
|
172 |
+
Returns:
|
173 |
+
List of Documents most similar to the query.
|
174 |
+
"""
|
175 |
+
embedding = self.embedding_function.embed_query(text=query)
|
176 |
+
return self.similarity_search_by_vector(
|
177 |
+
embedding=embedding,
|
178 |
+
k=k,
|
179 |
+
)
|
180 |
+
|
181 |
+
def similarity_search_with_score(
|
182 |
+
self,
|
183 |
+
query: str,
|
184 |
+
k: int = 4,
|
185 |
+
filter: Optional[dict] = None,
|
186 |
+
) -> List[Tuple[Document, float]]:
|
187 |
+
"""Return docs most similar to query.
|
188 |
+
|
189 |
+
Args:
|
190 |
+
query: Text to look up documents similar to.
|
191 |
+
k: Number of Documents to return. Defaults to 4.
|
192 |
+
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
List of Documents most similar to the query and score for each.
|
196 |
+
"""
|
197 |
+
embedding = self.embedding_function.embed_query(query)
|
198 |
+
docs = self.similarity_search_with_score_by_vector(embedding=embedding, k=k)
|
199 |
+
return docs
|
200 |
+
|
201 |
+
@property
|
202 |
+
def distance_strategy(self) -> str | None:
|
203 |
+
if self._distance_strategy == DistanceStrategy.EUCLIDEAN:
|
204 |
+
return "<->"
|
205 |
+
elif self._distance_strategy == DistanceStrategy.COSINE:
|
206 |
+
return "<=>"
|
207 |
+
elif self._distance_strategy == DistanceStrategy.MAX_INNER_PRODUCT:
|
208 |
+
return "<#>"
|
209 |
+
else:
|
210 |
+
raise ValueError(
|
211 |
+
f"Got unexpected value for distance: {self._distance_strategy}. "
|
212 |
+
f"Should be one of {', '.join([ds.value for ds in DistanceStrategy])}."
|
213 |
+
)
|
214 |
+
|
215 |
+
def similarity_search_with_score_by_vector(
|
216 |
+
self,
|
217 |
+
embedding: List[float],
|
218 |
+
k: int = 4,
|
219 |
+
) -> List[Tuple[Document, float]]:
|
220 |
+
results = self.__query_collection(embedding=embedding, k=k)
|
221 |
+
|
222 |
+
return self._results_to_docs_and_scores(results)
|
223 |
+
|
224 |
+
@staticmethod
|
225 |
+
def _fetch_title(title: str, abstract: str):
|
226 |
+
if len(title) > 0:
|
227 |
+
return title
|
228 |
+
return abstract.split(".")[0]
|
229 |
+
|
230 |
+
def _results_to_docs_and_scores(self, results: Any) -> List[Tuple[Document, float]]:
|
231 |
+
"""Return docs and scores from results."""
|
232 |
+
docs = [
|
233 |
+
(
|
234 |
+
Document(
|
235 |
+
page_content=json.dumps(
|
236 |
+
{
|
237 |
+
"title": self._fetch_title(
|
238 |
+
result["title"][0], result["abstract"][0]
|
239 |
+
),
|
240 |
+
"authors": result["authors"],
|
241 |
+
"keywords": result["keywords"],
|
242 |
+
}
|
243 |
+
),
|
244 |
+
metadata={
|
245 |
+
"id": result["id"],
|
246 |
+
"doi": result["doi"],
|
247 |
+
"hal_id": result["hal_id"],
|
248 |
+
"distance": result["distance"],
|
249 |
+
"abstract": result["abstract"][0],
|
250 |
+
},
|
251 |
+
),
|
252 |
+
result["distance"] if self.embedding_function is not None else None,
|
253 |
+
)
|
254 |
+
for result in results
|
255 |
+
]
|
256 |
+
return docs
|
257 |
+
|
258 |
+
def __query_collection(
|
259 |
+
self,
|
260 |
+
embedding: List[float],
|
261 |
+
k: int = 4,
|
262 |
+
) -> List[Any]:
|
263 |
+
"""Query the collection."""
|
264 |
+
|
265 |
+
limit = distance_strategy_limit[self._distance_strategy]
|
266 |
+
with Session(self._conn) as session:
|
267 |
+
results = session.execute(
|
268 |
+
text(
|
269 |
+
f"""
|
270 |
+
select
|
271 |
+
a.id,
|
272 |
+
a.title_en,
|
273 |
+
a.doi,
|
274 |
+
a.hal_id,
|
275 |
+
a.abstract_en,
|
276 |
+
string_agg(distinct keyword."name", ', ') as keywords,
|
277 |
+
string_agg(distinct author."name", ', ') as authors,
|
278 |
+
abstract_embedding_en {self.distance_strategy} '{str(embedding)}' as distance
|
279 |
+
from article a
|
280 |
+
left join article_keyword ON article_keyword.article_id = a.id
|
281 |
+
left join keyword on article_keyword.keyword_id = keyword.id
|
282 |
+
left join article_author ON article_author.article_id = a.id
|
283 |
+
left join author on author.id = article_author.author_id
|
284 |
+
where
|
285 |
+
abstract_en != '' and
|
286 |
+
abstract_en != 'None' and
|
287 |
+
abstract_embedding_en {self.distance_strategy} '{str(embedding)}' < {limit}
|
288 |
+
GROUP BY a.id
|
289 |
+
ORDER BY distance
|
290 |
+
LIMIT 100;
|
291 |
+
"""
|
292 |
+
)
|
293 |
+
)
|
294 |
+
results = results.fetchall()
|
295 |
+
results = pd.DataFrame(
|
296 |
+
results,
|
297 |
+
columns=[
|
298 |
+
"id",
|
299 |
+
"title",
|
300 |
+
"doi",
|
301 |
+
"hal_id",
|
302 |
+
"abstract",
|
303 |
+
"keywords",
|
304 |
+
"authors",
|
305 |
+
"distance",
|
306 |
+
],
|
307 |
+
)
|
308 |
+
results["abstract"] = results["abstract"].apply(str_to_list)
|
309 |
+
results["title"] = results["title"].apply(str_to_list)
|
310 |
+
results = results.to_dict(orient="records")
|
311 |
+
return results
|
312 |
+
|
313 |
+
def similarity_search_by_vector(
|
314 |
+
self,
|
315 |
+
embedding: List[float],
|
316 |
+
k: int = 4,
|
317 |
+
**kwargs: Any,
|
318 |
+
) -> List[Document]:
|
319 |
+
"""Return docs most similar to embedding vector.
|
320 |
+
|
321 |
+
Args:
|
322 |
+
embedding: Embedding to look up documents similar to.
|
323 |
+
k: Number of Documents to return. Defaults to 4.
|
324 |
+
filter (Optional[Dict[str, str]]): Filter by metadata. Defaults to None.
|
325 |
+
|
326 |
+
Returns:
|
327 |
+
List of Documents most similar to the query vector.
|
328 |
+
"""
|
329 |
+
docs_and_scores = self.similarity_search_with_score_by_vector(
|
330 |
+
embedding=embedding, k=k
|
331 |
+
)
|
332 |
+
return _results_to_docs(docs_and_scores)
|
333 |
+
|
334 |
+
@classmethod
|
335 |
+
def from_texts(
|
336 |
+
cls: Type[CustomVectorStore],
|
337 |
+
texts: List[str],
|
338 |
+
embedding: Embeddings,
|
339 |
+
metadatas: Optional[List[dict]] = None,
|
340 |
+
table_name: str = _LANGCHAIN_DEFAULT_COLLECTION_NAME,
|
341 |
+
distance_strategy: DistanceStrategy = DEFAULT_DISTANCE_STRATEGY,
|
342 |
+
ids: Optional[List[str]] = None,
|
343 |
+
pre_delete_collection: bool = False,
|
344 |
+
**kwargs: Any,
|
345 |
+
) -> CustomVectorStore:
|
346 |
+
"""
|
347 |
+
Return VectorStore initialized from texts and embeddings.
|
348 |
+
Postgres connection string is required
|
349 |
+
"Either pass it as a parameter
|
350 |
+
or set the PGVECTOR_CONNECTION_STRING environment variable.
|
351 |
+
"""
|
352 |
+
embeddings = embedding.embed_documents(list(texts))
|
353 |
+
|
354 |
+
return cls.__from(
|
355 |
+
texts,
|
356 |
+
embeddings,
|
357 |
+
embedding,
|
358 |
+
metadatas=metadatas,
|
359 |
+
ids=ids,
|
360 |
+
table_name=table_name,
|
361 |
+
distance_strategy=distance_strategy,
|
362 |
+
pre_delete_collection=pre_delete_collection,
|
363 |
+
**kwargs,
|
364 |
+
)
|