jhj0517 commited on
Commit
c8ead6d
·
1 Parent(s): e76c01c

add base abstract class for translation model

Browse files
Files changed (1) hide show
  1. modules/translation_base.py +148 -0
modules/translation_base.py ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ from abc import ABC, abstractmethod
5
+ from typing import List
6
+ from datetime import datetime
7
+
8
+ from modules.whisper_parameter import *
9
+ from modules.subtitle_manager import *
10
+
11
+
12
+ class TranslationBase(ABC):
13
+ def __init__(self,
14
+ model_dir: str):
15
+ super().__init__()
16
+ self.model = None
17
+ self.model_dir = model_dir
18
+ os.makedirs(self.model_dir, exist_ok=True)
19
+ self.current_model_size = None
20
+ self.device = self.get_device()
21
+
22
+ @abstractmethod
23
+ def translate(self,
24
+ text: str
25
+ ):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def update_model(self,
30
+ model_size: str,
31
+ src_lang: str,
32
+ tgt_lang: str,
33
+ progress: gr.Progress
34
+ ):
35
+ pass
36
+
37
+ def translate_file(self,
38
+ fileobjs: list,
39
+ model_size: str,
40
+ src_lang: str,
41
+ tgt_lang: str,
42
+ add_timestamp: bool,
43
+ progress=gr.Progress()) -> list:
44
+ """
45
+ Translate subtitle file from source language to target language
46
+
47
+ Parameters
48
+ ----------
49
+ fileobjs: list
50
+ List of files to transcribe from gr.Files()
51
+ model_size: str
52
+ Whisper model size from gr.Dropdown()
53
+ src_lang: str
54
+ Source language of the file to translate from gr.Dropdown()
55
+ tgt_lang: str
56
+ Target language of the file to translate from gr.Dropdown()
57
+ add_timestamp: bool
58
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
59
+ progress: gr.Progress
60
+ Indicator to show progress directly in gradio.
61
+ I use a forked version of whisper for this. To see more info : https://github.com/jhj0517/jhj0517-whisper/tree/add-progress-callback
62
+
63
+ Returns
64
+ ----------
65
+ A List of
66
+ String to return to gr.Textbox()
67
+ Files to return to gr.Files()
68
+ """
69
+ try:
70
+ self.update_model(model_size=model_size,
71
+ src_lang=src_lang,
72
+ tgt_lang=tgt_lang,
73
+ progress=progress)
74
+
75
+ files_info = {}
76
+ for fileobj in fileobjs:
77
+ file_path = fileobj.name
78
+ file_name, file_ext = os.path.splitext(os.path.basename(fileobj.name))
79
+ if file_ext == ".srt":
80
+ parsed_dicts = parse_srt(file_path=file_path)
81
+ total_progress = len(parsed_dicts)
82
+ for index, dic in enumerate(parsed_dicts):
83
+ progress(index / total_progress, desc="Translating..")
84
+ translated_text = self.translate(dic["sentence"])
85
+ dic["sentence"] = translated_text
86
+ subtitle = get_serialized_srt(parsed_dicts)
87
+
88
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
89
+ if add_timestamp:
90
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
91
+ else:
92
+ output_path = os.path.join("outputs", "translations", f"{file_name}.srt")
93
+
94
+ elif file_ext == ".vtt":
95
+ parsed_dicts = parse_vtt(file_path=file_path)
96
+ total_progress = len(parsed_dicts)
97
+ for index, dic in enumerate(parsed_dicts):
98
+ progress(index / total_progress, desc="Translating..")
99
+ translated_text = self.translate(dic["sentence"])
100
+ dic["sentence"] = translated_text
101
+ subtitle = get_serialized_vtt(parsed_dicts)
102
+
103
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
104
+ if add_timestamp:
105
+ output_path = os.path.join("outputs", "translations", f"{file_name}-{timestamp}")
106
+ else:
107
+ output_path = os.path.join("outputs", "translations", f"{file_name}.vtt")
108
+
109
+ write_file(subtitle, output_path)
110
+ files_info[file_name] = subtitle
111
+
112
+ total_result = ''
113
+ for file_name, subtitle in files_info.items():
114
+ total_result += '------------------------------------\n'
115
+ total_result += f'{file_name}\n\n'
116
+ total_result += f'{subtitle}'
117
+
118
+ gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
119
+ return [gr_str, output_path]
120
+ except Exception as e:
121
+ print(f"Error: {str(e)}")
122
+ finally:
123
+ self.release_cuda_memory()
124
+ self.remove_input_files([fileobj.name for fileobj in fileobjs])
125
+
126
+ @staticmethod
127
+ def get_device():
128
+ if torch.cuda.is_available():
129
+ return "cuda"
130
+ elif torch.backends.mps.is_available():
131
+ return "mps"
132
+ else:
133
+ return "cpu"
134
+
135
+ @staticmethod
136
+ def release_cuda_memory():
137
+ if torch.cuda.is_available():
138
+ torch.cuda.empty_cache()
139
+ torch.cuda.reset_max_memory_allocated()
140
+
141
+ @staticmethod
142
+ def remove_input_files(file_paths: List[str]):
143
+ if not file_paths:
144
+ return
145
+
146
+ for file_path in file_paths:
147
+ if file_path and os.path.exists(file_path):
148
+ os.remove(file_path)