alexneakameni commited on
Commit
67dab86
·
verified ·
1 Parent(s): 3ed22f6

Remove cuda check that might causes error

Browse files
Files changed (1) hide show
  1. ocr_engine.py +2 -6
ocr_engine.py CHANGED
@@ -1,6 +1,5 @@
1
  # Load model directly
2
  import os
3
- import torch
4
  from transformers import AutoModel, AutoTokenizer
5
  from PIL import Image
6
  import uuid
@@ -18,10 +17,8 @@ class OCRModel:
18
 
19
  def __init__(self):
20
  self.tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
21
- self.model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cuda' if torch.cuda.is_available() else "cpu", use_safetensors=True, pad_token_id=self.tokenizer.eos_token_id)
22
  self.model = self.model.eval()
23
- if torch.cuda.is_available():
24
- self.model = self.model.cuda()
25
 
26
  def chat(self, image: Image.Image) -> str:
27
  unique_id = str(uuid.uuid4())
@@ -31,8 +28,7 @@ class OCRModel:
31
  image.save(image_path)
32
 
33
  res = self.model.chat(self.tokenizer, image_path, ocr_type='ocr')
34
- with open(result_path, 'w') as f:
35
- f.write(res)
36
  return res
37
 
38
  ocr_model = OCRModel()
 
1
  # Load model directly
2
  import os
 
3
  from transformers import AutoModel, AutoTokenizer
4
  from PIL import Image
5
  import uuid
 
17
 
18
  def __init__(self):
19
  self.tokenizer = AutoTokenizer.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True)
20
+ self.model = AutoModel.from_pretrained('ucaslcl/GOT-OCR2_0', trust_remote_code=True, low_cpu_mem_usage=True, device_map='cpu', use_safetensors=True, pad_token_id=self.tokenizer.eos_token_id)
21
  self.model = self.model.eval()
 
 
22
 
23
  def chat(self, image: Image.Image) -> str:
24
  unique_id = str(uuid.uuid4())
 
28
  image.save(image_path)
29
 
30
  res = self.model.chat(self.tokenizer, image_path, ocr_type='ocr')
31
+ os.remove(image_path) # delete file create to avoid memory issue and data shared online
 
32
  return res
33
 
34
  ocr_model = OCRModel()