feat(client): support custom middleware injection (#1520)
* feat(client): support custom middleware injection Add support for custom middleware, allowing custom middleware list to be passed when initializing DeerFlowClient. These middleware will be injected after the default middleware when creating the agent, extending the agent's functionality. * feat: inject custom middlewares before ClarificationMiddleware to preserve ordering - Add `custom_middlewares` param to `_build_middlewares` - Inject custom middlewares right before `ClarificationMiddleware` to keep it as the last in the chain - Remove unsafe `.extend()` in `client.py` - Update tests in `test_client.py` and `test_lead_agent_model_resolution.py` to assert correct injection ordering
This commit is contained in:
@@ -63,13 +63,22 @@ class TestClientInit:
|
||||
assert client._agent is None
|
||||
|
||||
def test_custom_params(self, mock_app_config):
|
||||
mock_middleware = MagicMock()
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
c = DeerFlowClient(model_name="gpt-4", thinking_enabled=False, subagent_enabled=True, plan_mode=True, agent_name="test-agent")
|
||||
c = DeerFlowClient(
|
||||
model_name="gpt-4",
|
||||
thinking_enabled=False,
|
||||
subagent_enabled=True,
|
||||
plan_mode=True,
|
||||
agent_name="test-agent",
|
||||
middlewares=[mock_middleware]
|
||||
)
|
||||
assert c._model_name == "gpt-4"
|
||||
assert c._thinking_enabled is False
|
||||
assert c._subagent_enabled is True
|
||||
assert c._plan_mode is True
|
||||
assert c._agent_name == "test-agent"
|
||||
assert c._middlewares == [mock_middleware]
|
||||
|
||||
def test_invalid_agent_name(self, mock_app_config):
|
||||
with patch("deerflow.client.get_app_config", return_value=mock_app_config):
|
||||
@@ -413,6 +422,33 @@ class TestEnsureAgent:
|
||||
|
||||
assert mock_create_agent.call_args.kwargs["checkpointer"] is mock_checkpointer
|
||||
|
||||
def test_injects_custom_middlewares(self, client):
|
||||
mock_agent = MagicMock()
|
||||
mock_custom_middleware = MagicMock()
|
||||
client._middlewares = [mock_custom_middleware]
|
||||
config = client._get_runnable_config("t1")
|
||||
|
||||
mock_clarification = MagicMock()
|
||||
mock_clarification.__class__.__name__ = "ClarificationMiddleware"
|
||||
|
||||
def fake_build_middlewares(*args, **kwargs):
|
||||
custom = kwargs.get("custom_middlewares") or []
|
||||
return [MagicMock()] + custom + [mock_clarification]
|
||||
|
||||
with (
|
||||
patch("deerflow.client.create_chat_model"),
|
||||
patch("deerflow.client.create_agent", return_value=mock_agent) as mock_create_agent,
|
||||
patch("deerflow.client._build_middlewares", side_effect=fake_build_middlewares),
|
||||
patch("deerflow.client.apply_prompt_template", return_value="prompt"),
|
||||
patch.object(client, "_get_tools", return_value=[]),
|
||||
):
|
||||
client._ensure_agent(config)
|
||||
|
||||
called_middlewares = mock_create_agent.call_args.kwargs["middleware"]
|
||||
assert len(called_middlewares) == 3
|
||||
assert called_middlewares[-2] is mock_custom_middleware
|
||||
assert called_middlewares[-1] is mock_clarification
|
||||
|
||||
def test_skips_default_checkpointer_when_unconfigured(self, client):
|
||||
mock_agent = MagicMock()
|
||||
config = client._get_runnable_config("t1")
|
||||
|
||||
Reference in New Issue
Block a user