Mohssinibra commited on
Commit
c174689
·
verified ·
1 Parent(s): 002b689
Files changed (1) hide show
  1. app.py +21 -13
app.py CHANGED
@@ -1,24 +1,32 @@
1
  import gradio as gr
2
- from speechbrain.inference.ASR import EncoderASR
 
3
 
4
- # Load the pre-trained ASR model for Darija
5
- asr_model = EncoderASR.from_hparams(
6
- source="speechbrain/asr-wav2vec2-dvoice-darija",
7
- savedir="pretrained_models/asr-wav2vec2-dvoice-darija"
8
- )
9
 
10
  # Function to process the audio file and return transcription
11
  def transcribe_audio(audio_file):
12
- # Transcribe the uploaded audio file
13
- transcription = asr_model.transcribe_file(audio_file)
14
- return transcription
 
 
 
 
 
 
 
 
 
15
 
16
  # Create a Gradio interface
17
  interface = gr.Interface(
18
- fn=transcribe_audio, # Function to call
19
- inputs=gr.Audio(type="filepath"), # Input component (audio file upload)
20
- outputs="text", # Output component (text)
21
- title="Darija ASR Transcription", # Title of the interface
22
  description="Upload an audio file in Darija, and the ASR model will transcribe it into text." # Description
23
  )
24
 
 
1
  import gradio as gr
2
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
3
+ import torch
4
 
5
+ # Load the pre-trained Wav2Vec2 model for Darija
6
+ processor = Wav2Vec2Processor.from_pretrained("facebook/wav2vec2-large-xlsr-53-arabic")
7
+ model = Wav2Vec2ForCTC.from_pretrained("facebook/wav2vec2-large-xlsr-53-arabic")
 
 
8
 
9
  # Function to process the audio file and return transcription
10
  def transcribe_audio(audio_file):
11
+ # Load and process the audio file
12
+ audio_input, _ = torchaudio.load(audio_file)
13
+ input_values = processor(audio_input, return_tensors="pt").input_values
14
+
15
+ # Perform transcription
16
+ with torch.no_grad():
17
+ logits = model(input_values).logits
18
+
19
+ # Decode the logits to text
20
+ predicted_ids = torch.argmax(logits, dim=-1)
21
+ transcription = processor.batch_decode(predicted_ids)
22
+ return transcription[0]
23
 
24
  # Create a Gradio interface
25
  interface = gr.Interface(
26
+ fn=transcribe_audio, # Function to call
27
+ inputs=gr.Audio(type="filepath"), # Input component (audio file upload)
28
+ outputs="text", # Output component (text)
29
+ title="Darija ASR Transcription", # Title of the interface
30
  description="Upload an audio file in Darija, and the ASR model will transcribe it into text." # Description
31
  )
32