atsushieee commited on
Commit
a629fc9
·
1 Parent(s): ad51a72

Update whisper/inference.py

Browse files
Files changed (1) hide show
  1. whisper/inference.py +36 -1
whisper/inference.py CHANGED
@@ -3,6 +3,8 @@ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
3
  import numpy as np
4
  import argparse
5
  import torch
 
 
6
 
7
  from whisper.model import Whisper, ModelDimensions
8
  from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram
@@ -29,6 +31,37 @@ def load_model(path, device) -> Whisper:
29
  return model
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  def pred_ppg(whisper: Whisper, wavPath, ppgPath, device):
33
  audio = load_audio(wavPath)
34
  audln = audio.shape[0]
@@ -74,5 +107,7 @@ if __name__ == "__main__":
74
  ppgPath = args.ppg
75
 
76
  device = "cuda" if torch.cuda.is_available() else "cpu"
77
- whisper = load_model(os.path.join("/tmp/large-v2.pt"), device)
 
 
78
  pred_ppg(whisper, wavPath, ppgPath, device)
 
3
  import numpy as np
4
  import argparse
5
  import torch
6
+ import requests
7
+ from tqdm import tqdm
8
 
9
  from whisper.model import Whisper, ModelDimensions
10
  from whisper.audio import load_audio, pad_or_trim, log_mel_spectrogram
 
31
  return model
32
 
33
 
34
+ def check_and_download_model():
35
+ temp_dir = "/tmp"
36
+ model_path = os.path.join(temp_dir, "large-v2.pt")
37
+
38
+ if os.path.exists(model_path):
39
+ return f"モデルは既に存在します: {model_path}"
40
+
41
+ url = "https://openaipublic.azureedge.net/main/whisper/models/81f7c96c852ee8fc832187b0132e569d6c3065a3252ed18e56effd0b6a73e524/large-v2.pt"
42
+
43
+ try:
44
+ response = requests.get(url, stream=True)
45
+ response.raise_for_status()
46
+ total_size = int(response.headers.get('content-length', 0))
47
+
48
+ with open(model_path, 'wb') as f, tqdm(
49
+ desc=model_path,
50
+ total=total_size,
51
+ unit='iB',
52
+ unit_scale=True,
53
+ unit_divisor=1024,
54
+ ) as pbar:
55
+ for data in response.iter_content(chunk_size=1024):
56
+ size = f.write(data)
57
+ pbar.update(size)
58
+
59
+ return f"モデルのダウンロードが完了しました: {model_path}"
60
+
61
+ except Exception as e:
62
+ return f"エラーが発生しました: {e}"
63
+
64
+
65
  def pred_ppg(whisper: Whisper, wavPath, ppgPath, device):
66
  audio = load_audio(wavPath)
67
  audln = audio.shape[0]
 
107
  ppgPath = args.ppg
108
 
109
  device = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
+ _ =check_and_download_model()
112
+ whisper = load_model("/tmp/large-v2.pt", device)
113
  pred_ppg(whisper, wavPath, ppgPath, device)