feat(run): Propagates model_name from the gateway request through the runtime and persistence stack to the SQLite database. (#2775)

* feat(run): propagate model_name from gateway request context to persistence layer

Pass model_name through the full run creation pipeline — from
RunCreateRequest.context in the gateway, through RunManager, to the
RunStore interface and SQL persistence. This enables client-specified
model selection to be recorded per-run in the database.

* feat(run): add model allowlist validation and effective model name capture

- Validate model_name against allowlist in gateway services.py using
  get_app_config().get_model_config()
- Truncate model_name to 128 chars to match DB column constraint
- In worker.py, capture effective model name from agent.metadata after
  agent creation and persist if resolved differently than requested

* feat(run): add defense-in-depth model_name normalization and round-trip persistence tests

- Add _normalize_model_name() to RunRepository for whitespace stripping
  and 128-char truncation before DB writes.
- Add round-trip unit tests for model_name creation and default None
  in test_run_manager.py.

* fix(run): coerce non-string model_name values before strip/truncate in _normalize_model_name

* fix(gateway): add runtime type guard for model_name coercion in gateway services

Add isinstance check and str() coercion before calling .strip() to prevent
AttributeError when non-string types (int, None, etc.) flow through the
gateway. Paired with SQL integration test for end-to-end model_name
persistence across gateway → langgraph → persistence layer.

* fix(run): drop Alembic migration for model_name (no-op) and expose public update method on RunManager

- Drop a1b2c3d4e5f6 migration: model_name already exists in RunRow schema
  and is auto-created via Base.metadata.create_all() at startup
- Add update_model_name() public method to RunManager to replace the private
  _persist_to_store call in worker.py, preserving internal locking/persistence
This commit is contained in:
Yi Tang
2026-05-11 14:45:18 +01:00
committed by GitHub
parent 2eb11f97ab
commit de253e4a0a
8 changed files with 143 additions and 0 deletions
@@ -23,6 +23,18 @@ class RunRepository(RunStore):
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
self._sf = session_factory
@staticmethod
def _normalize_model_name(model_name: str | None) -> str | None:
"""Normalize model_name for storage: strip whitespace, truncate to 128 chars."""
if model_name is None:
return None
if not isinstance(model_name, str):
model_name = str(model_name)
normalized = model_name.strip()
if len(normalized) > 128:
normalized = normalized[:128]
return normalized
@staticmethod
def _safe_json(obj: Any) -> Any:
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
@@ -70,6 +82,7 @@ class RunRepository(RunStore):
thread_id,
assistant_id=None,
user_id: str | None | _AutoSentinel = AUTO,
model_name: str | None = None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -85,6 +98,7 @@ class RunRepository(RunStore):
thread_id=thread_id,
assistant_id=assistant_id,
user_id=resolved_user_id,
model_name=self._normalize_model_name(model_name),
status=status,
multitask_strategy=multitask_strategy,
metadata_json=self._safe_json(metadata) or {},
@@ -36,6 +36,7 @@ class RunRecord:
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
abort_action: str = "interrupt"
error: str | None = None
model_name: str | None = None
class RunManager:
@@ -65,6 +66,7 @@ class RunManager:
metadata=record.metadata or {},
kwargs=record.kwargs or {},
created_at=record.created_at,
model_name=record.model_name,
)
except Exception:
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
@@ -137,6 +139,18 @@ class RunManager:
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
logger.info("Run %s -> %s", run_id, status.value)
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
"""Update the model name for a run."""
async with self._lock:
record = self._runs.get(run_id)
if record is None:
logger.warning("update_model_name called for unknown run %s", run_id)
return
record.model_name = model_name
record.updated_at = _now_iso()
await self._persist_to_store(record)
logger.info("Run %s model_name=%s", run_id, model_name)
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
"""Request cancellation of a run.
@@ -171,6 +185,7 @@ class RunManager:
metadata: dict | None = None,
kwargs: dict | None = None,
multitask_strategy: str = "reject",
model_name: str | None = None,
) -> RunRecord:
"""Atomically check for inflight runs and create a new one.
@@ -221,6 +236,7 @@ class RunManager:
kwargs=kwargs or {},
created_at=now,
updated_at=now,
model_name=model_name,
)
self._runs[run_id] = record
@@ -23,6 +23,7 @@ class RunStore(abc.ABC):
thread_id: str,
assistant_id: str | None = None,
user_id: str | None = None,
model_name: str | None = None,
status: str = "pending",
multitask_strategy: str = "reject",
metadata: dict[str, Any] | None = None,
@@ -22,6 +22,7 @@ class MemoryRunStore(RunStore):
thread_id,
assistant_id=None,
user_id=None,
model_name=None,
status="pending",
multitask_strategy="reject",
metadata=None,
@@ -35,6 +36,7 @@ class MemoryRunStore(RunStore):
"thread_id": thread_id,
"assistant_id": assistant_id,
"user_id": user_id,
"model_name": model_name,
"status": status,
"multitask_strategy": multitask_strategy,
"metadata": metadata or {},
@@ -230,6 +230,17 @@ async def run_agent(
else:
agent = agent_factory(config=runnable_config)
# Capture the effective (resolved) model name from the agent's metadata.
# _resolve_model_name in agent.py may return the default model if the
# requested name is not in the allowlist — this update ensures the
# persisted model_name reflects the actual model used.
if record.model_name is not None:
resolved = getattr(agent, "metadata", {}) or {}
if isinstance(resolved, dict):
effective = resolved.get("model_name")
if effective and effective != record.model_name:
await run_manager.update_model_name(record.run_id, effective)
# 4. Attach checkpointer and store
if checkpointer is not None:
agent.checkpointer = checkpointer