jhj0517 commited on
Commit
6cee2a2
·
1 Parent(s): 79933ea

fix type hint

Browse files
modules/diarize/audio_loader.py CHANGED
@@ -24,32 +24,43 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
24
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
25
 
26
 
27
- def load_audio(file: str, sr: int = SAMPLE_RATE):
28
  """
29
- Open an audio file and read as mono waveform, resampling as necessary
30
 
31
  Parameters
32
  ----------
33
- file: str
34
- The audio file to open
35
 
36
  sr: int
37
- The sample rate to resample the audio if necessary
38
 
39
  Returns
40
  -------
41
  A NumPy array containing the audio waveform, in float32 dtype.
42
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
- # Launches a subprocess to decode audio while down-mixing and resampling as necessary.
45
- # Requires the ffmpeg CLI to be installed.
46
  cmd = [
47
  "ffmpeg",
48
  "-nostdin",
49
  "-threads",
50
  "0",
51
  "-i",
52
- file,
53
  "-f",
54
  "s16le",
55
  "-ac",
@@ -63,6 +74,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
63
  out = subprocess.run(cmd, capture_output=True, check=True).stdout
64
  except subprocess.CalledProcessError as e:
65
  raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
 
 
 
66
 
67
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
68
 
 
24
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
25
 
26
 
27
+ def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
28
  """
29
+ Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
30
 
31
  Parameters
32
  ----------
33
+ file: Union[str, np.ndarray]
34
+ The audio file to open or a numpy array containing the audio data.
35
 
36
  sr: int
37
+ The sample rate to resample the audio if necessary.
38
 
39
  Returns
40
  -------
41
  A NumPy array containing the audio waveform, in float32 dtype.
42
  """
43
+ if isinstance(file, np.ndarray):
44
+ if file.dtype != np.float32:
45
+ file = file.astype(np.float32)
46
+ if file.ndim > 1:
47
+ file = np.mean(file, axis=1)
48
+
49
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
50
+ write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
51
+ temp_file_path = temp_file.name
52
+ temp_file.close()
53
+ else:
54
+ temp_file_path = file
55
+
56
  try:
 
 
57
  cmd = [
58
  "ffmpeg",
59
  "-nostdin",
60
  "-threads",
61
  "0",
62
  "-i",
63
+ temp_file_path,
64
  "-f",
65
  "s16le",
66
  "-ac",
 
74
  out = subprocess.run(cmd, capture_output=True, check=True).stdout
75
  except subprocess.CalledProcessError as e:
76
  raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
77
+ finally:
78
+ if isinstance(file, np.ndarray):
79
+ os.remove(temp_file_path)
80
 
81
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
82
 
modules/diarize/diarizer.py CHANGED
@@ -1,6 +1,7 @@
1
  import os
2
  import torch
3
- from typing import List
 
4
  import time
5
  import logging
6
 
@@ -20,7 +21,7 @@ class Diarizer:
20
  self.pipe = None
21
 
22
  def run(self,
23
- audio: str,
24
  transcribed_result: List[dict],
25
  use_auth_token: str,
26
  device: str
 
1
  import os
2
  import torch
3
+ from typing import List, Union, BinaryIO
4
+ import numpy as np
5
  import time
6
  import logging
7
 
 
21
  self.pipe = None
22
 
23
  def run(self,
24
+ audio: Union[str, BinaryIO, np.ndarray],
25
  transcribed_result: List[dict],
26
  use_auth_token: str,
27
  device: str