thefish1 commited on
Commit
bd58e00
·
1 Parent(s): 9a0fbfc
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -324,9 +324,10 @@ def encode_keywords_to_avg(keywords, model, tokenizer, device):
324
 
325
  def encode_keywords_to_list(keywords, model, tokenizer, device):
326
  embeddings = []
 
327
  for keyword in tqdm(keywords):
328
  inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
329
- inputs.to(device)
330
  with torch.no_grad():
331
  outputs = model(**inputs)
332
  embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())
 
324
 
325
  def encode_keywords_to_list(keywords, model, tokenizer, device):
326
  embeddings = []
327
+ model.to(device)
328
  for keyword in tqdm(keywords):
329
  inputs = tokenizer(keyword, return_tensors='pt', padding=True, truncation=True, max_length=512)
330
+ inputs = {key: value.to(device) for key, value in inputs.items()}
331
  with torch.no_grad():
332
  outputs = model(**inputs)
333
  embeddings.append(outputs.last_hidden_state.mean(dim=1).squeeze().tolist())