mirror of
https://github.com/bytedance/deer-flow.git
synced 2026-05-24 17:06:00 +00:00
style: apply ruff format to persistence and runtime files
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -616,7 +616,6 @@ async def get_thread_history(thread_id: str, body: ThreadHistoryRequest, request
|
|||||||
logger.warning("Failed to load messages from event store for thread %s", _sanitize_log_param(thread_id), exc_info=True)
|
logger.warning("Failed to load messages from event store for thread %s", _sanitize_log_param(thread_id), exc_info=True)
|
||||||
all_messages = []
|
all_messages = []
|
||||||
|
|
||||||
|
|
||||||
entries: list[HistoryEntry] = []
|
entries: list[HistoryEntry] = []
|
||||||
is_latest_checkpoint = True
|
is_latest_checkpoint = True
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ def _json_serializer(obj: object) -> str:
|
|||||||
"""JSON serializer with ensure_ascii=False for Chinese character support."""
|
"""JSON serializer with ensure_ascii=False for Chinese character support."""
|
||||||
return json.dumps(obj, ensure_ascii=False)
|
return json.dumps(obj, ensure_ascii=False)
|
||||||
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_engine: AsyncEngine | None = None
|
_engine: AsyncEngine | None = None
|
||||||
|
|||||||
@@ -21,10 +21,7 @@ try:
|
|||||||
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
|
import deerflow.persistence.models # noqa: F401 — register ORM models with Base.metadata
|
||||||
except ImportError:
|
except ImportError:
|
||||||
# Models not available — migration will work with existing metadata only.
|
# Models not available — migration will work with existing metadata only.
|
||||||
logging.getLogger(__name__).warning(
|
logging.getLogger(__name__).warning("Could not import deerflow.persistence.models; Alembic may not detect all tables")
|
||||||
"Could not import deerflow.persistence.models; "
|
|
||||||
"Alembic may not detect all tables"
|
|
||||||
)
|
|
||||||
|
|
||||||
config = context.config
|
config = context.config
|
||||||
if config.config_file_name is not None:
|
if config.config_file_name is not None:
|
||||||
|
|||||||
@@ -99,11 +99,7 @@ class ThreadMetaRepository:
|
|||||||
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
async def update_display_name(self, thread_id: str, display_name: str) -> None:
|
||||||
"""Update the display_name (title) for a thread."""
|
"""Update the display_name (title) for a thread."""
|
||||||
async with self._sf() as session:
|
async with self._sf() as session:
|
||||||
await session.execute(
|
await session.execute(update(ThreadMetaRow).where(ThreadMetaRow.thread_id == thread_id).values(display_name=display_name, updated_at=datetime.now(UTC)))
|
||||||
update(ThreadMetaRow)
|
|
||||||
.where(ThreadMetaRow.thread_id == thread_id)
|
|
||||||
.values(display_name=display_name, updated_at=datetime.now(UTC))
|
|
||||||
)
|
|
||||||
await session.commit()
|
await session.commit()
|
||||||
|
|
||||||
async def update_status(self, thread_id: str, status: str) -> None:
|
async def update_status(self, thread_id: str, status: str) -> None:
|
||||||
|
|||||||
@@ -47,14 +47,16 @@ def langchain_to_openai_message(message: Any) -> dict:
|
|||||||
openai_tool_calls = []
|
openai_tool_calls = []
|
||||||
for tc in tool_calls:
|
for tc in tool_calls:
|
||||||
args = tc.get("args", {})
|
args = tc.get("args", {})
|
||||||
openai_tool_calls.append({
|
openai_tool_calls.append(
|
||||||
|
{
|
||||||
"id": tc.get("id", ""),
|
"id": tc.get("id", ""),
|
||||||
"type": "function",
|
"type": "function",
|
||||||
"function": {
|
"function": {
|
||||||
"name": tc.get("name", ""),
|
"name": tc.get("name", ""),
|
||||||
"arguments": json.dumps(args) if not isinstance(args, str) else args,
|
"arguments": json.dumps(args) if not isinstance(args, str) else args,
|
||||||
},
|
},
|
||||||
})
|
}
|
||||||
|
)
|
||||||
# If no text content, set content to null per OpenAI spec
|
# If no text content, set content to null per OpenAI spec
|
||||||
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
|
result["content"] = content if (isinstance(content, list) and content) or (isinstance(content, str) and content) else None
|
||||||
result["tool_calls"] = openai_tool_calls
|
result["tool_calls"] = openai_tool_calls
|
||||||
|
|||||||
@@ -65,11 +65,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
# Use FOR UPDATE to serialize seq assignment within a thread.
|
# Use FOR UPDATE to serialize seq assignment within a thread.
|
||||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||||
max_seq = await session.scalar(
|
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||||
select(func.max(RunEventRow.seq))
|
|
||||||
.where(RunEventRow.thread_id == thread_id)
|
|
||||||
.with_for_update()
|
|
||||||
)
|
|
||||||
seq = (max_seq or 0) + 1
|
seq = (max_seq or 0) + 1
|
||||||
row = RunEventRow(
|
row = RunEventRow(
|
||||||
thread_id=thread_id,
|
thread_id=thread_id,
|
||||||
@@ -93,11 +89,7 @@ class DbRunEventStore(RunEventStore):
|
|||||||
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
# NOTE: with_for_update() on aggregates is a no-op on SQLite;
|
||||||
# the UNIQUE(thread_id, seq) constraint catches races there.
|
# the UNIQUE(thread_id, seq) constraint catches races there.
|
||||||
thread_id = events[0]["thread_id"]
|
thread_id = events[0]["thread_id"]
|
||||||
max_seq = await session.scalar(
|
max_seq = await session.scalar(select(func.max(RunEventRow.seq)).where(RunEventRow.thread_id == thread_id).with_for_update())
|
||||||
select(func.max(RunEventRow.seq))
|
|
||||||
.where(RunEventRow.thread_id == thread_id)
|
|
||||||
.with_for_update()
|
|
||||||
)
|
|
||||||
seq = max_seq or 0
|
seq = max_seq or 0
|
||||||
rows = []
|
rows = []
|
||||||
for e in events:
|
for e in events:
|
||||||
|
|||||||
@@ -357,7 +357,8 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
# -- Internal methods --
|
# -- Internal methods --
|
||||||
|
|
||||||
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
def _put(self, *, event_type: str, category: str, content: str | dict = "", metadata: dict | None = None) -> None:
|
||||||
self._buffer.append({
|
self._buffer.append(
|
||||||
|
{
|
||||||
"thread_id": self.thread_id,
|
"thread_id": self.thread_id,
|
||||||
"run_id": self.run_id,
|
"run_id": self.run_id,
|
||||||
"event_type": event_type,
|
"event_type": event_type,
|
||||||
@@ -365,7 +366,8 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
"content": content,
|
"content": content,
|
||||||
"metadata": metadata or {},
|
"metadata": metadata or {},
|
||||||
"created_at": datetime.now(UTC).isoformat(),
|
"created_at": datetime.now(UTC).isoformat(),
|
||||||
})
|
}
|
||||||
|
)
|
||||||
if len(self._buffer) >= self._flush_threshold:
|
if len(self._buffer) >= self._flush_threshold:
|
||||||
self._flush_sync()
|
self._flush_sync()
|
||||||
|
|
||||||
@@ -395,7 +397,9 @@ class RunJournal(BaseCallbackHandler):
|
|||||||
except Exception:
|
except Exception:
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"Failed to flush %d events for run %s — returning to buffer",
|
"Failed to flush %d events for run %s — returning to buffer",
|
||||||
len(batch), self.run_id, exc_info=True,
|
len(batch),
|
||||||
|
self.run_id,
|
||||||
|
exc_info=True,
|
||||||
)
|
)
|
||||||
# Return failed events to buffer for retry on next flush
|
# Return failed events to buffer for retry on next flush
|
||||||
self._buffer = batch + self._buffer
|
self._buffer = batch + self._buffer
|
||||||
|
|||||||
@@ -41,7 +41,10 @@ class TestFeedbackRepository:
|
|||||||
async def test_create_negative_with_comment(self, tmp_path):
|
async def test_create_negative_with_comment(self, tmp_path):
|
||||||
repo = await _make_feedback_repo(tmp_path)
|
repo = await _make_feedback_repo(tmp_path)
|
||||||
record = await repo.create(
|
record = await repo.create(
|
||||||
run_id="r1", thread_id="t1", rating=-1, comment="Response was inaccurate",
|
run_id="r1",
|
||||||
|
thread_id="t1",
|
||||||
|
rating=-1,
|
||||||
|
comment="Response was inaccurate",
|
||||||
)
|
)
|
||||||
assert record["rating"] == -1
|
assert record["rating"] == -1
|
||||||
assert record["comment"] == "Response was inaccurate"
|
assert record["comment"] == "Response was inaccurate"
|
||||||
|
|||||||
@@ -947,8 +947,10 @@ class TestFullRunSequence:
|
|||||||
# 1. Human message (written by worker, using model_dump format)
|
# 1. Human message (written by worker, using model_dump format)
|
||||||
human_msg = HumanMessage(content="Search for quantum computing")
|
human_msg = HumanMessage(content="Search for quantum computing")
|
||||||
await store.put(
|
await store.put(
|
||||||
thread_id="t1", run_id="r1",
|
thread_id="t1",
|
||||||
event_type="human_message", category="message",
|
run_id="r1",
|
||||||
|
event_type="human_message",
|
||||||
|
category="message",
|
||||||
content=human_msg.model_dump(),
|
content=human_msg.model_dump(),
|
||||||
)
|
)
|
||||||
j.set_first_human_message("Search for quantum computing")
|
j.set_first_human_message("Search for quantum computing")
|
||||||
|
|||||||
Reference in New Issue
Block a user