Compare commits

...

3 Commits

Author SHA1 Message Date
Xun 2ab2876580 fix: Plan model_validate throw exception in auto_accepted_plan (#1111)
* fix: Plan.model_validate throw exception in auto_accepted_plan

* improve log

* add UT

* fix ci

* reverse uv.lock

* add blank

* fix
2026-03-12 17:13:39 +08:00
Willem Jiang 172ba2d7ad Update branch name in workflow configuration (#869) 2026-02-16 21:32:24 +08:00
大猫子 423f5c829c fix: strip <think> tags from reporter output to prevent thinking text leakage (#781) (#862)
* fix: strip <think> tags from LLM output to prevent thinking text leakage (#781)

Some models (e.g. DeepSeek-R1, QwQ via ollama) embed reasoning in
content using <think>...</think> tags instead of the separate
reasoning_content field. This causes thinking text to leak into
both streamed messages and the final report.

Fix at two layers:
- server/app.py: strip <think> tags in _create_event_stream_message
  so ALL streamed content is filtered (coordinator, planner, etc.)
- graph/nodes.py: strip <think> tags in reporter_node before storing
  final_report (which is not streamed through the event layer)

The regex uses a fast-path check ("<think>" in content) to avoid
unnecessary regex calls on normal content.

* refactor: add defensive check for think tag stripping and add reporter_node tests (#781)

- Add isinstance and fast-path check in reporter_node before regex, consistent with app.py
- Add TestReporterNodeThinkTagStripping with 5 test cases covering various scenarios

* chore: re-trigger review
2026-02-16 09:38:17 +08:00
6 changed files with 321 additions and 14 deletions
+2 -2
View File
@@ -3,7 +3,7 @@ name: Publish Containers
on: on:
push: push:
branches: branches:
- main - main-1.x
release: release:
types: [published] types: [published]
workflow_dispatch: workflow_dispatch:
@@ -92,4 +92,4 @@ jobs:
with: with:
subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}}
subject-digest: ${{ steps.push.outputs.digest }} subject-digest: ${{ steps.push.outputs.digest }}
push-to-registry: true push-to-registry: true
+1
View File
@@ -1,4 +1,5 @@
{ {
"python.defaultInterpreterPath": "${workspaceFolder}/.venv/bin/python",
"python.testing.pytestArgs": [ "python.testing.pytestArgs": [
"tests" "tests"
], ],
+55 -9
View File
@@ -4,8 +4,10 @@
import json import json
import logging import logging
import os import os
import re
from functools import partial from functools import partial
from typing import Annotated, Any, Literal from typing import Annotated, Any, Literal
from pydantic import ValidationError
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.runnables import RunnableConfig from langchain_core.runnables import RunnableConfig
@@ -495,8 +497,9 @@ def human_feedback_node(
) )
# if the plan is accepted, run the following node # if the plan is accepted, run the following node
plan_iterations = state["plan_iterations"] if state.get("plan_iterations", 0) else 0 plan_iterations = (state.get("plan_iterations") or 0) + 1
goto = "research_team" goto = "research_team"
configurable = Configuration.from_runnable_config(config)
try: try:
# Safely extract plan content from different types (string, AIMessage, dict) # Safely extract plan content from different types (string, AIMessage, dict)
original_plan = current_plan original_plan = current_plan
@@ -507,18 +510,55 @@ def human_feedback_node(
current_plan = json.loads(current_plan) current_plan = json.loads(current_plan)
current_plan_content = extract_plan_content(current_plan) current_plan_content = extract_plan_content(current_plan)
# increment the plan iterations
plan_iterations += 1
# parse the plan # parse the plan
new_plan = json.loads(repair_json_output(current_plan_content)) new_plan = json.loads(repair_json_output(current_plan_content))
# Some models may return only a raw steps list instead of a full plan object.
# Normalize to Plan schema to avoid ValidationError in Plan.model_validate().
if isinstance(new_plan, list):
logger.warning("Planner returned plan as list; normalizing to dict with inferred metadata")
new_plan = {
"locale": state.get("locale", "en-US"),
"has_enough_context": False,
"thought": "",
"title": state.get("research_topic") or "Research Plan",
"steps": new_plan,
}
elif not isinstance(new_plan, dict):
raise ValueError(f"Unsupported plan type after parsing: {type(new_plan).__name__}")
# Fill required fields if partially missing.
new_plan.setdefault("locale", state.get("locale", "en-US"))
new_plan.setdefault("has_enough_context", False)
new_plan.setdefault("thought", "")
if not new_plan.get("title"):
new_plan["title"] = state.get("research_topic") or "Research Plan"
if "steps" not in new_plan or new_plan.get("steps") is None:
new_plan["steps"] = []
# Validate and fix plan to ensure web search requirements are met # Validate and fix plan to ensure web search requirements are met
configurable = Configuration.from_runnable_config(config) # after normalization so list-shaped plans are also enforced.
new_plan = validate_and_fix_plan(new_plan, configurable.enforce_web_search, configurable.enable_web_search) new_plan = validate_and_fix_plan(
except (json.JSONDecodeError, AttributeError, ValueError) as e: new_plan,
configurable.enforce_web_search,
configurable.enable_web_search,
)
validated_plan = Plan.model_validate(new_plan)
except (json.JSONDecodeError, AttributeError, ValueError, ValidationError) as e:
logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}") logger.warning(f"Failed to parse plan: {str(e)}. Plan data type: {type(current_plan).__name__}")
if isinstance(current_plan, dict) and "content" in original_plan: if isinstance(current_plan, dict) and "content" in original_plan:
logger.warning(f"Plan appears to be an AIMessage object with content field") logger.warning(f"Plan appears to be an AIMessage object with content field")
if plan_iterations > 1: # the plan_iterations is increased before this check if plan_iterations < configurable.max_plan_iterations:
return Command(
update={
"plan_iterations": plan_iterations,
**preserve_state_meta_fields(state),
},
goto="planner"
)
if plan_iterations > 1:
return Command( return Command(
update=preserve_state_meta_fields(state), update=preserve_state_meta_fields(state),
goto="reporter" goto="reporter"
@@ -531,7 +571,7 @@ def human_feedback_node(
# Build update dict with safe locale handling # Build update dict with safe locale handling
update_dict = { update_dict = {
"current_plan": Plan.model_validate(new_plan), "current_plan": validated_plan,
"plan_iterations": plan_iterations, "plan_iterations": plan_iterations,
**preserve_state_meta_fields(state), **preserve_state_meta_fields(state),
} }
@@ -900,7 +940,13 @@ def reporter_node(state: State, config: RunnableConfig):
logger.debug(f"Current invoke messages: {invoke_messages}") logger.debug(f"Current invoke messages: {invoke_messages}")
response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages) response = get_llm_by_type(AGENT_LLM_MAP["reporter"]).invoke(invoke_messages)
response_content = response.content response_content = response.content
logger.info(f"reporter response: {response_content}") # Strip <think>...</think> tags that some models (e.g. QwQ, DeepSeek) embed
# directly in content instead of using the reasoning_content field (#781)
if isinstance(response_content, str) and "<think>" in response_content:
response_content = re.sub(
r"<think>[\s\S]*?</think>", "", response_content
).strip()
logger.debug(f"reporter response length: {len(response_content)}")
return { return {
"final_report": response_content, "final_report": response_content,
+6
View File
@@ -6,6 +6,7 @@ import base64
import json import json
import logging import logging
import os import os
import re
from typing import Annotated, Any, List, Optional, cast from typing import Annotated, Any, List, Optional, cast
from uuid import uuid4 from uuid import uuid4
@@ -423,6 +424,11 @@ def _create_event_stream_message(
if not isinstance(content, str): if not isinstance(content, str):
content = json.dumps(content, ensure_ascii=False) content = json.dumps(content, ensure_ascii=False)
# Strip <think>...</think> tags that some models (e.g. DeepSeek-R1, QwQ via ollama)
# embed directly in content instead of using the reasoning_content field (#781)
if isinstance(content, str) and "<think>" in content:
content = re.sub(r"<think>[\s\S]*?</think>", "", content).strip()
event_stream_message = { event_stream_message = {
"thread_id": thread_id, "thread_id": thread_id,
"agent": agent_name, "agent": agent_name,
+206 -3
View File
@@ -2,6 +2,7 @@ import json
from collections import namedtuple from collections import namedtuple
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
from pydantic import ValidationError
import pytest import pytest
from src.graph.nodes import ( from src.graph.nodes import (
@@ -825,12 +826,102 @@ def test_human_feedback_node_json_decode_error_first_iteration(
state = dict(mock_state_base) state = dict(mock_state_base)
state["auto_accepted_plan"] = True state["auto_accepted_plan"] = True
state["plan_iterations"] = 0 state["plan_iterations"] = 0
with patch( mock_configurable = MagicMock()
"src.graph.nodes.json.loads", side_effect=json.JSONDecodeError("err", "doc", 0) mock_configurable.max_plan_iterations = 3
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.json.loads",
side_effect=json.JSONDecodeError("err", "doc", 0),
),
): ):
result = human_feedback_node(state, mock_config) result = human_feedback_node(state, mock_config)
assert isinstance(result, Command) assert isinstance(result, Command)
assert result.goto == "__end__" assert result.goto == "planner"
assert result.update["plan_iterations"] == 1
def test_human_feedback_node_model_validate_error(mock_state_base, mock_config):
# Plan.model_validate raises ValidationError, should enter error handling path
from pydantic import BaseModel
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
# Build a real ValidationError instance from pydantic
class DummyModel(BaseModel):
value: int
try:
DummyModel.model_validate({"value": "not_an_int"})
except ValidationError as validation_error:
raised_validation_error = validation_error
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
mock_configurable.enforce_web_search = False
mock_configurable.enable_web_search = True
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.Plan.model_validate",
side_effect=raised_validation_error,
),
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "planner"
assert result.update["plan_iterations"] == 1
def test_human_feedback_node_list_plan_runs_enforcement_after_normalization(
mock_state_base, mock_config
):
# Regression: when plan content is a list, normalization happens first,
# then validate_and_fix_plan must still run on the normalized dict.
raw_list_plan = [
{
"need_search": False,
"title": "Only Step",
"description": "Collect baseline info",
# intentionally missing step_type
}
]
state = dict(mock_state_base)
state["auto_accepted_plan"] = True
state["plan_iterations"] = 0
state["current_plan"] = json.dumps({"content": [json.dumps(raw_list_plan)]})
mock_configurable = MagicMock()
mock_configurable.max_plan_iterations = 3
mock_configurable.enforce_web_search = True
mock_configurable.enable_web_search = True
with patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
):
result = human_feedback_node(state, mock_config)
assert isinstance(result, Command)
assert result.goto == "research_team"
assert result.update["plan_iterations"] == 1
normalized_plan = result.update["current_plan"]
assert isinstance(normalized_plan, dict)
assert isinstance(normalized_plan.get("steps"), list)
assert len(normalized_plan["steps"]) == 1
# validate_and_fix_plan effects should be visible after normalization
assert normalized_plan["steps"][0]["step_type"] == "research"
assert normalized_plan["steps"][0]["need_search"] is True
def test_human_feedback_node_json_decode_error_second_iteration( def test_human_feedback_node_json_decode_error_second_iteration(
@@ -2823,3 +2914,115 @@ async def test_execute_agent_step_no_tool_calls_still_works():
# Verify step execution result is set # Verify step execution result is set
assert state["current_plan"].steps[0].execution_res == "Based on my knowledge, here is the answer without needing to search." assert state["current_plan"].steps[0].execution_res == "Based on my knowledge, here is the answer without needing to search."
class TestReporterNodeThinkTagStripping:
"""Tests for stripping <think> tags from reporter_node output (#781).
Some models (e.g. DeepSeek-R1, QwQ via ollama) embed reasoning in
content using <think>...</think> tags instead of the separate
reasoning_content field.
"""
def _make_mock_state(self):
plan = MagicMock()
plan.title = "Test Plan"
plan.thought = "Test Thought"
return {
"current_plan": plan,
"observations": [],
"citations": [],
"locale": "en-US",
}
def _run_reporter_node(self, response_content):
state = self._make_mock_state()
mock_response = MagicMock()
mock_response.content = response_content
mock_configurable = MagicMock()
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "test"}],
),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch("src.graph.nodes.get_llm_token_limit_by_type", return_value=4096),
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch(
"src.graph.nodes.ContextManager"
) as mock_ctx_mgr,
):
mock_ctx_mgr.return_value.compress_messages.return_value = {"messages": []}
mock_llm = MagicMock()
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = reporter_node(state, MagicMock())
return result
def test_strips_think_tag_at_beginning(self):
result = self._run_reporter_node(
"<think>\nLet me analyze...\n</think>\n\n# Report\n\nContent here."
)
assert "<think>" not in result["final_report"]
assert "# Report" in result["final_report"]
assert "Content here." in result["final_report"]
def test_strips_multiple_think_blocks(self):
result = self._run_reporter_node(
"<think>First thought</think>\nParagraph 1.\n<think>Second thought</think>\nParagraph 2."
)
assert "<think>" not in result["final_report"]
assert "Paragraph 1." in result["final_report"]
assert "Paragraph 2." in result["final_report"]
def test_preserves_content_without_think_tags(self):
result = self._run_reporter_node("Normal content without think tags.")
assert result["final_report"] == "Normal content without think tags."
def test_empty_content_after_stripping(self):
result = self._run_reporter_node(
"<think>Only thinking, no real content</think>"
)
assert "<think>" not in result["final_report"]
def test_non_string_content_passes_through(self):
"""Verify non-string content is not broken by the stripping logic."""
state = self._make_mock_state()
mock_response = MagicMock()
# Simulate non-string content (e.g. list from multimodal model)
mock_response.content = ["some", "list"]
mock_configurable = MagicMock()
with (
patch(
"src.graph.nodes.Configuration.from_runnable_config",
return_value=mock_configurable,
),
patch(
"src.graph.nodes.apply_prompt_template",
return_value=[{"role": "user", "content": "test"}],
),
patch("src.graph.nodes.get_llm_by_type") as mock_get_llm,
patch("src.graph.nodes.get_llm_token_limit_by_type", return_value=4096),
patch("src.graph.nodes.AGENT_LLM_MAP", {"reporter": "basic"}),
patch(
"src.graph.nodes.ContextManager"
) as mock_ctx_mgr,
):
mock_ctx_mgr.return_value.compress_messages.return_value = {"messages": []}
mock_llm = MagicMock()
mock_llm.invoke.return_value = mock_response
mock_get_llm.return_value = mock_llm
result = reporter_node(state, MagicMock())
# Non-string content should pass through unchanged
assert result["final_report"] == ["some", "list"]
+51
View File
@@ -16,6 +16,7 @@ from langgraph.types import Command
from src.config.report_style import ReportStyle from src.config.report_style import ReportStyle
from src.server.app import ( from src.server.app import (
_astream_workflow_generator, _astream_workflow_generator,
_create_event_stream_message,
_create_interrupt_event, _create_interrupt_event,
_make_event, _make_event,
_stream_graph_events, _stream_graph_events,
@@ -1680,3 +1681,53 @@ class TestGlobalConnectionPoolUsage:
"""Helper to create an empty async generator.""" """Helper to create an empty async generator."""
if False: if False:
yield yield
class TestCreateEventStreamMessageThinkTagStripping:
"""Tests for stripping <think> tags from streamed content (#781).
Some models (e.g. DeepSeek-R1, QwQ via ollama) embed reasoning in
content using <think>...</think> tags instead of the separate
reasoning_content field.
"""
def _make_mock_chunk(self, content):
chunk = AIMessageChunk(content=content)
chunk.id = "msg_test"
chunk.response_metadata = {}
return chunk
def test_strips_think_tag_at_beginning(self):
chunk = self._make_mock_chunk(
"<think>\nLet me analyze...\n</think>\n\n# Report\n\nContent here."
)
result = _create_event_stream_message(chunk, {}, "thread-1", "reporter")
assert "<think>" not in result["content"]
assert "# Report" in result["content"]
assert "Content here." in result["content"]
def test_strips_multiple_think_blocks(self):
chunk = self._make_mock_chunk(
"<think>First thought</think>\nParagraph 1.\n<think>Second thought</think>\nParagraph 2."
)
result = _create_event_stream_message(chunk, {}, "thread-1", "coordinator")
assert "<think>" not in result["content"]
assert "Paragraph 1." in result["content"]
assert "Paragraph 2." in result["content"]
def test_preserves_content_without_think_tags(self):
chunk = self._make_mock_chunk("Normal content without think tags.")
result = _create_event_stream_message(chunk, {}, "thread-1", "planner")
assert result["content"] == "Normal content without think tags."
def test_empty_content_after_stripping(self):
chunk = self._make_mock_chunk("<think>Only thinking, no real content</think>")
result = _create_event_stream_message(chunk, {}, "thread-1", "reporter")
assert "<think>" not in result["content"]
def test_preserves_reasoning_content_field(self):
chunk = self._make_mock_chunk("Actual content")
chunk.additional_kwargs["reasoning_content"] = "This is reasoning"
result = _create_event_stream_message(chunk, {}, "thread-1", "planner")
assert result["content"] == "Actual content"
assert result["reasoning_content"] == "This is reasoning"