jhj0517 commited on
Commit
34da350
·
unverified ·
2 Parent(s): 8da8748 184dab0

Merge pull request #190 from jhj0517/fix/translation-long-input

Browse files
app.py CHANGED
@@ -20,7 +20,7 @@ class App:
20
  print(f"Device \"{self.whisper_inf.device}\" is detected")
21
  self.nllb_inf = NLLBInference(
22
  model_dir=self.args.nllb_model_dir,
23
- output_dir=self.args.output_dir
24
  )
25
  self.deepl_api = DeepLAPI(
26
  output_dir=self.args.output_dir
@@ -375,6 +375,8 @@ class App:
375
  choices=self.nllb_inf.available_source_langs)
376
  dd_nllb_targetlang = gr.Dropdown(label="Target Language",
377
  choices=self.nllb_inf.available_target_langs)
 
 
378
  with gr.Row():
379
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
380
  interactive=True)
@@ -388,7 +390,7 @@ class App:
388
  md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
389
 
390
  btn_run.click(fn=self.nllb_inf.translate_file,
391
- inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, cb_timestamp],
392
  outputs=[tb_indicator, files_subtitles])
393
 
394
  btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
 
20
  print(f"Device \"{self.whisper_inf.device}\" is detected")
21
  self.nllb_inf = NLLBInference(
22
  model_dir=self.args.nllb_model_dir,
23
+ output_dir=os.path.join(self.args.output_dir, "translations")
24
  )
25
  self.deepl_api = DeepLAPI(
26
  output_dir=self.args.output_dir
 
375
  choices=self.nllb_inf.available_source_langs)
376
  dd_nllb_targetlang = gr.Dropdown(label="Target Language",
377
  choices=self.nllb_inf.available_target_langs)
378
+ with gr.Row():
379
+ nb_max_length = gr.Number(label="Max Length Per Line", value=200, precision=0)
380
  with gr.Row():
381
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
382
  interactive=True)
 
390
  md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
391
 
392
  btn_run.click(fn=self.nllb_inf.translate_file,
393
+ inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, nb_max_length, cb_timestamp],
394
  outputs=[tb_indicator, files_subtitles])
395
 
396
  btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
modules/translation/nllb_inference.py CHANGED
@@ -21,9 +21,13 @@ class NLLBInference(TranslationBase):
21
  self.pipeline = None
22
 
23
  def translate(self,
24
- text: str
 
25
  ):
26
- result = self.pipeline(text)
 
 
 
27
  return result[0]['translation_text']
28
 
29
  def update_model(self,
 
21
  self.pipeline = None
22
 
23
  def translate(self,
24
+ text: str,
25
+ max_length: int
26
  ):
27
+ result = self.pipeline(
28
+ text,
29
+ max_length=max_length
30
+ )
31
  return result[0]['translation_text']
32
 
33
  def update_model(self,
modules/translation/translation_base.py CHANGED
@@ -24,7 +24,8 @@ class TranslationBase(ABC):
24
 
25
  @abstractmethod
26
  def translate(self,
27
- text: str
 
28
  ):
29
  pass
30
 
@@ -42,6 +43,7 @@ class TranslationBase(ABC):
42
  model_size: str,
43
  src_lang: str,
44
  tgt_lang: str,
 
45
  add_timestamp: bool,
46
  progress=gr.Progress()) -> list:
47
  """
@@ -57,6 +59,8 @@ class TranslationBase(ABC):
57
  Source language of the file to translate from gr.Dropdown()
58
  tgt_lang: str
59
  Target language of the file to translate from gr.Dropdown()
 
 
60
  add_timestamp: bool
61
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
62
  progress: gr.Progress
@@ -84,7 +88,7 @@ class TranslationBase(ABC):
84
  total_progress = len(parsed_dicts)
85
  for index, dic in enumerate(parsed_dicts):
86
  progress(index / total_progress, desc="Translating..")
87
- translated_text = self.translate(dic["sentence"])
88
  dic["sentence"] = translated_text
89
  subtitle = get_serialized_srt(parsed_dicts)
90
 
@@ -99,7 +103,7 @@ class TranslationBase(ABC):
99
  total_progress = len(parsed_dicts)
100
  for index, dic in enumerate(parsed_dicts):
101
  progress(index / total_progress, desc="Translating..")
102
- translated_text = self.translate(dic["sentence"])
103
  dic["sentence"] = translated_text
104
  subtitle = get_serialized_vtt(parsed_dicts)
105
 
@@ -124,7 +128,6 @@ class TranslationBase(ABC):
124
  print(f"Error: {str(e)}")
125
  finally:
126
  self.release_cuda_memory()
127
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
128
 
129
  @staticmethod
130
  def get_device():
 
24
 
25
  @abstractmethod
26
  def translate(self,
27
+ text: str,
28
+ max_length: int
29
  ):
30
  pass
31
 
 
43
  model_size: str,
44
  src_lang: str,
45
  tgt_lang: str,
46
+ max_length: int,
47
  add_timestamp: bool,
48
  progress=gr.Progress()) -> list:
49
  """
 
59
  Source language of the file to translate from gr.Dropdown()
60
  tgt_lang: str
61
  Target language of the file to translate from gr.Dropdown()
62
+ max_length: int
63
+ Max length per line to translate
64
  add_timestamp: bool
65
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
66
  progress: gr.Progress
 
88
  total_progress = len(parsed_dicts)
89
  for index, dic in enumerate(parsed_dicts):
90
  progress(index / total_progress, desc="Translating..")
91
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
92
  dic["sentence"] = translated_text
93
  subtitle = get_serialized_srt(parsed_dicts)
94
 
 
103
  total_progress = len(parsed_dicts)
104
  for index, dic in enumerate(parsed_dicts):
105
  progress(index / total_progress, desc="Translating..")
106
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
107
  dic["sentence"] = translated_text
108
  subtitle = get_serialized_vtt(parsed_dicts)
109
 
 
128
  print(f"Error: {str(e)}")
129
  finally:
130
  self.release_cuda_memory()
 
131
 
132
  @staticmethod
133
  def get_device():