nsaintsever commited on
Commit
2fff01c
·
1 Parent(s): 28002b3

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +113 -0
  2. requirements.txt +12 -0
app.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torchaudio
4
+ from audiocraft.models import MusicGen
5
+ import os
6
+ import numpy as np
7
+ import base64
8
+
9
+ @st.cache_resource()
10
+ def load_model():
11
+ model = MusicGen.get_pretrained('facebook/musicgen-small')
12
+ return model
13
+
14
+
15
+ @st.cache_resource()
16
+ def generate_music_tensors(description, duration: int):
17
+ model = load_model()
18
+
19
+ model.set_generation_params(
20
+ use_sampling=True,
21
+ top_k=250,
22
+ duration=duration
23
+ )
24
+
25
+ output = model.generate(
26
+ descriptions=[description],
27
+ progress=True,
28
+ return_tokens=True
29
+ )
30
+ return output[0]
31
+
32
+
33
+ def save_audio(samples: torch.Tensor):
34
+ """Renders an audio player for the given audio samples and saves them to a local directory.
35
+
36
+ Args:
37
+ samples (torch.Tensor): a Tensor of decoded audio samples
38
+ with shapes [B, C, T] or [C, T]
39
+ sample_rate (int): sample rate audio should be displayed with.
40
+ save_path (str): path to the directory where audio should be saved.
41
+ """
42
+
43
+ print("Samples (inside function): ", samples)
44
+ sample_rate = 30000
45
+ save_path = "audio_output/"
46
+ assert samples.dim() == 2 or samples.dim() == 3
47
+
48
+ samples = samples.detach().cpu()
49
+ if samples.dim() == 2:
50
+ samples = samples[None, ...]
51
+
52
+ for idx, audio in enumerate(samples):
53
+ audio_path = os.path.join(save_path, f"audio_{idx}.wav")
54
+ torchaudio.save(audio_path, audio, sample_rate)
55
+
56
+ def get_binary_file_downloader_html(bin_file, file_label='File'):
57
+ with open(bin_file, 'rb') as f:
58
+ data = f.read()
59
+ bin_str = base64.b64encode(data).decode()
60
+ href = f'<a href="data:application/octet-stream;base64,{bin_str}" download="{os.path.basename(bin_file)}">Download {file_label}</a>'
61
+ return href
62
+
63
+ st.set_page_config(
64
+ page_icon= "musical_note",
65
+ page_title= "Music Gen"
66
+ )
67
+
68
+ def main():
69
+ with st.sidebar:
70
+ st.header("""⚙️ Parameters ⚙️""",divider="rainbow")
71
+ st.text("")
72
+ st.subheader("1. Enter your music description.......")
73
+ text_area = st.text_area('Ex : 80s rock song with guitar and drums')
74
+ st.text('')
75
+ st.subheader("2. Select time duration (In Seconds)")
76
+
77
+ time_slider = st.slider("Select time duration (In Seconds)", 0, 20, 10)
78
+
79
+ st.title("""🎵 Text to Music Generator 🎵""")
80
+ st.text('')
81
+ left_co,right_co = st.columns(2)
82
+ left_co.write("""Music Generation using Meta AI, through a prompt""")
83
+ left_co.write(("""PS : First generation may take some time as it loads the full model and requirements"""))
84
+ #container1 = st.container()
85
+ #container1.write("""Music coupled with Image Generation using a prompt""")
86
+ #container1.write("""PS : First generation may take some time as it loads the full model and requirements""")
87
+
88
+
89
+ if st.sidebar.button('Generate !'):
90
+ gif_url = "https://media.giphy.com/media/26Fffy7jqQW8gVg8o/giphy.gif"
91
+ with right_co:
92
+ with st.spinner("Generating"):
93
+ st.image(gif_url,width=250)
94
+ with left_co:
95
+ st.text('')
96
+ st.text('')
97
+ st.text('')
98
+ st.text('')
99
+ st.text('')
100
+ st.text('')
101
+ st.subheader("Generated Music")
102
+
103
+ music_tensors = generate_music_tensors(text_area, time_slider)
104
+ save_music_file = save_audio(music_tensors)
105
+ audio_filepath = 'audio_output/audio_0.wav'
106
+ audio_file = open(audio_filepath, 'rb')
107
+ audio_bytes = audio_file.read()
108
+ st.audio(audio_bytes)
109
+ st.markdown(get_binary_file_downloader_html(audio_filepath, 'Audio'), unsafe_allow_html=True)
110
+
111
+
112
+ if __name__ == "__main__":
113
+ main()
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ audiocraft
2
+ numpy==1.23.5
3
+ torch==2.0.1
4
+ torchaudio==2.0.2
5
+ huggingface_hub
6
+ transformers==4.33.3
7
+ torchmetrics
8
+ encodec==0.1.1
9
+ xformers==0.0.22
10
+ streamlit
11
+ librosa
12
+ protobuf==3.20.0