tezuesh commited on
Commit
5acce69
·
verified ·
1 Parent(s): 7542ba5

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +12 -6
inference.py CHANGED
@@ -71,16 +71,22 @@ class InferenceRecipe:
71
  """Load and preprocess audio."""
72
  try:
73
  # Convert to tensor
74
- wav = torch.from_numpy(audio_array).float().unsqueeze(0).to(self.device)
75
 
76
  # Resample if needed
77
  if sample_rate != self.sample_rate:
78
  logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
79
- wav = torchaudio.transforms.Resample(
 
80
  orig_freq=sample_rate,
81
  new_freq=self.sample_rate
82
- )(wav)
83
-
 
 
 
 
 
84
  # Ensure frame alignment
85
  frame_size = int(self.sample_rate / self.frame_rate)
86
  orig_length = wav.shape[-1]
@@ -89,11 +95,11 @@ class InferenceRecipe:
89
  logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
90
 
91
  return wav
92
-
93
  except Exception as e:
94
  logger.error(f"Audio loading failed: {str(e)}")
95
  raise
96
-
97
  def _pad_codes(self, all_codes, time_seconds=30):
98
  try:
99
  min_frames = int(time_seconds * self.frame_rate)
 
71
  """Load and preprocess audio."""
72
  try:
73
  # Convert to tensor
74
+ wav = torch.from_numpy(audio_array).float().unsqueeze(0)
75
 
76
  # Resample if needed
77
  if sample_rate != self.sample_rate:
78
  logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
79
+ # Create resampler on same device as input will be
80
+ resampler = torchaudio.transforms.Resample(
81
  orig_freq=sample_rate,
82
  new_freq=self.sample_rate
83
+ ).to(self.device)
84
+ # Move wav to device before resampling
85
+ wav = resampler(wav.to(self.device))
86
+ else:
87
+ # If no resampling needed, still ensure wav is on correct device
88
+ wav = wav.to(self.device)
89
+
90
  # Ensure frame alignment
91
  frame_size = int(self.sample_rate / self.frame_rate)
92
  orig_length = wav.shape[-1]
 
95
  logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
96
 
97
  return wav
98
+
99
  except Exception as e:
100
  logger.error(f"Audio loading failed: {str(e)}")
101
  raise
102
+
103
  def _pad_codes(self, all_codes, time_seconds=30):
104
  try:
105
  min_frames = int(time_seconds * self.frame_rate)