Linsad commited on
Commit
ab65af6
·
1 Parent(s): 750f8c6

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +10 -9
handler.py CHANGED
@@ -3,21 +3,22 @@ from transformers import AutoTokenizer, AutoModel
3
 
4
 
5
  class EndpointHandler:
6
- def __init__(self):
7
- self.tokenizer = AutoTokenizer.from_pretrained("chatglm2-6b-int4", trust_remote_code=True)
8
- self.model = AutoModel.from_pretrained("chatglm2-6b-int4", trust_remote_code=True).half().cuda()
 
9
  self.model = self.model.eval()
10
 
11
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
12
  """
13
- data args:
14
- inputs (:obj: `str`)
15
- Return:
16
- A :obj:`list` | `dict`: will be serialized and returned
17
  """
18
  # get inputs
19
  inputs = data.pop("inputs", data)
20
-
21
  response, history = self.model.chat(self.tokenizer, inputs, history=[])
22
-
23
  return [{'response': response, 'history': history}]
 
3
 
4
 
5
  class EndpointHandler:
6
+ def __init__(self, path=""):
7
+ print('path is' + path)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
9
+ self.model = AutoModel.from_pretrained(path, trust_remote_code=True).half().cuda()
10
  self.model = self.model.eval()
11
 
12
  def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
13
  """
14
+ data args:
15
+ inputs (:obj: `str`)
16
+ Return:
17
+ A :obj:`list` | `dict`: will be serialized and returned
18
  """
19
  # get inputs
20
  inputs = data.pop("inputs", data)
21
+
22
  response, history = self.model.chat(self.tokenizer, inputs, history=[])
23
+
24
  return [{'response': response, 'history': history}]