jhj0517 commited on
Commit
007ba83
·
1 Parent(s): 6bc716a

Refactor parameter

Browse files
modules/translation/translation_base.py CHANGED
@@ -46,8 +46,8 @@ class TranslationBase(ABC):
46
  model_size: str,
47
  src_lang: str,
48
  tgt_lang: str,
49
- max_length: int,
50
- add_timestamp: bool,
51
  progress=gr.Progress()) -> list:
52
  """
53
  Translate subtitle file from source language to target language
@@ -77,6 +77,9 @@ class TranslationBase(ABC):
77
  Files to return to gr.Files()
78
  """
79
  try:
 
 
 
80
  self.cache_parameters(model_size=model_size,
81
  src_lang=src_lang,
82
  tgt_lang=tgt_lang,
@@ -90,10 +93,9 @@ class TranslationBase(ABC):
90
 
91
  files_info = {}
92
  for fileobj in fileobjs:
93
- file_path = fileobj.name
94
- file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
95
  if file_ext == ".srt":
96
- parsed_dicts = parse_srt(file_path=file_path)
97
  total_progress = len(parsed_dicts)
98
  for index, dic in enumerate(parsed_dicts):
99
  progress(index / total_progress, desc="Translating..")
@@ -102,7 +104,7 @@ class TranslationBase(ABC):
102
  subtitle = get_serialized_srt(parsed_dicts)
103
 
104
  elif file_ext == ".vtt":
105
- parsed_dicts = parse_vtt(file_path=file_path)
106
  total_progress = len(parsed_dicts)
107
  for index, dic in enumerate(parsed_dicts):
108
  progress(index / total_progress, desc="Translating..")
 
46
  model_size: str,
47
  src_lang: str,
48
  tgt_lang: str,
49
+ max_length: int = 200,
50
+ add_timestamp: bool = True,
51
  progress=gr.Progress()) -> list:
52
  """
53
  Translate subtitle file from source language to target language
 
77
  Files to return to gr.Files()
78
  """
79
  try:
80
+ if fileobjs and isinstance(fileobjs[0], gr.utils.NamedString):
81
+ fileobjs = [file.name for file in fileobjs]
82
+
83
  self.cache_parameters(model_size=model_size,
84
  src_lang=src_lang,
85
  tgt_lang=tgt_lang,
 
93
 
94
  files_info = {}
95
  for fileobj in fileobjs:
96
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj))
 
97
  if file_ext == ".srt":
98
+ parsed_dicts = parse_srt(file_path=fileobj)
99
  total_progress = len(parsed_dicts)
100
  for index, dic in enumerate(parsed_dicts):
101
  progress(index / total_progress, desc="Translating..")
 
104
  subtitle = get_serialized_srt(parsed_dicts)
105
 
106
  elif file_ext == ".vtt":
107
+ parsed_dicts = parse_vtt(file_path=fileobj)
108
  total_progress = len(parsed_dicts)
109
  for index, dic in enumerate(parsed_dicts):
110
  progress(index / total_progress, desc="Translating..")