mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-28 08:14:48 +08:00
feat: implement enhance prompt (#294)
* feat: implement enhance prompt * add unit test * fix prompt * fix: fix eslint and compiling issues * feat: add border-beam animation * fix: fix importing issues --------- Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
4
src/prompt_enhancer/__init__.py
Normal file
4
src/prompt_enhancer/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Prompt enhancer module for improving user prompts."""
|
||||
25
src/prompt_enhancer/graph/builder.py
Normal file
25
src/prompt_enhancer/graph/builder.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""Build and return the prompt enhancer workflow graph."""
|
||||
# Build state graph
|
||||
builder = StateGraph(PromptEnhancerState)
|
||||
|
||||
# Add the enhancer node
|
||||
builder.add_node("enhancer", prompt_enhancer_node)
|
||||
|
||||
# Set entry point
|
||||
builder.set_entry_point("enhancer")
|
||||
|
||||
# Set finish point
|
||||
builder.set_finish_point("enhancer")
|
||||
|
||||
# Compile and return the graph
|
||||
return builder.compile()
|
||||
67
src/prompt_enhancer/graph/enhancer_node.py
Normal file
67
src/prompt_enhancer/graph/enhancer_node.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts.template import env, apply_prompt_template
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prompt_enhancer_node(state: PromptEnhancerState):
|
||||
"""Node that enhances user prompts using AI analysis."""
|
||||
logger.info("Enhancing user prompt...")
|
||||
|
||||
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
|
||||
|
||||
try:
|
||||
|
||||
# Create messages with context if provided
|
||||
context_info = ""
|
||||
if state.get("context"):
|
||||
context_info = f"\n\nAdditional context: {state['context']}"
|
||||
|
||||
original_prompt_message = HumanMessage(
|
||||
content=f"Please enhance this prompt:{context_info}\n\nOriginal prompt: {state['prompt']}"
|
||||
)
|
||||
|
||||
messages = apply_prompt_template(
|
||||
"prompt_enhancer/prompt_enhancer",
|
||||
{
|
||||
"messages": [original_prompt_message],
|
||||
"report_style": state.get("report_style"),
|
||||
},
|
||||
)
|
||||
|
||||
# Get the response from the model
|
||||
response = model.invoke(messages)
|
||||
|
||||
# Clean up the response - remove any extra formatting or comments
|
||||
enhanced_prompt = response.content.strip()
|
||||
|
||||
# Remove common prefixes that might be added by the model
|
||||
prefixes_to_remove = [
|
||||
"Enhanced Prompt:",
|
||||
"Enhanced prompt:",
|
||||
"Here's the enhanced prompt:",
|
||||
"Here is the enhanced prompt:",
|
||||
"**Enhanced Prompt**:",
|
||||
"**Enhanced prompt**:",
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_remove:
|
||||
if enhanced_prompt.startswith(prefix):
|
||||
enhanced_prompt = enhanced_prompt[len(prefix) :].strip()
|
||||
break
|
||||
|
||||
logger.info("Prompt enhancement completed successfully")
|
||||
logger.debug(f"Enhanced prompt: {enhanced_prompt}")
|
||||
return {"output": enhanced_prompt}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prompt enhancement: {str(e)}")
|
||||
return {"output": state["prompt"]}
|
||||
14
src/prompt_enhancer/graph/state.py
Normal file
14
src/prompt_enhancer/graph/state.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import TypedDict, Optional
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class PromptEnhancerState(TypedDict):
|
||||
"""State for the prompt enhancer workflow."""
|
||||
|
||||
prompt: str # Original prompt to enhance
|
||||
context: Optional[str] # Additional context
|
||||
report_style: Optional[ReportStyle] # Report style preference
|
||||
output: Optional[str] # Enhanced prompt result
|
||||
Reference in New Issue
Block a user