Oysiyl commited on
Commit
75f5f49
·
1 Parent(s): ee395c3

fix modal analytics source attribution

Browse files
modal_service.py CHANGED
@@ -99,7 +99,7 @@ image = (
99
  )
100
  @modal.asgi_app()
101
  def api():
102
- from fastapi import FastAPI, HTTPException
103
  from fastapi.concurrency import run_in_threadpool
104
 
105
  state: dict[str, Any] = {
@@ -148,20 +148,24 @@ def api():
148
  }
149
 
150
  @web_app.post("/generate")
151
- async def generate(request: GenerateRequest) -> dict[str, Any]:
 
 
152
  if not state["ready"] or state["backend"] is None:
153
  raise HTTPException(
154
  status_code=503, detail=state["import_error"] or "Backend not ready"
155
  )
156
 
157
- actual_seed = resolve_request_seed(request)
158
- prepared_request = request.model_copy(
159
  update={"seed": actual_seed, "use_custom_seed": True}
160
  )
161
 
162
  def _run_generation() -> tuple[Any, str, dict[str, Any] | None]:
163
  backend = state["backend"]
164
- kwargs = build_generation_kwargs(prepared_request)
 
 
165
  if prepared_request.mode == "artistic":
166
  generator = backend.generate_artistic_qr(**kwargs)
167
  else:
@@ -178,7 +182,7 @@ def api():
178
  return build_response_payload(
179
  image_obj,
180
  final_status,
181
- request,
182
  actual_seed=actual_seed,
183
  elapsed=elapsed,
184
  settings=settings,
 
99
  )
100
  @modal.asgi_app()
101
  def api():
102
+ from fastapi import FastAPI, HTTPException, Request
103
  from fastapi.concurrency import run_in_threadpool
104
 
105
  state: dict[str, Any] = {
 
148
  }
149
 
150
  @web_app.post("/generate")
151
+ async def generate(
152
+ payload: GenerateRequest, raw_request: Request
153
+ ) -> dict[str, Any]:
154
  if not state["ready"] or state["backend"] is None:
155
  raise HTTPException(
156
  status_code=503, detail=state["import_error"] or "Backend not ready"
157
  )
158
 
159
+ actual_seed = resolve_request_seed(payload)
160
+ prepared_request = payload.model_copy(
161
  update={"seed": actual_seed, "use_custom_seed": True}
162
  )
163
 
164
  def _run_generation() -> tuple[Any, str, dict[str, Any] | None]:
165
  backend = state["backend"]
166
+ kwargs = build_generation_kwargs(
167
+ prepared_request, runtime_request=raw_request
168
+ )
169
  if prepared_request.mode == "artistic":
170
  generator = backend.generate_artistic_qr(**kwargs)
171
  else:
 
182
  return build_response_payload(
183
  image_obj,
184
  final_status,
185
+ payload,
186
  actual_seed=actual_seed,
187
  elapsed=elapsed,
188
  settings=settings,
qr_modal_contract.py CHANGED
@@ -66,7 +66,9 @@ class GenerateRequest(BaseModel):
66
  tile_pyrup_iters: int = Field(default=3, ge=1, le=4)
67
 
68
 
69
- def build_generation_kwargs(request: GenerateRequest) -> dict[str, Any]:
 
 
70
  common = {
71
  "prompt": request.prompt,
72
  "negative_prompt": request.negative_prompt,
@@ -93,7 +95,7 @@ def build_generation_kwargs(request: GenerateRequest) -> dict[str, Any]:
93
  "gradient_strength": request.gradient_strength,
94
  "variation_steps": request.variation_steps,
95
  "progress": None,
96
- "request": None,
97
  }
98
 
99
  if request.mode == "artistic":
 
66
  tile_pyrup_iters: int = Field(default=3, ge=1, le=4)
67
 
68
 
69
+ def build_generation_kwargs(
70
+ request: GenerateRequest, runtime_request: Any | None = None
71
+ ) -> dict[str, Any]:
72
  common = {
73
  "prompt": request.prompt,
74
  "negative_prompt": request.negative_prompt,
 
95
  "gradient_strength": request.gradient_strength,
96
  "variation_steps": request.variation_steps,
97
  "progress": None,
98
+ "request": runtime_request,
99
  }
100
 
101
  if request.mode == "artistic":
tests-unit/modal_service_test/test_qr_modal_contract.py CHANGED
@@ -22,6 +22,19 @@ def test_build_generation_kwargs_includes_short_link_flag_and_wrapper_fields():
22
  assert kwargs["analytics_opt_in"] is False
23
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def test_build_generation_kwargs_for_standard_omits_artistic_only_fields():
26
  request = GenerateRequest(
27
  mode="standard",
 
22
  assert kwargs["analytics_opt_in"] is False
23
 
24
 
25
+ def test_build_generation_kwargs_passes_request_object_when_provided():
26
+ request = GenerateRequest(
27
+ mode="standard",
28
+ prompt="poster",
29
+ qr_text="https://example.com",
30
+ )
31
+ runtime_request = object()
32
+
33
+ kwargs = build_generation_kwargs(request, runtime_request=runtime_request)
34
+
35
+ assert kwargs["request"] is runtime_request
36
+
37
+
38
  def test_build_generation_kwargs_for_standard_omits_artistic_only_fields():
39
  request = GenerateRequest(
40
  mode="standard",