Pierre Chapuis commited on
Commit
61a384d
·
unverified ·
1 Parent(s): 7c9213f

better errors

Browse files
Files changed (2) hide show
  1. src/app.py +12 -3
  2. src/fg.py +7 -7
src/app.py CHANGED
@@ -50,6 +50,14 @@ class ProcessParams:
50
  bbox: tuple[int, int, int, int] | None = None
51
 
52
 
 
 
 
 
 
 
 
 
53
  async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
54
  with io.BytesIO() as f:
55
  params.image.save(f, format="JPEG")
@@ -61,14 +69,15 @@ async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
61
  segment_params = {"bbox": list(params.bbox)}
62
  else:
63
  assert params.prompt
64
- segment_input_st = await ctx.call_skill(
 
65
  f"infer-bbox/{st_input}",
66
  {"product_name": params.prompt},
67
  )
68
  segment_params = {}
69
 
70
- st_mask = await ctx.call_skill(f"segment/{segment_input_st}", segment_params)
71
- st_erased = await ctx.call_skill(f"erase/{st_input}/{st_mask}", {"mode": "free"})
72
 
73
  response = await ctx.request(
74
  "GET",
 
50
  bbox: tuple[int, int, int, int] | None = None
51
 
52
 
53
+ async def call_or_raise(ctx: EditorAPIContext, uri: str, params: dict[str, Any]) -> str:
54
+ st, ok = await ctx.call_skill(uri, params)
55
+ if ok:
56
+ return st
57
+ meta = await ctx.get_meta(st)
58
+ raise RuntimeError(f"skill {uri} failed with {st}: {meta}")
59
+
60
+
61
  async def _process(ctx: EditorAPIContext, params: ProcessParams) -> Image.Image:
62
  with io.BytesIO() as f:
63
  params.image.save(f, format="JPEG")
 
69
  segment_params = {"bbox": list(params.bbox)}
70
  else:
71
  assert params.prompt
72
+ segment_input_st = await call_or_raise(
73
+ ctx,
74
  f"infer-bbox/{st_input}",
75
  {"product_name": params.prompt},
76
  )
77
  segment_params = {}
78
 
79
+ st_mask = await call_or_raise(ctx, f"segment/{segment_input_st}", segment_params)
80
+ st_erased = await call_or_raise(ctx, f"erase/{st_input}/{st_mask}", {"mode": "free"})
81
 
82
  response = await ctx.request(
83
  "GET",
src/fg.py CHANGED
@@ -172,7 +172,7 @@ class EditorAPIContext:
172
  self._sse_last_event_id = ""
173
  self._sse_retry_ms = 0
174
  self._sse_task = asyncio.create_task(self._sse_loop())
175
- await self.sse_await("_sse_loop")
176
  self._sse_failures = 0
177
 
178
  async def sse_recover(self) -> bool:
@@ -181,7 +181,7 @@ class EditorAPIContext:
181
  return False
182
  self._sse_task = asyncio.create_task(self._sse_loop())
183
  try:
184
- await self.sse_await("_sse_loop")
185
  return True
186
  except SSELoopStopped:
187
  pass
@@ -192,7 +192,7 @@ class EditorAPIContext:
192
  await self._sse_task
193
  self._sse_task = None
194
 
195
- async def sse_await(self, state_id: str, timeout: float | None = None) -> None:
196
  assert self._sse_task
197
  future = self._sse_futures[state_id]
198
 
@@ -217,7 +217,7 @@ class EditorAPIContext:
217
 
218
  jdata = future.result()
219
  del self._sse_futures[state_id]
220
- assert jdata["status"] == "ok", f"state {state_id} is {jdata['status']}"
221
 
222
  async def get_meta(self, state_id: str) -> dict[str, Any]:
223
  response = await self.request("GET", f"state/meta/{state_id}")
@@ -257,9 +257,9 @@ class EditorAPIContext:
257
  uri: str,
258
  params: dict[str, Any] | None,
259
  timeout: float | None = None,
260
- ) -> str:
261
  params = {"priority": self.priority} | (params or {})
262
  response = await self.request("POST", f"skills/{uri}", json=params)
263
  state_id = response.json()["state"]
264
- await self.sse_await(state_id, timeout=timeout)
265
- return state_id
 
172
  self._sse_last_event_id = ""
173
  self._sse_retry_ms = 0
174
  self._sse_task = asyncio.create_task(self._sse_loop())
175
+ assert await self.sse_await("_sse_loop")
176
  self._sse_failures = 0
177
 
178
  async def sse_recover(self) -> bool:
 
181
  return False
182
  self._sse_task = asyncio.create_task(self._sse_loop())
183
  try:
184
+ assert await self.sse_await("_sse_loop")
185
  return True
186
  except SSELoopStopped:
187
  pass
 
192
  await self._sse_task
193
  self._sse_task = None
194
 
195
+ async def sse_await(self, state_id: str, timeout: float | None = None) -> bool:
196
  assert self._sse_task
197
  future = self._sse_futures[state_id]
198
 
 
217
 
218
  jdata = future.result()
219
  del self._sse_futures[state_id]
220
+ return jdata["status"] == "ok"
221
 
222
  async def get_meta(self, state_id: str) -> dict[str, Any]:
223
  response = await self.request("GET", f"state/meta/{state_id}")
 
257
  uri: str,
258
  params: dict[str, Any] | None,
259
  timeout: float | None = None,
260
+ ) -> tuple[str, bool]:
261
  params = {"priority": self.priority} | (params or {})
262
  response = await self.request("POST", f"skills/{uri}", json=params)
263
  state_id = response.json()["state"]
264
+ status = await self.sse_await(state_id, timeout=timeout)
265
+ return state_id, status