Pierre Chapuis commited on
Commit
6ca913f
·
unverified ·
1 Parent(s): 61a384d

use API client

Browse files
Files changed (5) hide show
  1. pyproject.toml +1 -3
  2. requirements.lock +12 -10
  3. requirements.txt +1 -2
  4. src/app.py +1 -2
  5. src/fg.py +0 -265
pyproject.toml CHANGED
@@ -9,11 +9,10 @@ dependencies = [
9
  "gradio>=4.41.0,<5", # gradio-imageslider requires <5
10
  "environs>=11.0.0",
11
  "gradio-image-annotation>=0.2.5",
12
- "httpx>=0.27.0",
13
  "pillow>=10.4.0",
14
  "gradio-imageslider>=0.0.20",
15
  "pillow-heif>=0.18.0",
16
- "httpx-sse>=0.4.0",
17
  ]
18
  readme = "README.md"
19
  requires-python = ">= 3.12, <3.13"
@@ -51,4 +50,3 @@ select = [
51
  [tool.pyright]
52
  include = ["src"]
53
  exclude = ["**/__pycache__"]
54
- strict = ["src/fg.py"]
 
9
  "gradio>=4.41.0,<5", # gradio-imageslider requires <5
10
  "environs>=11.0.0",
11
  "gradio-image-annotation>=0.2.5",
 
12
  "pillow>=10.4.0",
13
  "gradio-imageslider>=0.0.20",
14
  "pillow-heif>=0.18.0",
15
+ "finegrain @ git+ssh://[email protected]/finegrain-ai/finegrain-python@08e74457ec3390c1401609931941c8f80284efd6#subdirectory=finegrain",
16
  ]
17
  readme = "README.md"
18
  requires-python = ">= 3.12, <3.13"
 
50
  [tool.pyright]
51
  include = ["src"]
52
  exclude = ["**/__pycache__"]
 
requirements.lock CHANGED
@@ -33,13 +33,15 @@ cycler==0.12.1
33
  # via matplotlib
34
  environs==14.1.0
35
  # via eraser
36
- fastapi==0.115.6
37
  # via gradio
38
  ffmpy==0.5.0
39
  # via gradio
40
- filelock==3.16.1
41
  # via huggingface-hub
42
- fonttools==4.55.3
 
 
43
  # via matplotlib
44
  fsspec==2024.12.0
45
  # via gradio-client
@@ -60,11 +62,11 @@ h11==0.14.0
60
  httpcore==1.0.7
61
  # via httpx
62
  httpx==0.28.1
63
- # via eraser
64
  # via gradio
65
  # via gradio-client
66
  httpx-sse==0.4.0
67
- # via eraser
68
  huggingface-hub==0.27.1
69
  # via gradio
70
  # via gradio-client
@@ -83,7 +85,7 @@ markdown-it-py==3.0.0
83
  markupsafe==2.1.5
84
  # via gradio
85
  # via jinja2
86
- marshmallow==3.25.1
87
  # via environs
88
  matplotlib==3.10.0
89
  # via gradio
@@ -112,7 +114,7 @@ pillow==10.4.0
112
  # via pillow-heif
113
  pillow-heif==0.21.0
114
  # via eraser
115
- pydantic==2.10.5
116
  # via fastapi
117
  # via gradio
118
  pydantic-core==2.27.2
@@ -139,7 +141,7 @@ requests==2.32.3
139
  # via huggingface-hub
140
  rich==13.9.4
141
  # via typer
142
- ruff==0.9.2
143
  # via gradio
144
  semantic-version==2.10.0
145
  # via gradio
@@ -149,7 +151,7 @@ six==1.17.0
149
  # via python-dateutil
150
  sniffio==1.3.1
151
  # via anyio
152
- starlette==0.41.3
153
  # via fastapi
154
  tomlkit==0.12.0
155
  # via gradio
@@ -166,7 +168,7 @@ typing-extensions==4.12.2
166
  # via pydantic
167
  # via pydantic-core
168
  # via typer
169
- tzdata==2024.2
170
  # via pandas
171
  urllib3==2.3.0
172
  # via gradio
 
33
  # via matplotlib
34
  environs==14.1.0
35
  # via eraser
36
+ fastapi==0.115.7
37
  # via gradio
38
  ffmpy==0.5.0
39
  # via gradio
40
+ filelock==3.17.0
41
  # via huggingface-hub
