Spaces:
Runtime error
Runtime error
import torch | |
from transformers import AutoTokenizer, AutoModel | |
import os | |
class TextExtractor: | |
def __init__(self, model_name, proxy=None): | |
""" | |
Initialize the TextExtractor with a specified model and optional proxy settings. | |
Parameters: | |
- model_name (str): The name of the pre-trained model to load from HuggingFace Hub. | |
- proxy (str, optional): The proxy address to use for HTTP and HTTPS requests. | |
""" | |
# if proxy is None: | |
# proxy = 'http://localhost:8234' | |
# if proxy: | |
# os.environ['HTTP_PROXY'] = proxy | |
# os.environ['HTTPS_PROXY'] = proxy | |
try: | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name) | |
self.model = AutoModel.from_pretrained(model_name) | |
except: | |
print('try switch on local_files_only') | |
self.tokenizer = AutoTokenizer.from_pretrained(model_name, local_files_only=True) | |
self.model = AutoModel.from_pretrained(model_name, local_files_only=True) | |
self.model.eval() | |
def extract(self, sentences): | |
""" | |
Extract sentence embeddings for the provided sentences. | |
Parameters: | |
- sentences (list of str): A list of sentences to extract embeddings for. | |
Returns: | |
- torch.Tensor: The normalized sentence embeddings. | |
""" | |
encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt') | |
with torch.no_grad(): | |
model_output = self.model(**encoded_input) | |
sentence_embeddings = model_output[0][:, 0] | |
sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1) | |
return sentence_embeddings | |
import pandas as pd | |
def get_qas(excel_file = None): | |
defaule_excel_file = 'data/output_fixid.xlsx' | |
if excel_file is None: | |
excel_file = defaule_excel_file | |
# 读取Excel文件 | |
df = pd.read_excel(excel_file) | |
df = df[df["question"].notna()] | |
df = df[df["summary"].notna()] | |
datas = [] | |
# 遍历DataFrame的每一行 | |
for index, row in df.iterrows(): | |
id = row['id'] | |
question = row['question'] | |
short_answer = row['summary'] | |
category = row['category'] | |
texts = [question, short_answer] | |
data_value = { | |
"texts":texts, | |
} | |
data = { | |
"id":id, | |
"value":data_value | |
} | |
datas.append(data) | |
return datas | |
from tqdm import tqdm | |
def extract_embedding(datas, text_extractor): | |
""" | |
Extract embeddings for each item in the provided data. | |
Parameters: | |
- datas (list of dict): A list of dictionaries containing text data. | |
Returns: | |
- list of dict: The input data with added embeddings. | |
""" | |
for data in tqdm(datas): | |
texts = data["value"]["texts"] | |
text = "。".join(texts) | |
embeddings = text_extractor.extract(text) | |
embeddings_list = embeddings.tolist() # Convert tensor to list of lists | |
data["value"]["embedding"] = embeddings_list | |
return datas | |
def save_parquet(datas, file_path): | |
""" | |
Save the provided data to a Parquet file. | |
Parameters: | |
- datas (list of dict): A list of dictionaries containing text data and embeddings. | |
- file_path (str): The path to the output Parquet file. | |
""" | |
# Flatten the data for easier conversion to DataFrame | |
flattened_data = [] | |
for data in datas: | |
id = data["id"] | |
texts = data["value"]["texts"] | |
text = "。".join(texts) | |
embedding = data["value"]["embedding"] | |
flattened_data.append({ | |
"id": id, | |
"text": text, | |
"embedding": embedding | |
}) | |
# Create DataFrame | |
df = pd.DataFrame(flattened_data) | |
# Save DataFrame to Parquet | |
df.to_parquet(file_path, index=False) | |
import pandas as pd | |
import os | |
def get_id2embedding(regen=False, parquet_file='datas/qa_with_embedding.parquet'): | |
""" | |
Get a dictionary mapping IDs to embeddings. Regenerate embeddings if specified. | |
Parameters: | |
- parquet_file (str): The path to the Parquet file. | |
- regen (bool): Whether to regenerate embeddings. | |
Returns: | |
- dict: A dictionary mapping IDs to list of float embeddings. | |
""" | |
if regen or not os.path.exists(parquet_file): | |
print("Regenerating embeddings...") | |
# Example usage: | |
model_name = 'BAAI/bge-small-zh-v1.5' | |
text_extractor = TextExtractor(model_name) | |
datas = get_qas() | |
print("Extracting embeddings for", len(datas), "data items") | |
datas = extract_embedding(datas, text_extractor) | |
save_parquet(datas, parquet_file) | |
df = pd.read_parquet(parquet_file) | |
id2embedding = {} | |
for index, row in df.iterrows(): | |
id = row['id'] | |
embedding = row['embedding'] | |
id2embedding[id] = embedding[0] | |
return id2embedding | |
import torch | |
from sklearn.metrics.pairwise import cosine_similarity | |
import heapq | |
def __get_id2top30map(id2embedding): | |
""" | |
Get a dictionary mapping IDs to their top 30 nearest neighbors based on cosine similarity. | |
Parameters: | |
- id2embedding (dict): A dictionary mapping IDs to list of float embeddings. | |
Returns: | |
- dict: A dictionary mapping each ID to a list of the top 30 nearest neighbor IDs. | |
""" | |
ids = list(id2embedding.keys()) | |
embeddings = torch.tensor([id2embedding[id] for id in ids]) | |
# Compute cosine similarity matrix | |
cos_sim_matrix = cosine_similarity(embeddings) | |
id2top30map = {} | |
for i, id in enumerate(ids): | |
# Get the similarity scores for the current ID | |
sim_scores = cos_sim_matrix[i] | |
# Get the top 30 indices (excluding the current ID itself) | |
top_indices = heapq.nlargest(31, range(len(sim_scores)), key=lambda x: sim_scores[x]) | |
top_indices.remove(i) # Remove the index of the current ID | |
# Map the indices back to IDs | |
top_30_ids = [ids[idx] for idx in top_indices[:30]] | |
id2top30map[id] = top_30_ids | |
return id2top30map | |
import pickle | |
def get_id2top30map( id2embedding = None ): | |
default_save_pkl = "data/id2top30map.pkl" | |
if id2embedding is None: | |
if os.path.exists(default_save_pkl): | |
with open(default_save_pkl, 'rb') as f: | |
id2top30map = pickle.load(f) | |
else: | |
print("No embedding found, generating new one...") | |
id2embedding = get_id2embedding(regen=False) | |
id2top30map = __get_id2top30map(id2embedding) | |
with open(default_save_pkl, 'wb') as f: | |
pickle.dump(id2top30map, f) | |
else: | |
id2top30map = __get_id2top30map(id2embedding) | |
return id2top30map | |
if __name__ == '__main__': | |
if False: | |
# Example usage: | |
model_name = 'BAAI/bge-small-zh-v1.5' | |
sentences = ["样例数据-1", "样例数据-2"] | |
text_extractor = TextExtractor(model_name) | |
embeddings = text_extractor.extract(sentences) | |
print("Sentence embeddings:", embeddings) | |
datas = get_qas() | |
print("extract embedding for ", len(datas), " datas") | |
datas = extract_embedding(datas, text_extractor ) | |
default_parquet_save_name = "data/qa_with_embedding.parquet" | |
save_parquet(datas, default_parquet_save_name) | |
if True: | |
id2embedding = get_id2embedding(regen=False) | |
print(len(id2embedding[4])) | |
id2top30map = get_id2top30map( None ) | |
print("ID to Top 30 Neighbors dictionary:", id2top30map[4]) | |
if True: | |
start_id = 332 | |
visited_ids = [start_id] | |
current_queue = [start_id] | |
expend_num = 5 | |
for iteration in range(10): | |
current_node = current_queue.pop(0) | |
top30 = id2top30map[current_node] | |
current_expend = [] | |
for id in top30: | |
if id not in visited_ids: | |
visited_ids.append(id) | |
current_queue.append(id) | |
current_expend.append(id) | |
if len(current_expend) >= expend_num: | |
break | |
display_text = f"{current_node} | ->" + ",".join([str(i) for i in current_expend]) | |
print(display_text) | |
from get_qa_and_image import get_qa_and_image | |
image_datas = get_qa_and_image() | |
id2index = {} | |
for i, data in enumerate(image_datas): | |
id2index[data['id']] = i | |
indexes = [id2index[i] for i in visited_ids if i in id2index] | |
image_names = [image_datas[index]['value']['image'] for index in indexes] | |
target_copy_folder = "data/asso_collection" | |
import shutil | |
# copy image into target_copy_folder | |
for image_name in image_names: | |
shutil.copy(image_name, target_copy_folder) |