jhj0517 commited on
Commit
5e73da1
·
1 Parent(s): 1a50cf4

add `download_model()`

Browse files
modules/insanely_fast_whisper_inference.py CHANGED
@@ -3,11 +3,10 @@ import time
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
5
  import torch
6
- import transformers
7
  from transformers import pipeline
8
  from transformers.utils import is_flash_attn_2_available
9
- import whisper
10
  import gradio as gr
 
11
 
12
  from modules.whisper_parameter import *
13
  from modules.whisper_base import WhisperBase
@@ -53,16 +52,16 @@ class InsanelyFastWhisperInference(WhisperBase):
53
  if params.lang == "Automatic Detection":
54
  params.lang = None
55
 
56
- def progress_callback(progress_value):
57
- progress(progress_value, desc="Transcribing..")
58
-
59
- segments_result = self.model(
60
  inputs=audio,
61
  chunk_length_s=30,
62
  batch_size=24,
63
  return_timestamps=True,
64
  )
65
- segments_result = self.format_result(transcribed_result=segments_result)
 
 
66
  elapsed_time = time.time() - start_time
67
  return segments_result, elapsed_time
68
 
@@ -85,6 +84,14 @@ class InsanelyFastWhisperInference(WhisperBase):
85
  Indicator to show progress directly in gradio.
86
  """
87
  progress(0, desc="Initializing Model..")
 
 
 
 
 
 
 
 
88
  self.current_compute_type = compute_type
89
  self.current_model_size = model_size
90
 
@@ -97,7 +104,9 @@ class InsanelyFastWhisperInference(WhisperBase):
97
  )
98
 
99
  @staticmethod
100
- def format_result(transcribed_result: dict) -> List[dict]:
 
 
101
  """
102
  Format the transcription result of insanely_fast_whisper as the same with other implementation.
103
 
@@ -105,6 +114,8 @@ class InsanelyFastWhisperInference(WhisperBase):
105
  ----------
106
  transcribed_result: dict
107
  Transcription result of the insanely_fast_whisper
 
 
108
 
109
  Returns
110
  ----------
@@ -118,3 +129,31 @@ class InsanelyFastWhisperInference(WhisperBase):
118
  item["end"] = end
119
  return result
120
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  import numpy as np
4
  from typing import BinaryIO, Union, Tuple, List
5
  import torch
 
6
  from transformers import pipeline
7
  from transformers.utils import is_flash_attn_2_available
 
8
  import gradio as gr
9
+ import wget
10
 
11
  from modules.whisper_parameter import *
12
  from modules.whisper_base import WhisperBase
 
52
  if params.lang == "Automatic Detection":
53
  params.lang = None
54
 
55
+ progress(0, desc="Transcribing...")
56
+ segments = self.model(
 
 
57
  inputs=audio,
58
  chunk_length_s=30,
59
  batch_size=24,
60
  return_timestamps=True,
61
  )
62
+ segments_result = self.format_result(
63
+ transcribed_result=segments,
64
+ )
65
  elapsed_time = time.time() - start_time
66
  return segments_result, elapsed_time
67
 
 
84
  Indicator to show progress directly in gradio.
85
  """
86
  progress(0, desc="Initializing Model..")
87
+ model_path = os.path.join(self.model_dir, model_size)
88
+ if not os.path.isdir(model_path) or not os.listdir(model_path):
89
+ self.download_model(
90
+ model_size=model_size,
91
+ download_root=model_path,
92
+ progress=progress
93
+ )
94
+
95
  self.current_compute_type = compute_type
96
  self.current_model_size = model_size
97
 
 
104
  )
105
 
106
  @staticmethod
107
+ def format_result(
108
+ transcribed_result: dict
109
+ ) -> List[dict]:
110
  """
111
  Format the transcription result of insanely_fast_whisper as the same with other implementation.
112
 
 
114
  ----------
115
  transcribed_result: dict
116
  Transcription result of the insanely_fast_whisper
117
+ progress: gr.Progress
118
+ Indicator to show progress directly in gradio.
119
 
120
  Returns
121
  ----------
 
129
  item["end"] = end
130
  return result
131
 
132
+ @staticmethod
133
+ def download_model(
134
+ model_size: str,
135
+ download_root: str,
136
+ progress: gr.Progress
137
+ ):
138
+ progress(0, 'Initializing model..')
139
+ print(f'Downloading {model_size} to "{download_root}"....')
140
+
141
+ os.makedirs(download_root, exist_ok=True)
142
+ download_list = [
143
+ "model.safetensors",
144
+ "config.json",
145
+ "generation_config.json",
146
+ "preprocessor_config.json",
147
+ "tokenizer.json",
148
+ "tokenizer_config.json",
149
+ "added_tokens.json",
150
+ "special_tokens_map.json",
151
+ "vocab.json",
152
+ ]
153
+
154
+ download_host = f"https://huggingface.co/openai/whisper-{model_size}/resolve/main"
155
+ for item in download_list:
156
+ wget.download(
157
+ download_host+"/"+item,
158
+ download_root
159
+ )
requirements.txt CHANGED
@@ -4,4 +4,5 @@ git+https://github.com/jhj0517/jhj0517-whisper.git
4
  faster-whisper==1.0.2
5
  transformers
6
  gradio==4.29.0
7
- pytube
 
 
4
  faster-whisper==1.0.2
5
  transformers
6
  gradio==4.29.0
7
+ pytube
8
+ wget==3.2