File size: 1,657 Bytes
a0db56d
c174689
 
dddea8e
a0db56d
c174689
44a72dc
 
011418d
002b689
 
2fb86d3
 
 
 
 
 
 
 
c174689
 
 
 
 
 
 
 
 
011418d
002b689
011418d
c174689
 
 
 
002b689
011418d
 
002b689
011418d
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
import gradio as gr
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
import torch
import torchaudio

# Load the pre-trained Wav2Vec2 model for Darija
processor = Wav2Vec2Processor.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")
model = Wav2Vec2ForCTC.from_pretrained("boumehdi/wav2vec2-large-xlsr-moroccan-darija")

# Function to process the audio file and return transcription
def transcribe_audio(audio_file):
    # Load and process the audio file with the correct sampling rate
    audio_input, sampling_rate = torchaudio.load(audio_file, normalize=True)

    # Make sure the audio input has the correct dimensions
    audio_input = audio_input.squeeze()  # Remove unnecessary dimensions

    # Process the audio input for the model
    input_values = processor(audio_input, sampling_rate=sampling_rate, return_tensors="pt").input_values

    # Perform transcription
    with torch.no_grad():
        logits = model(input_values).logits
    
    # Decode the logits to text
    predicted_ids = torch.argmax(logits, dim=-1)
    transcription = processor.batch_decode(predicted_ids)
    return transcription[0]

# Create a Gradio interface
interface = gr.Interface(
    fn=transcribe_audio,                # Function to call
    inputs=gr.Audio(type="filepath"),   # Input component (audio file upload)
    outputs="text",                     # Output component (text)
    title="Darija ASR Transcription",   # Title of the interface
    description="Upload an audio file in Darija, and the ASR model will transcribe it into text."  # Description
)

# Launch the Gradio interface
if __name__ == "__main__":
    interface.launch()