import torch from transformers import AutoModelForCausalLM, AutoTokenizer from typing import Dict, List, Any class EndpointHandler: def __init__(self, path=""): # Load the model and tokenizer self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16).to(self.device).eval() self.tokenizer = AutoTokenizer.from_pretrained(path) def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: # Handle the incoming request input_text = data["inputs"]["text"] template = data["inputs"]["template"] # Use the predict function output = self.predict_NuExtract([input_text], template) return [{"extracted_information": output}] def predict_NuExtract(self, texts, template, batch_size=1, max_length=10_000, max_new_tokens=4_000): # Generate prompts based on the template template = json.dumps(json.loads(template), indent=4) prompts = [f"""<|input|>\n### Template:\n{template}\n### Text:\n{text}\n\n<|output|>""" for text in texts] outputs = [] with torch.no_grad(): for i in range(0, len(prompts), batch_size): batch_prompts = prompts[i:i+batch_size] batch_encodings = self.tokenizer(batch_prompts, return_tensors="pt", truncation=True, padding=True, max_length=max_length).to(self.device) pred_ids = self.model.generate(**batch_encodings, max_new_tokens=max_new_tokens) outputs += self.tokenizer.batch_decode(pred_ids, skip_special_tokens=True) return [output.split("<|output|>")[1] for output in outputs]