Pierre Chapuis
commited on
better errors
Browse files- src/app.py +12 -3
- src/fg.py +7 -7
src/app.py
CHANGED
@@ -50,6 +50,14 @@ class ProcessParams:
|
|
50 |
bbox: tuple[int, int, int, int] | None = None
|
51 |
|
52 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
53 |
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
54 |
with io.BytesIO() as f:
|
55 |
params.image.save(f, format="JPEG")
|
@@ -61,14 +69,15 @@ async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
|
61 |
segment_params = {"bbox": list(params.bbox)}
|
62 |
else:
|
63 |
assert params.prompt
|
64 |
-
segment_input_st = await
|
|
|
65 |
f"infer-bbox/{st_input}",
|
66 |
{"product_name": params.prompt},
|
67 |
)
|
68 |
segment_params = {}
|
69 |
|
70 |
-
st_mask = await ctx
|
71 |
-
st_erased = await ctx
|
72 |
|
73 |
response = await ctx.request(
|
74 |
"GET",
|
|
|
50 |
bbox: tuple[int, int, int, int] | None = None
|
51 |
|
52 |
|
53 |
+
async def call_or_raise(ctx: EditorAPIContext, uri: str, params: dict[str, Any]) -> str:
|
54 |
+
st, ok = await ctx.call_skill(uri, params)
|
55 |
+
if ok:
|
56 |
+
return st
|
57 |
+
meta = await ctx.get_meta(st)
|
58 |
+
raise RuntimeError(f"skill {uri} failed with {st}: {meta}")
|
59 |
+
|
60 |
+
|
61 |
async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
|
62 |
with io.BytesIO() as f:
|
63 |
params.image.save(f, format="JPEG")
|
|
|
69 |
segment_params = {"bbox": list(params.bbox)}
|
70 |
else:
|
71 |
assert params.prompt
|
72 |
+
segment_input_st = await call_or_raise(
|
73 |
+
ctx,
|
74 |
f"infer-bbox/{st_input}",
|
75 |
{"product_name": params.prompt},
|
76 |
)
|
77 |
segment_params = {}
|
78 |
|
79 |
+
st_mask = await call_or_raise(ctx, f"segment/{segment_input_st}", segment_params)
|
80 |
+
st_erased = await call_or_raise(ctx, f"erase/{st_input}/{st_mask}", {"mode": "free"})
|
81 |
|
82 |
response = await ctx.request(
|
83 |
"GET",
|
src/fg.py
CHANGED
@@ -172,7 +172,7 @@ class EditorAPIContext:
|
|
172 |
self._sse_last_event_id = ""
|
173 |
self._sse_retry_ms = 0
|
174 |
self._sse_task = asyncio.create_task(self._sse_loop())
|
175 |
-
await self.sse_await("_sse_loop")
|
176 |
self._sse_failures = 0
|
177 |
|
178 |
async def sse_recover(self) -> bool:
|
@@ -181,7 +181,7 @@ class EditorAPIContext:
|
|
181 |
return False
|
182 |
self._sse_task = asyncio.create_task(self._sse_loop())
|
183 |
try:
|
184 |
-
await self.sse_await("_sse_loop")
|
185 |
return True
|
186 |
except SSELoopStopped:
|
187 |
pass
|
@@ -192,7 +192,7 @@ class EditorAPIContext:
|
|
192 |
await self._sse_task
|
193 |
self._sse_task = None
|
194 |
|
195 |
-
async def sse_await(self, state_id: str, timeout: float | None = None) ->
|
196 |
assert self._sse_task
|
197 |
future = self._sse_futures[state_id]
|
198 |
|
@@ -217,7 +217,7 @@ class EditorAPIContext:
|
|
217 |
|
218 |
jdata = future.result()
|
219 |
del self._sse_futures[state_id]
|
220 |
-
|
221 |
|
222 |
async def get_meta(self, state_id: str) -> dict[str, Any]:
|
223 |
response = await self.request("GET", f"state/meta/{state_id}")
|
@@ -257,9 +257,9 @@ class EditorAPIContext:
|
|
257 |
uri: str,
|
258 |
params: dict[str, Any] | None,
|
259 |
timeout: float | None = None,
|
260 |
-
) -> str:
|
261 |
params = {"priority": self.priority} | (params or {})
|
262 |
response = await self.request("POST", f"skills/{uri}", json=params)
|
263 |
state_id = response.json()["state"]
|
264 |
-
await self.sse_await(state_id, timeout=timeout)
|
265 |
-
return state_id
|
|
|
172 |
self._sse_last_event_id = ""
|
173 |
self._sse_retry_ms = 0
|
174 |
self._sse_task = asyncio.create_task(self._sse_loop())
|
175 |
+
assert await self.sse_await("_sse_loop")
|
176 |
self._sse_failures = 0
|
177 |
|
178 |
async def sse_recover(self) -> bool:
|
|
|
181 |
return False
|
182 |
self._sse_task = asyncio.create_task(self._sse_loop())
|
183 |
try:
|
184 |
+
assert await self.sse_await("_sse_loop")
|
185 |
return True
|
186 |
except SSELoopStopped:
|
187 |
pass
|
|
|
192 |
await self._sse_task
|
193 |
self._sse_task = None
|
194 |
|
195 |
+
async def sse_await(self, state_id: str, timeout: float | None = None) -> bool:
|
196 |
assert self._sse_task
|
197 |
future = self._sse_futures[state_id]
|
198 |
|
|
|
217 |
|
218 |
jdata = future.result()
|
219 |
del self._sse_futures[state_id]
|
220 |
+
return jdata["status"] == "ok"
|
221 |
|
222 |
async def get_meta(self, state_id: str) -> dict[str, Any]:
|
223 |
response = await self.request("GET", f"state/meta/{state_id}")
|
|
|
257 |
uri: str,
|
258 |
params: dict[str, Any] | None,
|
259 |
timeout: float | None = None,
|
260 |
+
) -> tuple[str, bool]:
|
261 |
params = {"priority": self.priority} | (params or {})
|
262 |
response = await self.request("POST", f"skills/{uri}", json=params)
|
263 |
state_id = response.json()["state"]
|
264 |
+
status = await self.sse_await(state_id, timeout=timeout)
|
265 |
+
return state_id, status
|