42
+ finegrain @ git+ssh://git@github.com/finegrain-ai/finegrain-python@08e74457ec3390c1401609931941c8f80284efd6#subdirectory=finegrain
43
+ # via eraser
44
+ fonttools==4.55.5
45
  # via matplotlib
46
  fsspec==2024.12.0
47
  # via gradio-client
 
62
  httpcore==1.0.7
63
  # via httpx
64
  httpx==0.28.1
65
+ # via finegrain
66
  # via gradio
67
  # via gradio-client
68
  httpx-sse==0.4.0
69
+ # via finegrain
70
  huggingface-hub==0.27.1
71
  # via gradio
72
  # via gradio-client
 
85
  markupsafe==2.1.5
86
  # via gradio
87
  # via jinja2
88
+ marshmallow==3.26.0
89
  # via environs
90
  matplotlib==3.10.0
91
  # via gradio
 
114
  # via pillow-heif
115
  pillow-heif==0.21.0
116
  # via eraser
117
+ pydantic==2.10.6
118
  # via fastapi
119
  # via gradio
120
  pydantic-core==2.27.2
 
141
  # via huggingface-hub
142
  rich==13.9.4
143
  # via typer
144
+ ruff==0.9.3
145
  # via gradio
146
  semantic-version==2.10.0
147
  # via gradio
 
151
  # via python-dateutil
152
  sniffio==1.3.1
153
  # via anyio
154
+ starlette==0.45.3
155
  # via fastapi
156
  tomlkit==0.12.0
157
  # via gradio
 
168
  # via pydantic
169
  # via pydantic-core
170
  # via typer
171
+ tzdata==2025.1
172
  # via pandas
173
  urllib3==2.3.0
174
  # via gradio
requirements.txt CHANGED
@@ -1,7 +1,6 @@
1
  gradio_image_annotation>=0.2.5
2
  gradio_imageslider>=0.0.20
3
  environs>=11.0.0
4
- httpx>=0.27.0
5
- httpx-sse>=0.4.0
6
  pillow>=10.4.0
7
  pillow-heif>=0.18.0
 
 
1
  gradio_image_annotation>=0.2.5
2
  gradio_imageslider>=0.0.20
3
  environs>=11.0.0
 
 
4
  pillow>=10.4.0
5
  pillow-heif>=0.18.0
6
+ git+https://github.com/finegrain-ai/finegrain-python@08e74457ec3390c1401609931941c8f80284efd6#subdirectory=finegrain
src/app.py CHANGED
@@ -6,12 +6,11 @@ from typing import Any
6
  import gradio as gr
7
  import pillow_heif
8
  from environs import Env
 
9
  from gradio_image_annotation import image_annotator
10
  from gradio_imageslider import ImageSlider
11
  from PIL import Image
12
 
13
- from fg import EditorAPIContext
14
-
15
  pillow_heif.register_heif_opener()
16
  pillow_heif.register_avif_opener()
17
 
 
6
  import gradio as gr
7
  import pillow_heif
8
  from environs import Env
9
+ from finegrain import EditorAPIContext
10
  from gradio_image_annotation import image_annotator
11
  from gradio_imageslider import ImageSlider
12
  from PIL import Image
13
 
 
 
14
  pillow_heif.register_heif_opener()
15
  pillow_heif.register_avif_opener()
16
 
