Files
deer-flow/src/prompt_enhancer/graph/enhancer_node.py
T
DanielWalnut 1cd6aa0ece feat: implement enhance prompt (#294)
* feat: implement enhance prompt

* add unit test

* fix prompt

* fix: fix eslint and compiling issues

* feat: add border-beam animation

* fix: fix importing issues

---------

Co-authored-by: Henry Li <henry1943@163.com>
2025-06-08 19:41:59 +08:00

68 lines
2.2 KiB
Python

# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
from langchain.schema import HumanMessage, SystemMessage
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts.template import env, apply_prompt_template
from src.prompt_enhancer.graph.state import PromptEnhancerState
logger = logging.getLogger(__name__)
def prompt_enhancer_node(state: PromptEnhancerState):
"""Node that enhances user prompts using AI analysis."""
logger.info("Enhancing user prompt...")
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
try:
# Create messages with context if provided
context_info = ""
if state.get("context"):
context_info = f"\n\nAdditional context: {state['context']}"
original_prompt_message = HumanMessage(
content=f"Please enhance this prompt:{context_info}\n\nOriginal prompt: {state['prompt']}"
)
messages = apply_prompt_template(
"prompt_enhancer/prompt_enhancer",
{
"messages": [original_prompt_message],
"report_style": state.get("report_style"),
},
)
# Get the response from the model
response = model.invoke(messages)
# Clean up the response - remove any extra formatting or comments
enhanced_prompt = response.content.strip()
# Remove common prefixes that might be added by the model
prefixes_to_remove = [
"Enhanced Prompt:",
"Enhanced prompt:",
"Here's the enhanced prompt:",
"Here is the enhanced prompt:",
"**Enhanced Prompt**:",
"**Enhanced prompt**:",
]
for prefix in prefixes_to_remove:
if enhanced_prompt.startswith(prefix):
enhanced_prompt = enhanced_prompt[len(prefix) :].strip()
break
logger.info("Prompt enhancement completed successfully")
logger.debug(f"Enhanced prompt: {enhanced_prompt}")
return {"output": enhanced_prompt}
except Exception as e:
logger.error(f"Error in prompt enhancement: {str(e)}")
return {"output": state["prompt"]}