jhj0517 commited on
Commit
71eeeca
·
1 Parent(s): c8ead6d

refactor translation model class with abstract class

Browse files
Files changed (1) hide show
  1. modules/nllb_inference.py +31 -123
modules/nllb_inference.py CHANGED
@@ -1,141 +1,49 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
- import torch
4
  import os
5
- from datetime import datetime
6
 
7
- from .base_interface import BaseInterface
8
- from modules.subtitle_manager import *
9
 
10
- DEFAULT_MODEL_SIZE = "facebook/nllb-200-1.3B"
11
- NLLB_MODELS = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
12
 
13
-
14
- class NLLBInference(BaseInterface):
15
  def __init__(self):
16
- super().__init__()
17
- self.default_model_size = DEFAULT_MODEL_SIZE
18
- self.current_model_size = None
19
- self.model = None
20
  self.tokenizer = None
21
- self.available_models = NLLB_MODELS
22
  self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
23
  self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
24
- self.device = 0 if torch.cuda.is_available() else -1
25
  self.pipeline = None
26
 
27
- def translate_text(self, text):
 
 
28
  result = self.pipeline(text)
29
  return result[0]['translation_text']
30
 
31
- def translate_file(self,
32
- fileobjs: list,
33
- model_size: str,
34
- src_lang: str,
35
- tgt_lang: str,
36
- add_timestamp: bool,
37
- progress=gr.Progress()) -> list:
38
- """
39
- Translate subtitle file from source language to target language
40
-
41
- Parameters
42
- ----------
43
- fileobjs: list
44
- List of files to transcribe from gr.Files()
45
- model_size: str
46
- Whisper model size from gr.Dropdown()
47
- src_lang: str
48
- Source language of the file to translate from gr.Dropdown()
49
- tgt_lang: str
50
- Target language of the file to translate from gr.Dropdown()
51
- add_timestamp: bool
52
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
53
- progress: gr.Progress
54
- Indicator to show progress directly in gradio.
55
- I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
56
-
57
- Returns
58
- ----------
59
- A List of
60
- String to return to gr.Textbox()
61
- Files to return to gr.Files()
62
- """
63
- try:
64
- if model_size != self.current_model_size or self.model is None:
65
- print("\nInitializing NLLB Model..\n")
66
- progress(0, desc="Initializing NLLB Model..")
67
- self.current_model_size = model_size
68
- self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
69
- cache_dir=os.path.join("models", "NLLB"))
70
- self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
71
- cache_dir=os.path.join("models", "NLLB", "tokenizers"))
72
-
73
- src_lang = NLLB_AVAILABLE_LANGS[src_lang]
74
- tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
75
-
76
- self.pipeline = pipeline("translation",
77
- model=self.model,
78
- tokenizer=self.tokenizer,
79
- src_lang=src_lang,
80
- tgt_lang=tgt_lang,
81
- device=self.device)
82
-
83
- files_info = {}
84
- for fileobj in fileobjs:
85
- file_path = fileobj.name
86
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
87
- if file_ext == ".srt":
88
- parsed_dicts = parse_srt(file_path=file_path)
89
- total_progress = len(parsed_dicts)
90
- for index, dic in enumerate(parsed_dicts):
91
- progress(index / total_progress, desc="Translating..")
92
- translated_text = self.translate_text(dic["sentence"])
93
- dic["sentence"] = translated_text
94
- subtitle = get_serialized_srt(parsed_dicts)
95
-
96
- timestamp = datetime.now().strftime("%m%d%H%M%S")
97
- if add_timestamp:
98
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
99
- else:
100
- output_path = os.path.join("outputs", "translations", f"{file_name}")
101
- output_path += '.srt'
102
-
103
- write_file(subtitle, output_path)
104
-
105
- elif file_ext == ".vtt":
106
- parsed_dicts = parse_vtt(file_path=file_path)
107
- total_progress = len(parsed_dicts)
108
- for index, dic in enumerate(parsed_dicts):
109
- progress(index / total_progress, desc="Translating..")
110
- translated_text = self.translate_text(dic["sentence"])
111
- dic["sentence"] = translated_text
112
- subtitle = get_serialized_vtt(parsed_dicts)
113
-
114
- timestamp = datetime.now().strftime("%m%d%H%M%S")
115
- if add_timestamp:
116
- output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
117
- else:
118
- output_path = os.path.join("outputs", "translations", f"{file_name}")
119
- output_path += '.vtt'
120
-
121
- write_file(subtitle, output_path)
122
-
123
- files_info[file_name] = subtitle
124
-
125
- total_result = ''
126
- for file_name, subtitle in files_info.items():
127
- total_result += '------------------------------------\n'
128
- total_result += f'{file_name}\n\n'
129
- total_result += f'{subtitle}'
130
-
131
- gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
132
- return [gr_str, output_path]
133
- except Exception as e:
134
- print(f"Error: {str(e)}")
135
- finally:
136
- self.release_cuda_memory()
137
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
138
-
139
 
140
  NLLB_AVAILABLE_LANGS = {
141
  "Acehnese (Arabic script)": "ace_Arab",
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
 
3
  import os
 
4
 
5
+ from modules.translation_base import TranslationBase
 
6
 
 
 
7
 
8
+ class NLLBInference(TranslationBase):
 
9
  def __init__(self):
10
+ super().__init__(
11
+ model_dir=os.path.join("models", "NLLB")
12
+ )
 
13
  self.tokenizer = None
14
+ self.available_models = ["facebook/nllb-200-3.3B", "facebook/nllb-200-1.3B", "facebook/nllb-200-distilled-600M"]
15
  self.available_source_langs = list(NLLB_AVAILABLE_LANGS.keys())
16
  self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
 
17
  self.pipeline = None
18
 
19
+ def translate(self,
20
+ text: str
21
+ ):
22
  result = self.pipeline(text)
23
  return result[0]['translation_text']
24
 
25
+ def update_model(self,
26
+ model_size: str,
27
+ src_lang: str,
28
+ tgt_lang: str,
29
+ progress: gr.Progress
30
+ ):
31
+ if model_size != self.current_model_size or self.model is None:
32
+ print("\nInitializing NLLB Model..\n")
33
+ progress(0, desc="Initializing NLLB Model..")
34
+ self.current_model_size = model_size
35
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
36
+ cache_dir=self.model_dir)
37
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
38
+ cache_dir=os.path.join(self.model_dir, "tokenizers"))
39
+ src_lang = NLLB_AVAILABLE_LANGS[src_lang]
40
+ tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
41
+ self.pipeline = pipeline("translation",
42
+ model=self.model,
43
+ tokenizer=self.tokenizer,
44
+ src_lang=src_lang,
45
+ tgt_lang=tgt_lang,
46
+ device=self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
  NLLB_AVAILABLE_LANGS = {
49
  "Acehnese (Arabic script)": "ace_Arab",