space02 / vec_db.py
thefish1's picture
u
6dcc10e
import weaviate
import pandas as pd
import torch
import json
from transformers import AutoTokenizer, AutoModel
import subprocess
import os
# 设置 Matplotlib 缓存目录为可写的目录
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
# 设置 Hugging Face Transformers 缓存目录为可写的目录
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
#
# try:
# # 运行 Docker 容器的命令
# command = [
# "docker", "run",
# "-p", "8080:8080",
# "-p", "50051:50051",
# "cr.weaviate.io/semitechnologies/weaviate:1.24.20"
# ]
#
# # 执行命令
# subprocess.run(command, check=True)
# print("Docker container is running.")
#
# except subprocess.CalledProcessError as e:
# print(f"An error occurred: {e}")
class_name = 'Lhnjames123321'
auth_config = weaviate.AuthApiKey(api_key="8wNsHV3Enc2PNVL8Bspadh21qYAfAvnK2ux3")
client = weaviate.Client(
url="https://3a8sbx3s66by10yxginaa.c0.asia-southeast1.gcp.weaviate.cloud",
auth_client_secret=auth_config
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained("bert-base-chinese").to(device)
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")
def encode_sentences(sentences, model, tokenizer, device):
inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True, max_length=512)
inputs.to(device)
with torch.no_grad():
outputs = model(**inputs)
embeddings = outputs.last_hidden_state.mean(dim=1)
return embeddings.cpu().numpy()
# def class_exists(client, class_name):
# existing_classes = client.schema.get_classes()
# return any(cls['class'] == class_name for cls in existing_classes)
def init_weaviate():
# if class_exists(client, class_name)==0:
# class_obj = {
# 'class': class_name,
# 'vectorIndexConfig': {
# 'distance': 'cosine'
# },
# }
# client.schema.create_class(class_obj)
file_path = 'data.json'
sentence_data = []
with open(file_path, 'r', encoding='utf-8') as f:
for line in f:
try:
data = json.loads(line.strip())
sentence1 = data.get('response', '')
sentence_data.append(sentence1)
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
continue
sentence_embeddings = encode_sentences(sentence_data, model, tokenizer, device)
data = {'sentence': sentence_data,
'embeddings': sentence_embeddings.tolist()}
df = pd.DataFrame(data)
with client.batch(batch_size=100) as batch:
for i in range(df.shape[0]):
print(f'importing data: {i + 1}/{df.shape[0]}')
properties = {
'sentence_id': i + 1,
'sentence': df.sentence[i],
}
custom_vector = df.embeddings[i]
client.batch.add_data_object(
properties,
class_name=class_name,
vector=custom_vector
)
print('import completed')
def use_weaviate(input_str):
query = encode_sentences([input_str], model, tokenizer, device)[0].tolist()
nearVector = {
'vector': query
}
response = (
client.query
.get(class_name, ['sentence_id', 'sentence'])
.with_near_vector(nearVector)
.with_limit(5)
.with_additional(['distance'])
.do()
)
print(response)
results = response['data']['Get'][class_name]
text_list = [result['sentence'] for result in results]
return text_list
if __name__ == '__main__':
init_weaviate()
input_str = input("请输入查询的文本:")
ans = use_weaviate(input_str)
print("查询结果:", ans)