prithivMLmods commited on
Commit
a657082
·
verified ·
1 Parent(s): 9021a40

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +196 -0
app.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import subprocess
4
+ import sys
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+ from typing import Optional, Tuple
8
+ from urllib.request import urlopen, urlretrieve
9
+
10
+ import gradio as gr
11
+ from huggingface_hub import HfApi, whoami
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+
17
+ @dataclass
18
+ class Config:
19
+ """Application configuration."""
20
+
21
+ hf_token: str
22
+ hf_username: str
23
+ transformers_version: str = "3.0.0"
24
+ hf_base_url: str = "https://huggingface.co"
25
+ transformers_base_url: str = (
26
+ "https://github.com/xenova/transformers.js/archive/refs"
27
+ )
28
+ repo_path: Path = Path("./transformers.js")
29
+
30
+ @classmethod
31
+ def from_env(cls) -> "Config":
32
+ """Create config from environment variables and secrets."""
33
+ system_token = os.getenv("HF_TOKEN")
34
+ user_token = gr.session_state.get("user_hf_token")
35
+ if user_token:
36
+ hf_username = whoami(token=user_token)["name"]
37
+ else:
38
+ hf_username = (
39
+ os.getenv("SPACE_AUTHOR_NAME") or whoami(token=system_token)["name"]
40
+ )
41
+ hf_token = user_token or system_token
42
+
43
+ if not hf_token:
44
+ raise ValueError("HF_TOKEN must be set")
45
+
46
+ return cls(hf_token=hf_token, hf_username=hf_username)
47
+
48
+
49
+ class ModelConverter:
50
+ """Handles model conversion and upload operations."""
51
+
52
+ def __init__(self, config: Config):
53
+ self.config = config
54
+ self.api = HfApi(token=config.hf_token)
55
+
56
+ def _get_ref_type(self) -> str:
57
+ """Determine the reference type for the transformers repository."""
58
+ url = f"{self.config.transformers_base_url}/tags/{self.config.transformers_version}.tar.gz"
59
+ try:
60
+ return "tags" if urlopen(url).getcode() == 200 else "heads"
61
+ except Exception as e:
62
+ logger.warning(f"Failed to check tags, defaulting to heads: {e}")
63
+ return "heads"
64
+
65
+ def setup_repository(self) -> None:
66
+ """Download and setup transformers repository if needed."""
67
+ if self.config.repo_path.exists():
68
+ return
69
+
70
+ ref_type = self._get_ref_type()
71
+ archive_url = f"{self.config.transformers_base_url}/{ref_type}/{self.config.transformers_version}.tar.gz"
72
+ archive_path = Path(f"./transformers_{self.config.transformers_version}.tar.gz")
73
+
74
+ try:
75
+ urlretrieve(archive_url, archive_path)
76
+ self._extract_archive(archive_path)
77
+ logger.info("Repository downloaded and extracted successfully")
78
+ except Exception as e:
79
+ raise RuntimeError(f"Failed to setup repository: {e}")
80
+ finally:
81
+ archive_path.unlink(missing_ok=True)
82
+
83
+ def _extract_archive(self, archive_path: Path) -> None:
84
+ """Extract the downloaded archive."""
85
+ import tarfile
86
+ import tempfile
87
+
88
+ with tempfile.TemporaryDirectory() as tmp_dir:
89
+ with tarfile.open(archive_path, "r:gz") as tar:
90
+ tar.extractall(tmp_dir)
91
+
92
+ extracted_folder = next(Path(tmp_dir).iterdir())
93
+ extracted_folder.rename(self.config.repo_path)
94
+
95
+ def convert_model(self, input_model_id: str) -> Tuple[bool, Optional[str]]:
96
+ """Convert the model to ONNX format."""
97
+ try:
98
+ result = subprocess.run(
99
+ [
100
+ sys.executable,
101
+ "-m",
102
+ "scripts.convert",
103
+ "--quantize",
104
+ "--model_id",
105
+ input_model_id,
106
+ ],
107
+ cwd=self.config.repo_path,
108
+ capture_output=True,
109
+ text=True,
110
+ env={},
111
+ )
112
+
113
+ if result.returncode != 0:
114
+ return False, result.stderr
115
+
116
+ return True, result.stderr
117
+
118
+ except Exception as e:
119
+ return False, str(e)
120
+
121
+ def upload_model(self, input_model_id: str) -> Optional[str]:
122
+ """Upload the converted model to Hugging Face under the onnx/ folder."""
123
+ try:
124
+ model_folder_path = self.config.repo_path / "models" / input_model_id
125
+ onnx_folder_path = model_folder_path / "onnx"
126
+
127
+ # Create the onnx folder if it doesn't exist
128
+ onnx_folder_path.mkdir(exist_ok=True)
129
+
130
+ # Move the converted model files to the onnx folder
131
+ for file in model_folder_path.iterdir():
132
+ if file.is_file() and file.suffix == ".onnx":
133
+ file.rename(onnx_folder_path / file.name)
134
+
135
+ # Upload the onnx folder to the same model path
136
+ self.api.upload_folder(
137
+ folder_path=str(onnx_folder_path), repo_id=input_model_id, path_in_repo="onnx"
138
+ )
139
+ return None
140
+ except Exception as e:
141
+ return str(e)
142
+ finally:
143
+ import shutil
144
+
145
+ shutil.rmtree(model_folder_path, ignore_errors=True)
146
+
147
+
148
+ def convert_and_upload_model(input_model_id: str, user_hf_token: str = None):
149
+ """Function to handle model conversion and upload."""
150
+ try:
151
+ config = Config.from_env()
152
+ if user_hf_token:
153
+ gr.session_state["user_hf_token"] = user_hf_token
154
+ config = Config.from_env()
155
+
156
+ converter = ModelConverter(config)
157
+ converter.setup_repository()
158
+
159
+ if converter.api.repo_exists(input_model_id):
160
+ with gr.spinner("Converting model..."):
161
+ success, stderr = converter.convert_model(input_model_id)
162
+ if not success:
163
+ return f"Conversion failed: {stderr}"
164
+
165
+ gr.Info("Conversion successful!")
166
+ gr.Code(stderr)
167
+
168
+ with gr.spinner("Uploading model..."):
169
+ error = converter.upload_model(input_model_id)
170
+ if error:
171
+ return f"Upload failed: {error}"
172
+
173
+ gr.Info("Upload successful!")
174
+ return f"Model uploaded to {input_model_id}/onnx"
175
+ else:
176
+ return f"Model {input_model_id} does not exist on Hugging Face."
177
+
178
+ except Exception as e:
179
+ logger.exception("Application error")
180
+ return f"An error occurred: {str(e)}"
181
+
182
+
183
+ # Gradio Interface
184
+ iface = gr.Interface(
185
+ fn=convert_and_upload_model,
186
+ inputs=[
187
+ gr.Textbox(label="Hugging Face Model ID", placeholder="Enter the Hugging Face model ID to convert. Example: `EleutherAI/pythia-14m`"),
188
+ gr.Textbox(label="Hugging Face Write Token (Optional)", type="password", placeholder="Fill it if you want to upload the model under your account.")
189
+ ],
190
+ outputs=gr.Textbox(label="Output"),
191
+ title="Convert a Hugging Face model to ONNX",
192
+ description="This tool converts a Hugging Face model to ONNX format and uploads it to the same model path under an `onnx/` folder."
193
+ )
194
+
195
+ if __name__ == "__main__":
196
+ iface.launch()