feat(run): Propagates model_name from the gateway request through the runtime and persistence stack to the SQLite database. (#2775)
* feat(run): propagate model_name from gateway request context to persistence layer Pass model_name through the full run creation pipeline — from RunCreateRequest.context in the gateway, through RunManager, to the RunStore interface and SQL persistence. This enables client-specified model selection to be recorded per-run in the database. * feat(run): add model allowlist validation and effective model name capture - Validate model_name against allowlist in gateway services.py using get_app_config().get_model_config() - Truncate model_name to 128 chars to match DB column constraint - In worker.py, capture effective model name from agent.metadata after agent creation and persist if resolved differently than requested * feat(run): add defense-in-depth model_name normalization and round-trip persistence tests - Add _normalize_model_name() to RunRepository for whitespace stripping and 128-char truncation before DB writes. - Add round-trip unit tests for model_name creation and default None in test_run_manager.py. * fix(run): coerce non-string model_name values before strip/truncate in _normalize_model_name * fix(gateway): add runtime type guard for model_name coercion in gateway services Add isinstance check and str() coercion before calling .strip() to prevent AttributeError when non-string types (int, None, etc.) flow through the gateway. Paired with SQL integration test for end-to-end model_name persistence across gateway → langgraph → persistence layer. * fix(run): drop Alembic migration for model_name (no-op) and expose public update method on RunManager - Drop a1b2c3d4e5f6 migration: model_name already exists in RunRow schema and is auto-created via Base.metadata.create_all() at startup - Add update_model_name() public method to RunManager to replace the private _persist_to_store call in worker.py, preserving internal locking/persistence
This commit is contained in:
@@ -23,6 +23,18 @@ class RunRepository(RunStore):
|
||||
def __init__(self, session_factory: async_sessionmaker[AsyncSession]) -> None:
|
||||
self._sf = session_factory
|
||||
|
||||
@staticmethod
|
||||
def _normalize_model_name(model_name: str | None) -> str | None:
|
||||
"""Normalize model_name for storage: strip whitespace, truncate to 128 chars."""
|
||||
if model_name is None:
|
||||
return None
|
||||
if not isinstance(model_name, str):
|
||||
model_name = str(model_name)
|
||||
normalized = model_name.strip()
|
||||
if len(normalized) > 128:
|
||||
normalized = normalized[:128]
|
||||
return normalized
|
||||
|
||||
@staticmethod
|
||||
def _safe_json(obj: Any) -> Any:
|
||||
"""Ensure obj is JSON-serializable. Falls back to model_dump() or str()."""
|
||||
@@ -70,6 +82,7 @@ class RunRepository(RunStore):
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id: str | None | _AutoSentinel = AUTO,
|
||||
model_name: str | None = None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -85,6 +98,7 @@ class RunRepository(RunStore):
|
||||
thread_id=thread_id,
|
||||
assistant_id=assistant_id,
|
||||
user_id=resolved_user_id,
|
||||
model_name=self._normalize_model_name(model_name),
|
||||
status=status,
|
||||
multitask_strategy=multitask_strategy,
|
||||
metadata_json=self._safe_json(metadata) or {},
|
||||
|
||||
@@ -36,6 +36,7 @@ class RunRecord:
|
||||
abort_event: asyncio.Event = field(default_factory=asyncio.Event, repr=False)
|
||||
abort_action: str = "interrupt"
|
||||
error: str | None = None
|
||||
model_name: str | None = None
|
||||
|
||||
|
||||
class RunManager:
|
||||
@@ -65,6 +66,7 @@ class RunManager:
|
||||
metadata=record.metadata or {},
|
||||
kwargs=record.kwargs or {},
|
||||
created_at=record.created_at,
|
||||
model_name=record.model_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning("Failed to persist run %s to store", record.run_id, exc_info=True)
|
||||
@@ -137,6 +139,18 @@ class RunManager:
|
||||
logger.warning("Failed to persist status update for run %s", run_id, exc_info=True)
|
||||
logger.info("Run %s -> %s", run_id, status.value)
|
||||
|
||||
async def update_model_name(self, run_id: str, model_name: str | None) -> None:
|
||||
"""Update the model name for a run."""
|
||||
async with self._lock:
|
||||
record = self._runs.get(run_id)
|
||||
if record is None:
|
||||
logger.warning("update_model_name called for unknown run %s", run_id)
|
||||
return
|
||||
record.model_name = model_name
|
||||
record.updated_at = _now_iso()
|
||||
await self._persist_to_store(record)
|
||||
logger.info("Run %s model_name=%s", run_id, model_name)
|
||||
|
||||
async def cancel(self, run_id: str, *, action: str = "interrupt") -> bool:
|
||||
"""Request cancellation of a run.
|
||||
|
||||
@@ -171,6 +185,7 @@ class RunManager:
|
||||
metadata: dict | None = None,
|
||||
kwargs: dict | None = None,
|
||||
multitask_strategy: str = "reject",
|
||||
model_name: str | None = None,
|
||||
) -> RunRecord:
|
||||
"""Atomically check for inflight runs and create a new one.
|
||||
|
||||
@@ -221,6 +236,7 @@ class RunManager:
|
||||
kwargs=kwargs or {},
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
model_name=model_name,
|
||||
)
|
||||
self._runs[run_id] = record
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ class RunStore(abc.ABC):
|
||||
thread_id: str,
|
||||
assistant_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
model_name: str | None = None,
|
||||
status: str = "pending",
|
||||
multitask_strategy: str = "reject",
|
||||
metadata: dict[str, Any] | None = None,
|
||||
|
||||
@@ -22,6 +22,7 @@ class MemoryRunStore(RunStore):
|
||||
thread_id,
|
||||
assistant_id=None,
|
||||
user_id=None,
|
||||
model_name=None,
|
||||
status="pending",
|
||||
multitask_strategy="reject",
|
||||
metadata=None,
|
||||
@@ -35,6 +36,7 @@ class MemoryRunStore(RunStore):
|
||||
"thread_id": thread_id,
|
||||
"assistant_id": assistant_id,
|
||||
"user_id": user_id,
|
||||
"model_name": model_name,
|
||||
"status": status,
|
||||
"multitask_strategy": multitask_strategy,
|
||||
"metadata": metadata or {},
|
||||
|
||||
@@ -230,6 +230,17 @@ async def run_agent(
|
||||
else:
|
||||
agent = agent_factory(config=runnable_config)
|
||||
|
||||
# Capture the effective (resolved) model name from the agent's metadata.
|
||||
# _resolve_model_name in agent.py may return the default model if the
|
||||
# requested name is not in the allowlist — this update ensures the
|
||||
# persisted model_name reflects the actual model used.
|
||||
if record.model_name is not None:
|
||||
resolved = getattr(agent, "metadata", {}) or {}
|
||||
if isinstance(resolved, dict):
|
||||
effective = resolved.get("model_name")
|
||||
if effective and effective != record.model_name:
|
||||
await run_manager.update_model_name(record.run_id, effective)
|
||||
|
||||
# 4. Attach checkpointer and store
|
||||
if checkpointer is not None:
|
||||
agent.checkpointer = checkpointer
|
||||
|
||||
Reference in New Issue
Block a user