feat: implement enhance prompt (#294)

* feat: implement enhance prompt

* add unit test

* fix prompt

* fix: fix eslint and compiling issues

* feat: add border-beam animation

* fix: fix importing issues

---------

Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
DanielWalnut
2025-06-08 19:41:59 +08:00
committed by GitHub
parent 8081a14c21
commit 1cd6aa0ece
19 changed files with 1100 additions and 4 deletions

View File

@@ -0,0 +1,4 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
"""Prompt enhancer module for improving user prompts."""

View File

@@ -0,0 +1,25 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from langgraph.graph import StateGraph
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
from src.prompt_enhancer.graph.state import PromptEnhancerState
def build_graph():
"""Build and return the prompt enhancer workflow graph."""
# Build state graph
builder = StateGraph(PromptEnhancerState)
# Add the enhancer node
builder.add_node("enhancer", prompt_enhancer_node)
# Set entry point
builder.set_entry_point("enhancer")
# Set finish point
builder.set_finish_point("enhancer")
# Compile and return the graph
return builder.compile()

View File

@@ -0,0 +1,67 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
import logging
from langchain.schema import HumanMessage, SystemMessage
from src.config.agents import AGENT_LLM_MAP
from src.llms.llm import get_llm_by_type
from src.prompts.template import env, apply_prompt_template
from src.prompt_enhancer.graph.state import PromptEnhancerState
logger = logging.getLogger(__name__)
def prompt_enhancer_node(state: PromptEnhancerState):
"""Node that enhances user prompts using AI analysis."""
logger.info("Enhancing user prompt...")
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
try:
# Create messages with context if provided
context_info = ""
if state.get("context"):
context_info = f"\n\nAdditional context: {state['context']}"
original_prompt_message = HumanMessage(
content=f"Please enhance this prompt:{context_info}\n\nOriginal prompt: {state['prompt']}"
)
messages = apply_prompt_template(
"prompt_enhancer/prompt_enhancer",
{
"messages": [original_prompt_message],
"report_style": state.get("report_style"),
},
)
# Get the response from the model
response = model.invoke(messages)
# Clean up the response - remove any extra formatting or comments
enhanced_prompt = response.content.strip()
# Remove common prefixes that might be added by the model
prefixes_to_remove = [
"Enhanced Prompt:",
"Enhanced prompt:",
"Here's the enhanced prompt:",
"Here is the enhanced prompt:",
"**Enhanced Prompt**:",
"**Enhanced prompt**:",
]
for prefix in prefixes_to_remove:
if enhanced_prompt.startswith(prefix):
enhanced_prompt = enhanced_prompt[len(prefix) :].strip()
break
logger.info("Prompt enhancement completed successfully")
logger.debug(f"Enhanced prompt: {enhanced_prompt}")
return {"output": enhanced_prompt}
except Exception as e:
logger.error(f"Error in prompt enhancement: {str(e)}")
return {"output": state["prompt"]}

View File

@@ -0,0 +1,14 @@
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from typing import TypedDict, Optional
from src.config.report_style import ReportStyle
class PromptEnhancerState(TypedDict):
"""State for the prompt enhancer workflow."""
prompt: str # Original prompt to enhance
context: Optional[str] # Additional context
report_style: Optional[ReportStyle] # Report style preference
output: Optional[str] # Enhanced prompt result