# Author: Ricardo Lisboa Santos # Creation date: 2024-01-10 import torch # import torch_directml from transformers import pipeline def getDevice(DEVICE): device = None if DEVICE == "cpu": device = torch.device("cpu") dtype = torch.float32 elif DEVICE == "cuda": device = torch.device("cuda") dtype = torch.float16 # elif DEVICE == "directml": # device = torch_directml.device() # dtype = torch.float16 return device def loadGenerator(device): generator = pipeline("question-answering") # .to(device) return generator def query(generator, question, context): output = generator( question=question, context=context, ) return output def clearCache(DEVICE, generator): generator.tokenizer.save_pretrained("cache") generator.model.save_pretrained("cache") del generator # if DEVICE == "directml": # torch_directml.empty_cache()