1cd6aa0ece
* feat: implement enhance prompt * add unit test * fix prompt * fix: fix eslint and compiling issues * feat: add border-beam animation * fix: fix importing issues --------- Co-authored-by: Henry Li <henry1943@163.com>
68 lines
2.2 KiB
Python
68 lines
2.2 KiB
Python
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
|
# SPDX-License-Identifier: MIT
|
|
|
|
import logging
|
|
|
|
from langchain.schema import HumanMessage, SystemMessage
|
|
|
|
from src.config.agents import AGENT_LLM_MAP
|
|
from src.llms.llm import get_llm_by_type
|
|
from src.prompts.template import env, apply_prompt_template
|
|
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def prompt_enhancer_node(state: PromptEnhancerState):
|
|
"""Node that enhances user prompts using AI analysis."""
|
|
logger.info("Enhancing user prompt...")
|
|
|
|
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
|
|
|
|
try:
|
|
|
|
# Create messages with context if provided
|
|
context_info = ""
|
|
if state.get("context"):
|
|
context_info = f"\n\nAdditional context: {state['context']}"
|
|
|
|
original_prompt_message = HumanMessage(
|
|
content=f"Please enhance this prompt:{context_info}\n\nOriginal prompt: {state['prompt']}"
|
|
)
|
|
|
|
messages = apply_prompt_template(
|
|
"prompt_enhancer/prompt_enhancer",
|
|
{
|
|
"messages": [original_prompt_message],
|
|
"report_style": state.get("report_style"),
|
|
},
|
|
)
|
|
|
|
# Get the response from the model
|
|
response = model.invoke(messages)
|
|
|
|
# Clean up the response - remove any extra formatting or comments
|
|
enhanced_prompt = response.content.strip()
|
|
|
|
# Remove common prefixes that might be added by the model
|
|
prefixes_to_remove = [
|
|
"Enhanced Prompt:",
|
|
"Enhanced prompt:",
|
|
"Here's the enhanced prompt:",
|
|
"Here is the enhanced prompt:",
|
|
"**Enhanced Prompt**:",
|
|
"**Enhanced prompt**:",
|
|
]
|
|
|
|
for prefix in prefixes_to_remove:
|
|
if enhanced_prompt.startswith(prefix):
|
|
enhanced_prompt = enhanced_prompt[len(prefix) :].strip()
|
|
break
|
|
|
|
logger.info("Prompt enhancement completed successfully")
|
|
logger.debug(f"Enhanced prompt: {enhanced_prompt}")
|
|
return {"output": enhanced_prompt}
|
|
except Exception as e:
|
|
logger.error(f"Error in prompt enhancement: {str(e)}")
|
|
return {"output": state["prompt"]}
|