|
import weaviate |
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModel |
|
import json |
|
import os |
|
|
|
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib' |
|
|
|
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache' |
|
|
|
os.makedirs(os.environ['MPLCONFIGDIR'], exist_ok=True) |
|
os.makedirs(os.environ['TRANSFORMERS_CACHE'], exist_ok=True) |
|
auth_config = weaviate.AuthApiKey(api_key="8wNsHV3Enc2PNVL8Bspadh21qYAfAvnK2ux3") |
|
|
|
|
|
database_client = weaviate.Client( |
|
url="https://3a8sbx3s66by10yxginaa.c0.asia-southeast1.gcp.weaviate.cloud", |
|
auth_client_secret=auth_config |
|
) |
|
class_name="Lhnjames123321" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") |
|
model = AutoModel.from_pretrained("bert-base-chinese") |
|
|
|
|
|
def encode(sentences, model, tokenizer): |
|
|
|
model.eval() |
|
embeddings = [] |
|
|
|
with torch.no_grad(): |
|
for sentence in sentences: |
|
|
|
print(sentence) |
|
inputs = tokenizer(sentence, return_tensors='pt', padding=True, truncation=True, max_length=512) |
|
print(inputs) |
|
inputs = {key: value for key, value in inputs.items()} |
|
|
|
|
|
outputs = model(**inputs) |
|
|
|
embedding = outputs.last_hidden_state.mean(dim=1).numpy().astype('float32') |
|
|
|
embeddings.append(embedding) |
|
|
|
return np.vstack(embeddings) |
|
|
|
|
|
def insert_keywords_to_weaviate(database_client, class_name, keywords, summaries, avg_embeddings): |
|
|
|
with database_client.batch(batch_size=100) as batch: |
|
for i, (keyword, summary, avg_embedding) in enumerate(zip(keywords, summaries, avg_embeddings)): |
|
vector = avg_embedding.tolist() |
|
properties = { |
|
'keywords': keyword, |
|
'summary': summary |
|
} |
|
print(f'Inserting: {keyword} with summary: {summary}') |
|
batch.add_data_object( |
|
properties, |
|
class_name=class_name, |
|
vector=vector |
|
) |
|
print('Insertion completed') |
|
|
|
def init_database(database_client, class_name): |
|
|
|
dataset = [] |
|
with open('train_2000_modified.json', 'r', encoding='utf-8') as f: |
|
for line in f: |
|
dataset.append(json.loads(line)) |
|
|
|
keywords=[item['content'] for item in dataset if 'content' in item] |
|
summaries=[item['summary'] for item in dataset if 'summary' in item] |
|
|
|
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese") |
|
model = AutoModel.from_pretrained("bert-base-chinese") |
|
|
|
keywords_avg_embeddings =[] |
|
|
|
for lst in keywords: |
|
lst = lst.split(',') |
|
embeddings = encode(lst, model, tokenizer) |
|
avg_embedding = embeddings.mean(axis=0) |
|
keywords_avg_embeddings.append(avg_embedding) |
|
|
|
insert_keywords_to_weaviate(database_client, class_name, keywords, summaries, keywords_avg_embeddings) |
|
|
|
|
|
def fetch_summary_from_database(query_keywords,classname): |
|
|
|
keyword_embeddings=[] |
|
for keyword in query_keywords: |
|
keyword_embedding=encode([keyword], model, tokenizer) |
|
keyword_embeddings.append(keyword_embedding) |
|
|
|
avg_embedding = np.mean(keyword_embeddings, axis=0) |
|
response = ( |
|
database_client.query |
|
.get(class_name, ['keywords', 'summary']) |
|
.with_near_vector({'vector': avg_embedding}) |
|
.with_limit(1) |
|
.with_additional(['distance']) |
|
.do() |
|
) |
|
print(response) |
|
|
|
top_distance = response['data']['Get'][class_name][0]['_additional']['distance'] |
|
top_keywords_list=response['data']['Get'][class_name][0]['keywords'] |
|
top_summary = response['data']['Get'][class_name][0]['summary'] |
|
|
|
return top_distance,top_keywords_list,top_summary |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
init_database() |
|
|