File size: 3,877 Bytes
6dcc10e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import weaviate
import pandas as pd
import torch
import json
from transformers import AutoTokenizer, AutoModel
import subprocess
import os
# 设置 Matplotlib 缓存目录为可写的目录
os.environ['MPLCONFIGDIR'] = '/tmp/matplotlib'
# 设置 Hugging Face Transformers 缓存目录为可写的目录
os.environ['TRANSFORMERS_CACHE'] = '/tmp/huggingface_cache'
#
# try:
#         # 运行 Docker 容器的命令
#         command = [
#             "docker", "run",
#             "-p", "8080:8080",
#             "-p", "50051:50051",
#             "cr.weaviate.io/semitechnologies/weaviate:1.24.20"
#         ]
#
#         # 执行命令
#         subprocess.run(command, check=True)
#         print("Docker container is running.")
#
# except subprocess.CalledProcessError as e:
#         print(f"An error occurred: {e}")

class_name = 'Lhnjames123321'
auth_config = weaviate.AuthApiKey(api_key="8wNsHV3Enc2PNVL8Bspadh21qYAfAvnK2ux3")
client = weaviate.Client(
  url="https://3a8sbx3s66by10yxginaa.c0.asia-southeast1.gcp.weaviate.cloud",
  auth_client_secret=auth_config
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AutoModel.from_pretrained("bert-base-chinese").to(device)
tokenizer = AutoTokenizer.from_pretrained("bert-base-chinese")

def encode_sentences(sentences, model, tokenizer, device):
    inputs = tokenizer(sentences, return_tensors='pt', padding=True, truncation=True, max_length=512)
    inputs.to(device)
    with torch.no_grad():
        outputs = model(**inputs)
    embeddings = outputs.last_hidden_state.mean(dim=1)
    return embeddings.cpu().numpy()

# def class_exists(client, class_name):
#     existing_classes = client.schema.get_classes()
#     return any(cls['class'] == class_name for cls in existing_classes)

def init_weaviate():
    # if class_exists(client, class_name)==0:
    #     class_obj = {
    #         'class': class_name,
    #         'vectorIndexConfig': {
    #             'distance': 'cosine'
    #         },
    #     }
    #     client.schema.create_class(class_obj)

    file_path = 'data.json'
    sentence_data = []

    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            try:
                data = json.loads(line.strip())
                sentence1 = data.get('response', '')
                sentence_data.append(sentence1)
            except json.JSONDecodeError as e:
                print(f"Error parsing JSON: {e}")
                continue

    sentence_embeddings = encode_sentences(sentence_data, model, tokenizer, device)

    data = {'sentence': sentence_data,
            'embeddings': sentence_embeddings.tolist()}
    df = pd.DataFrame(data)

    with client.batch(batch_size=100) as batch:
        for i in range(df.shape[0]):
            print(f'importing data: {i + 1}/{df.shape[0]}')
            properties = {
                'sentence_id': i + 1,
                'sentence': df.sentence[i],
            }
            custom_vector = df.embeddings[i]
            client.batch.add_data_object(
                properties,
                class_name=class_name,
                vector=custom_vector
            )
    print('import completed')


def use_weaviate(input_str):
    query = encode_sentences([input_str], model, tokenizer, device)[0].tolist()
    nearVector = {
        'vector': query
    }

    response = (
        client.query
        .get(class_name, ['sentence_id', 'sentence'])
        .with_near_vector(nearVector)
        .with_limit(5)
        .with_additional(['distance'])
        .do()
    )
    print(response)
    results = response['data']['Get'][class_name]
    text_list = [result['sentence'] for result in results]
    return text_list

if __name__ == '__main__':
    init_weaviate()
    input_str = input("请输入查询的文本:")
    ans = use_weaviate(input_str)
    print("查询结果:", ans)