app.py
CHANGED
@@ -324,9 +324,10 @@ def encode_keywords_to_avg(keywords, model, tokenizer, device):
|
|
324 |
|
325 |
def encode_keywords_to_list(keywords, model, tokenizer, device):
|
326 |
embeddings = []
|
|
|
327 |
for keyword in tqdm(keywords):
|
328 |
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
329 |
-
inputs.to(device)
|
330 |
with torch.no_grad():
|
331 |
outputs = model(**inputs)
|
332 |
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())
|
|
|
324 |
|
325 |
def encode_keywords_to_list(keywords, model, tokenizer, device):
|
326 |
embeddings = []
|
327 |
+
model.to(device)
|
328 |
for keyword in tqdm(keywords):
|
329 |
inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
|
330 |
+
inputs = {key: value.to(device) for key, value in inputs.items()}
|
331 |
with torch.no_grad():
|
332 |
outputs = model(**inputs)
|
333 |
embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())
|