src/fg.py DELETED
@@ -1,265 +0,0 @@
1
- import asyncio
2
- import dataclasses as dc
3
- import json
4
- import logging
5
- from collections import defaultdict
6
- from collections.abc import Awaitable, Callable, Mapping
7
- from typing import Any, Literal, cast
8
-
9
- import httpx
10
- import httpx_sse
11
- from httpx._types import QueryParamTypes, RequestFiles
12
-
13
- logger = logging.getLogger(__name__)
14
-
15
- Priority = Literal["low", "standard", "high"]
16
-
17
-
18
- class SSELoopStopped(RuntimeError):
19
- pass
20
-
21
-
22
- class Futures[T]:
23
- @classmethod
24
- def create_future(cls) -> asyncio.Future[T]:
25
- return asyncio.get_running_loop().create_future()
26
-
27
- def __init__(self, capacity: int = 256) -> None:
28
- self.futures = defaultdict[str, asyncio.Future[T]](self.create_future)
29
- self.capacity = capacity
30
-
31
- def cull(self) -> None:
32
- while len(self.futures) >= self.capacity:
33
- del self.futures[next(iter(self.futures))]
34
-
35
- def __getitem__(self, key: str) -> asyncio.Future[T]:
36
- self.cull()
37
- return self.futures[key]
38
-
39
- def __delitem__(self, key: str) -> None:
40
- try:
41
- del self.futures[key]
42
- except KeyError:
43
- pass
44
-
45
-
46
- @dc.dataclass(kw_only=True)
47
- class EditorAPIContext:
48
- uri: str
49
- user: str
50
- password: str
51
- priority: Priority = "standard"
52
- token: str | None = None
53
- verify: bool | str = True
54
- default_timeout: float = 60.0
55
- logger: logging.Logger = logger
56
- max_sse_failures: int = 5
57
-
58
- _client: httpx.AsyncClient | None = None
59
- _client_ctx_depth: int = 0
60
- _sse_futures: Futures[dict[str, Any]] = dc.field(default_factory=Futures)
61
- _sse_task: asyncio.Task[None] | None = None
62
- _sse_failures: int = 0
63
- _sse_last_event_id: str = ""
64
- _sse_retry_ms: int = 0
65
-
66
- async def __aenter__(self) -> httpx.AsyncClient:
67
- if self._client:
68
- assert self._client_ctx_depth > 0
69
- self._client_ctx_depth += 1
70
- return self._client
71
- assert self._client_ctx_depth == 0
72
- self._client = httpx.AsyncClient(verify=self.verify)
73
- self._client_ctx_depth = 1
74
- return self._client
75
-
76
- async def __aexit__(self, *args: Any) -> None:
77
- if (not self._client) or self._client_ctx_depth <= 0:
78
- raise RuntimeError("unbalanced __aexit__")
79
- self._client_ctx_depth -= 1
80
- if self._client_ctx_depth == 0:
81
- await self._client.__aexit__(*args)
82
- self._client = None
83
-
84
- @property
85
- def auth_headers(self) -> dict[str, str]:
86
- assert self.token
87
- return {"Authorization": f"Bearer {self.token}"}
88
-
89
- async def login(self) -> None:
90
- async with self as client:
91
- response = await client.post(
92
- f"{self.uri}/auth/login",
93
- json={"username": self.user, "password": self.password},
94
- )
95
- response.raise_for_status()
96
- self.logger.debug(f"logged in as {self.user}")
97
- self.token = response.json()["token"]
98
-
99
- async def request(
100
- self,
101
- method: Literal["GET", "POST"],
102
- url: str,
103
- files: RequestFiles | None = None,
104
- params: QueryParamTypes | None = None,
105
- json: dict[str, Any] | None = None,
106
- headers: Mapping[str, str] | None = None,
107
- raise_for_status: bool = True,
108
- ) -> httpx.Response:
109
- async def _q() -> httpx.Response:
110
- return await client.request(
111
- method,
112
- f"{self.uri}/{url}",
113
- headers=dict(headers or {}) | self.auth_headers,
114
- files=files,
115
- params=params,
116
- json=json,
117
- )
118
-
119
- async with self as client:
120
- r = await _q()
121
- if r.status_code == 401:
122
- self.logger.debug("renewing token")
123
- await self.login()
124
- r = await _q()
125
-
126
- if raise_for_status:
127
- r.raise_for_status()
128
- return r
129
-
130
- @classmethod
131
- def decode_json(cls, data: str) -> dict[str, Any] | None:
132
- try:
133
- r = json.loads(data)
134
- except json.JSONDecodeError:
135
- return None
136
- if type(r) is not dict:
137
- return None
138
- return cast(dict[str, Any], r)
139
-
140
- async def _sse_loop(self) -> None:
141
- response = await self.request("POST", "sub-auth")
142
- sub_token = response.json()["token"]
143
- url = f"{self.uri}/sub/{sub_token}"
144
- headers = {"Accept": "text/event-stream"}
145
- if self._sse_last_event_id:
146
- retry_ms = self._sse_retry_ms + 1000 * 2**self._sse_failures
147
- self.logger.info(f"resuming SSE from event {self._sse_last_event_id} in {retry_ms} ms")
148
- await asyncio.sleep(retry_ms / 1000)
149
- headers["Last-Event-ID"] = self._sse_last_event_id
150
- async with (
151
- httpx.AsyncClient(timeout=None, verify=self.verify) as c,
152
- httpx_sse.aconnect_sse(c, "GET", url, headers=headers) as es,
153
- ):
154
- es.response.raise_for_status()
155
- self._sse_futures["_sse_loop"].set_result({"status": "ok"})
156
- try:
157
- async for sse in es.aiter_sse():
158
- self._sse_last_event_id = sse.id
159
- self._sse_retry_ms = sse.retry or 0
160
- jdata = self.decode_json(sse.data)
161
- if (jdata is None) or ("state" not in jdata):
162
- # Note: when the server restarts we typically get an
163
- # empty string here, then the loop exits.
164
- self.logger.warning(f"unexpected SSE data: {sse.data}")
165
- continue
166
- self._sse_futures[jdata["state"]].set_result(jdata)
167
- except asyncio.CancelledError:
168
- pass
169
-
170
- async def sse_start(self) -> None:
171
- assert self._sse_task is None
172
- self._sse_last_event_id = ""
173
- self._sse_retry_ms = 0
174
- self._sse_task = asyncio.create_task(self._sse_loop())
175
- assert await self.sse_await("_sse_loop")
176
- self._sse_failures = 0
177
-
178
- async def sse_recover(self) -> bool:
179
- while True:
180
- if self._sse_failures > self.max_sse_failures:
181
- return False
182
- self._sse_task = asyncio.create_task(self._sse_loop())
183
- try:
184
- assert await self.sse_await("_sse_loop")
185
- return True
186
- except SSELoopStopped:
187
- pass
188
-
189
- async def sse_stop(self) -> None:
190
- assert self._sse_task
191
- self._sse_task.cancel()
192
- await self._sse_task
193
- self._sse_task = None
194
-
195
- async def sse_await(self, state_id: str, timeout: float | None = None) -> bool:
196
- assert self._sse_task
197
- future = self._sse_futures[state_id]
198
-
199
- while True:
200
- done, _ = await asyncio.wait(
201
- {future, self._sse_task},
202
- timeout=timeout or self.default_timeout,
203
- return_when=asyncio.FIRST_COMPLETED,
204
- )
205
- if not done:
206
- raise TimeoutError(f"state {state_id} timed out after {timeout}")
207
- if self._sse_task in done:
208
- self._sse_failures += 1
209
- if state_id != "_sse_loop" and (await self.sse_recover()):
210
- self._sse_failures = 0
211
- continue
212
- exception = self._sse_task.exception()
213
- raise SSELoopStopped(f"SSE loop stopped while waiting for state {state_id}") from exception
214
- break
215
-
216
- assert done == {future}
217
-
218
- jdata = future.result()
219
- del self._sse_futures[state_id]
220
- return jdata["status"] == "ok"
221
-
222
- async def get_meta(self, state_id: str) -> dict[str, Any]:
223
- response = await self.request("GET", f"state/meta/{state_id}")
224
- return response.json()
225
-
226
- async def _run_one[Tin, Tout](
227
- self,
228
- co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
229
- params: Tin,
230
- ) -> Tout:
231
- # This wraps the coroutine in the SSE loop.
232
- # This is mostly useful if you use synchronous Python,
233
- # otherwise you can call the functions directly.
234
- if not self.token:
235
- await self.login()
236
- await self.sse_start()
237
- try:
238
- r = await co(self, params)
239
- return r
240
- finally:
241
- await self.sse_stop()
242
-
243
- def run_one_sync[Tin, Tout](
244
- self,
245
- co: Callable[["EditorAPIContext", Tin], Awaitable[Tout]],
246
- params: Tin,
247
- ) -> Tout:
248
- try:
249
- loop = asyncio.get_event_loop()
250
- except RuntimeError:
251
- loop = asyncio.new_event_loop()
252
- asyncio.set_event_loop(loop)
253
- return loop.run_until_complete(self._run_one(co, params))
254
-
255
- async def call_skill(
256
- self,
257
- uri: str,
258
- params: dict[str, Any] | None,
259
- timeout: float | None = None,
260
- ) -> tuple[str, bool]:
261
- params = {"priority": self.priority} | (params or {})
262
- response = await self.request("POST", f"skills/{uri}", json=params)
263
- state_id = response.json()["state"]
264
- status = await self.sse_await(state_id, timeout=timeout)
265
- return state_id, status