Update inference.py
Browse files- inference.py +12 -6
inference.py
CHANGED
@@ -71,16 +71,22 @@ class InferenceRecipe:
|
|
71 |
"""Load and preprocess audio."""
|
72 |
try:
|
73 |
# Convert to tensor
|
74 |
-
wav = torch.from_numpy(audio_array).float().unsqueeze(0)
|
75 |
|
76 |
# Resample if needed
|
77 |
if sample_rate != self.sample_rate:
|
78 |
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
|
79 |
-
|
|
|
80 |
orig_freq=sample_rate,
|
81 |
new_freq=self.sample_rate
|
82 |
-
)(
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
84 |
# Ensure frame alignment
|
85 |
frame_size = int(self.sample_rate / self.frame_rate)
|
86 |
orig_length = wav.shape[-1]
|
@@ -89,11 +95,11 @@ class InferenceRecipe:
|
|
89 |
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
|
90 |
|
91 |
return wav
|
92 |
-
|
93 |
except Exception as e:
|
94 |
logger.error(f"Audio loading failed: {str(e)}")
|
95 |
raise
|
96 |
-
|
97 |
def _pad_codes(self, all_codes, time_seconds=30):
|
98 |
try:
|
99 |
min_frames = int(time_seconds * self.frame_rate)
|
|
|
71 |
"""Load and preprocess audio."""
|
72 |
try:
|
73 |
# Convert to tensor
|
74 |
+
wav = torch.from_numpy(audio_array).float().unsqueeze(0)
|
75 |
|
76 |
# Resample if needed
|
77 |
if sample_rate != self.sample_rate:
|
78 |
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
|
79 |
+
# Create resampler on same device as input will be
|
80 |
+
resampler = torchaudio.transforms.Resample(
|
81 |
orig_freq=sample_rate,
|
82 |
new_freq=self.sample_rate
|
83 |
+
).to(self.device)
|
84 |
+
# Move wav to device before resampling
|
85 |
+
wav = resampler(wav.to(self.device))
|
86 |
+
else:
|
87 |
+
# If no resampling needed, still ensure wav is on correct device
|
88 |
+
wav = wav.to(self.device)
|
89 |
+
|
90 |
# Ensure frame alignment
|
91 |
frame_size = int(self.sample_rate / self.frame_rate)
|
92 |
orig_length = wav.shape[-1]
|
|
|
95 |
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
|
96 |
|
97 |
return wav
|
98 |
+
|
99 |
except Exception as e:
|
100 |
logger.error(f"Audio loading failed: {str(e)}")
|
101 |
raise
|
102 |
+
|
103 |
def _pad_codes(self, all_codes, time_seconds=30):
|
104 |
try:
|
105 |
min_frames = int(time_seconds * self.frame_rate)
|