mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-04 06:32:13 +08:00
feat: implement enhance prompt (#294)
* feat: implement enhance prompt * add unit test * fix prompt * fix: fix eslint and compiling issues * feat: add border-beam animation * fix: fix importing issues --------- Co-authored-by: Henry Li <henry1943@163.com>
This commit is contained in:
@@ -16,4 +16,5 @@ AGENT_LLM_MAP: dict[str, LLMType] = {
|
||||
"podcast_script_writer": "basic",
|
||||
"ppt_composer": "basic",
|
||||
"prose_writer": "basic",
|
||||
"prompt_enhancer": "basic",
|
||||
}
|
||||
|
||||
4
src/prompt_enhancer/__init__.py
Normal file
4
src/prompt_enhancer/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
"""Prompt enhancer module for improving user prompts."""
|
||||
25
src/prompt_enhancer/graph/builder.py
Normal file
25
src/prompt_enhancer/graph/builder.py
Normal file
@@ -0,0 +1,25 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from langgraph.graph import StateGraph
|
||||
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
|
||||
def build_graph():
|
||||
"""Build and return the prompt enhancer workflow graph."""
|
||||
# Build state graph
|
||||
builder = StateGraph(PromptEnhancerState)
|
||||
|
||||
# Add the enhancer node
|
||||
builder.add_node("enhancer", prompt_enhancer_node)
|
||||
|
||||
# Set entry point
|
||||
builder.set_entry_point("enhancer")
|
||||
|
||||
# Set finish point
|
||||
builder.set_finish_point("enhancer")
|
||||
|
||||
# Compile and return the graph
|
||||
return builder.compile()
|
||||
67
src/prompt_enhancer/graph/enhancer_node.py
Normal file
67
src/prompt_enhancer/graph/enhancer_node.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import logging
|
||||
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.config.agents import AGENT_LLM_MAP
|
||||
from src.llms.llm import get_llm_by_type
|
||||
from src.prompts.template import env, apply_prompt_template
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def prompt_enhancer_node(state: PromptEnhancerState):
|
||||
"""Node that enhances user prompts using AI analysis."""
|
||||
logger.info("Enhancing user prompt...")
|
||||
|
||||
model = get_llm_by_type(AGENT_LLM_MAP["prompt_enhancer"])
|
||||
|
||||
try:
|
||||
|
||||
# Create messages with context if provided
|
||||
context_info = ""
|
||||
if state.get("context"):
|
||||
context_info = f"\n\nAdditional context: {state['context']}"
|
||||
|
||||
original_prompt_message = HumanMessage(
|
||||
content=f"Please enhance this prompt:{context_info}\n\nOriginal prompt: {state['prompt']}"
|
||||
)
|
||||
|
||||
messages = apply_prompt_template(
|
||||
"prompt_enhancer/prompt_enhancer",
|
||||
{
|
||||
"messages": [original_prompt_message],
|
||||
"report_style": state.get("report_style"),
|
||||
},
|
||||
)
|
||||
|
||||
# Get the response from the model
|
||||
response = model.invoke(messages)
|
||||
|
||||
# Clean up the response - remove any extra formatting or comments
|
||||
enhanced_prompt = response.content.strip()
|
||||
|
||||
# Remove common prefixes that might be added by the model
|
||||
prefixes_to_remove = [
|
||||
"Enhanced Prompt:",
|
||||
"Enhanced prompt:",
|
||||
"Here's the enhanced prompt:",
|
||||
"Here is the enhanced prompt:",
|
||||
"**Enhanced Prompt**:",
|
||||
"**Enhanced prompt**:",
|
||||
]
|
||||
|
||||
for prefix in prefixes_to_remove:
|
||||
if enhanced_prompt.startswith(prefix):
|
||||
enhanced_prompt = enhanced_prompt[len(prefix) :].strip()
|
||||
break
|
||||
|
||||
logger.info("Prompt enhancement completed successfully")
|
||||
logger.debug(f"Enhanced prompt: {enhanced_prompt}")
|
||||
return {"output": enhanced_prompt}
|
||||
except Exception as e:
|
||||
logger.error(f"Error in prompt enhancement: {str(e)}")
|
||||
return {"output": state["prompt"]}
|
||||
14
src/prompt_enhancer/graph/state.py
Normal file
14
src/prompt_enhancer/graph/state.py
Normal file
@@ -0,0 +1,14 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
from typing import TypedDict, Optional
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class PromptEnhancerState(TypedDict):
|
||||
"""State for the prompt enhancer workflow."""
|
||||
|
||||
prompt: str # Original prompt to enhance
|
||||
context: Optional[str] # Additional context
|
||||
report_style: Optional[ReportStyle] # Report style preference
|
||||
output: Optional[str] # Enhanced prompt result
|
||||
104
src/prompts/prompt_enhancer/prompt_enhancer.md
Normal file
104
src/prompts/prompt_enhancer/prompt_enhancer.md
Normal file
@@ -0,0 +1,104 @@
|
||||
---
|
||||
CURRENT_TIME: {{ CURRENT_TIME }}
|
||||
---
|
||||
|
||||
You are an expert prompt engineer. Your task is to enhance user prompts to make them more effective, specific, and likely to produce high-quality results from AI systems.
|
||||
|
||||
# Your Role
|
||||
- Analyze the original prompt for clarity, specificity, and completeness
|
||||
- Enhance the prompt by adding relevant details, context, and structure
|
||||
- Make the prompt more actionable and results-oriented
|
||||
- Preserve the user's original intent while improving effectiveness
|
||||
|
||||
{% if report_style == "academic" %}
|
||||
# Enhancement Guidelines for Academic Style
|
||||
1. **Add methodological rigor**: Include research methodology, scope, and analytical framework
|
||||
2. **Specify academic structure**: Organize with clear thesis, literature review, analysis, and conclusions
|
||||
3. **Clarify scholarly expectations**: Specify citation requirements, evidence standards, and academic tone
|
||||
4. **Add theoretical context**: Include relevant theoretical frameworks and disciplinary perspectives
|
||||
5. **Ensure precision**: Use precise terminology and avoid ambiguous language
|
||||
6. **Include limitations**: Acknowledge scope limitations and potential biases
|
||||
{% elif report_style == "popular_science" %}
|
||||
# Enhancement Guidelines for Popular Science Style
|
||||
1. **Add accessibility**: Transform technical concepts into relatable analogies and examples
|
||||
2. **Improve narrative structure**: Organize as an engaging story with clear beginning, middle, and end
|
||||
3. **Clarify audience expectations**: Specify general audience level and engagement goals
|
||||
4. **Add human context**: Include real-world applications and human interest elements
|
||||
5. **Make it compelling**: Ensure the prompt guides toward fascinating and wonder-inspiring content
|
||||
6. **Include visual elements**: Suggest use of metaphors and descriptive language for complex concepts
|
||||
{% elif report_style == "news" %}
|
||||
# Enhancement Guidelines for News Style
|
||||
1. **Add journalistic rigor**: Include fact-checking requirements, source verification, and objectivity standards
|
||||
2. **Improve news structure**: Organize with inverted pyramid structure (most important information first)
|
||||
3. **Clarify reporting expectations**: Specify timeliness, accuracy, and balanced perspective requirements
|
||||
4. **Add contextual background**: Include relevant background information and broader implications
|
||||
5. **Make it newsworthy**: Ensure the prompt focuses on current relevance and public interest
|
||||
6. **Include attribution**: Specify source requirements and quote standards
|
||||
{% elif report_style == "social_media" %}
|
||||
# Enhancement Guidelines for Social Media Style
|
||||
1. **Add engagement focus**: Include attention-grabbing elements, hooks, and shareability factors
|
||||
2. **Improve platform structure**: Organize for specific platform requirements (character limits, hashtags, etc.)
|
||||
3. **Clarify audience expectations**: Specify target demographic and engagement goals
|
||||
4. **Add viral elements**: Include trending topics, relatable content, and interactive elements
|
||||
5. **Make it shareable**: Ensure the prompt guides toward content that encourages sharing and discussion
|
||||
6. **Include visual considerations**: Suggest emoji usage, formatting, and visual appeal elements
|
||||
{% else %}
|
||||
# General Enhancement Guidelines
|
||||
1. **Add specificity**: Include relevant details, scope, and constraints
|
||||
2. **Improve structure**: Organize the request logically with clear sections if needed
|
||||
3. **Clarify expectations**: Specify desired output format, length, or style
|
||||
4. **Add context**: Include background information that would help generate better results
|
||||
5. **Make it actionable**: Ensure the prompt guides toward concrete, useful outputs
|
||||
{% endif %}
|
||||
|
||||
# Output Requirements
|
||||
- Output ONLY the enhanced prompt
|
||||
- Do NOT include any explanations, comments, or meta-text
|
||||
- Do NOT use phrases like "Enhanced Prompt:" or "Here's the enhanced version:"
|
||||
- The output should be ready to use directly as a prompt
|
||||
|
||||
{% if report_style == "academic" %}
|
||||
# Academic Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Conduct a comprehensive academic analysis of artificial intelligence applications across three key sectors: healthcare, education, and business. Employ a systematic literature review methodology to examine peer-reviewed sources from the past five years. Structure your analysis with: (1) theoretical framework defining AI and its taxonomies, (2) sector-specific case studies with quantitative performance metrics, (3) critical evaluation of implementation challenges and ethical considerations, (4) comparative analysis across sectors, and (5) evidence-based recommendations for future research directions. Maintain academic rigor with proper citations, acknowledge methodological limitations, and present findings with appropriate hedging language. Target length: 3000-4000 words with APA formatting."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Provide a rigorous academic examination of anthropogenic climate change, synthesizing current scientific consensus and recent research developments. Structure your analysis as follows: (1) theoretical foundations of greenhouse effect and radiative forcing mechanisms, (2) systematic review of empirical evidence from paleoclimatic, observational, and modeling studies, (3) critical analysis of attribution studies linking human activities to observed warming, (4) evaluation of climate sensitivity estimates and uncertainty ranges, (5) assessment of projected impacts under different emission scenarios, and (6) discussion of research gaps and methodological limitations. Include quantitative data, statistical significance levels, and confidence intervals where appropriate. Cite peer-reviewed sources extensively and maintain objective, third-person academic voice throughout."
|
||||
|
||||
{% elif report_style == "popular_science" %}
|
||||
# Popular Science Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Tell the fascinating story of how artificial intelligence is quietly revolutionizing our daily lives in ways most people never realize. Take readers on an engaging journey through three surprising realms: the hospital where AI helps doctors spot diseases faster than ever before, the classroom where intelligent tutors adapt to each student's learning style, and the boardroom where algorithms are making million-dollar decisions. Use vivid analogies (like comparing neural networks to how our brains work) and real-world examples that readers can relate to. Include 'wow factor' moments that showcase AI's incredible capabilities, but also honest discussions about current limitations. Write with infectious enthusiasm while maintaining scientific accuracy, and conclude with exciting possibilities that await us in the near future. Aim for 1500-2000 words that feel like a captivating conversation with a brilliant friend."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Craft a compelling narrative that transforms the complex science of climate change into an accessible and engaging story for curious readers. Begin with a relatable scenario (like why your hometown weather feels different than when you were a kid) and use this as a gateway to explore the fascinating science behind our changing planet. Employ vivid analogies - compare Earth's atmosphere to a blanket, greenhouse gases to invisible heat-trapping molecules, and climate feedback loops to a snowball rolling downhill. Include surprising facts and 'aha moments' that will make readers think differently about the world around them. Weave in human stories of scientists making discoveries, communities adapting to change, and innovative solutions being developed. Balance the serious implications with hope and actionable insights, concluding with empowering steps readers can take. Write with wonder and curiosity, making complex concepts feel approachable and personally relevant."
|
||||
|
||||
{% elif report_style == "news" %}
|
||||
# News Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Report on the current state and immediate impact of artificial intelligence across three critical sectors: healthcare, education, and business. Lead with the most newsworthy developments and recent breakthroughs that are affecting people today. Structure using inverted pyramid format: start with key findings and immediate implications, then provide essential background context, followed by detailed analysis and expert perspectives. Include specific, verifiable data points, recent statistics, and quotes from credible sources including industry leaders, researchers, and affected stakeholders. Address both benefits and concerns with balanced reporting, fact-check all claims, and provide proper attribution for all information. Focus on timeliness and relevance to current events, highlighting what's happening now and what readers need to know. Maintain journalistic objectivity while making the significance clear to a general news audience. Target 800-1200 words following AP style guidelines."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Provide comprehensive news coverage of climate change that explains the current scientific understanding and immediate implications for readers. Lead with the most recent and significant developments in climate science, policy, or impacts that are making headlines today. Structure the report with: breaking developments first, essential background for understanding the issue, current scientific consensus with specific data and timeframes, real-world impacts already being observed, policy responses and debates, and what experts say comes next. Include quotes from credible climate scientists, policy makers, and affected communities. Present information objectively while clearly communicating the scientific consensus, fact-check all claims, and provide proper source attribution. Address common misconceptions with factual corrections. Focus on what's happening now, why it matters to readers, and what they can expect in the near future. Follow journalistic standards for accuracy, balance, and timeliness."
|
||||
|
||||
{% elif report_style == "social_media" %}
|
||||
# Social Media Style Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Create engaging social media content about AI that will stop the scroll and spark conversations! Start with an attention-grabbing hook like 'You won't believe what AI just did in hospitals this week 🤯' and structure as a compelling thread or post series. Include surprising facts, relatable examples (like AI helping doctors spot diseases or personalizing your Netflix recommendations), and interactive elements that encourage sharing and comments. Use strategic hashtags (#AI #Technology #Future), incorporate relevant emojis for visual appeal, and include questions that prompt audience engagement ('Have you noticed AI in your daily life? Drop examples below! 👇'). Make complex concepts digestible with bite-sized explanations, trending analogies, and shareable quotes. Include a clear call-to-action and optimize for the specific platform (Twitter threads, Instagram carousel, LinkedIn professional insights, or TikTok-style quick facts). Aim for high shareability with content that feels both informative and entertaining."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Develop viral-worthy social media content that makes climate change accessible and shareable without being preachy. Open with a scroll-stopping hook like 'The weather app on your phone is telling a bigger story than you think 📱🌡️' and break down complex science into digestible, engaging chunks. Use relatable comparisons (Earth's fever, atmosphere as a blanket), trending formats (before/after visuals, myth-busting series, quick facts), and interactive elements (polls, questions, challenges). Include strategic hashtags (#ClimateChange #Science #Environment), eye-catching emojis, and shareable graphics or infographics. Address common questions and misconceptions with clear, factual responses. Create content that encourages positive action rather than climate anxiety, ending with empowering steps followers can take. Optimize for platform-specific features (Instagram Stories, TikTok trends, Twitter threads) and include calls-to-action that drive engagement and sharing."
|
||||
|
||||
{% else %}
|
||||
# General Examples
|
||||
|
||||
**Original**: "Write about AI"
|
||||
**Enhanced**: "Write a comprehensive 1000-word analysis of artificial intelligence's current applications in healthcare, education, and business. Include specific examples of AI tools being used in each sector, discuss both benefits and challenges, and provide insights into future trends. Structure the response with clear sections for each industry and conclude with key takeaways."
|
||||
|
||||
**Original**: "Explain climate change"
|
||||
**Enhanced**: "Provide a detailed explanation of climate change suitable for a general audience. Cover the scientific mechanisms behind global warming, major causes including greenhouse gas emissions, observable effects we're seeing today, and projected future impacts. Include specific data and examples, and explain the difference between weather and climate. Organize the response with clear headings and conclude with actionable steps individuals can take."
|
||||
{% endif %}
|
||||
@@ -20,11 +20,13 @@ from src.graph.builder import build_graph_with_memory
|
||||
from src.podcast.graph.builder import build_graph as build_podcast_graph
|
||||
from src.ppt.graph.builder import build_graph as build_ppt_graph
|
||||
from src.prose.graph.builder import build_graph as build_prose_graph
|
||||
from src.prompt_enhancer.graph.builder import build_graph as build_prompt_enhancer_graph
|
||||
from src.rag.builder import build_retriever
|
||||
from src.rag.retriever import Resource
|
||||
from src.server.chat_request import (
|
||||
ChatMessage,
|
||||
ChatRequest,
|
||||
EnhancePromptRequest,
|
||||
GeneratePodcastRequest,
|
||||
GeneratePPTRequest,
|
||||
GenerateProseRequest,
|
||||
@@ -300,6 +302,50 @@ async def generate_prose(request: GenerateProseRequest):
|
||||
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
||||
|
||||
|
||||
@app.post("/api/prompt/enhance")
|
||||
async def enhance_prompt(request: EnhancePromptRequest):
|
||||
try:
|
||||
sanitized_prompt = request.prompt.replace("\r\n", "").replace("\n", "")
|
||||
logger.info(f"Enhancing prompt: {sanitized_prompt}")
|
||||
|
||||
# Convert string report_style to ReportStyle enum
|
||||
report_style = None
|
||||
if request.report_style:
|
||||
try:
|
||||
# Handle both uppercase and lowercase input
|
||||
style_mapping = {
|
||||
"ACADEMIC": ReportStyle.ACADEMIC,
|
||||
"POPULAR_SCIENCE": ReportStyle.POPULAR_SCIENCE,
|
||||
"NEWS": ReportStyle.NEWS,
|
||||
"SOCIAL_MEDIA": ReportStyle.SOCIAL_MEDIA,
|
||||
"academic": ReportStyle.ACADEMIC,
|
||||
"popular_science": ReportStyle.POPULAR_SCIENCE,
|
||||
"news": ReportStyle.NEWS,
|
||||
"social_media": ReportStyle.SOCIAL_MEDIA,
|
||||
}
|
||||
report_style = style_mapping.get(
|
||||
request.report_style, ReportStyle.ACADEMIC
|
||||
)
|
||||
except Exception:
|
||||
# If invalid style, default to ACADEMIC
|
||||
report_style = ReportStyle.ACADEMIC
|
||||
else:
|
||||
report_style = ReportStyle.ACADEMIC
|
||||
|
||||
workflow = build_prompt_enhancer_graph()
|
||||
final_state = workflow.invoke(
|
||||
{
|
||||
"prompt": request.prompt,
|
||||
"context": request.context,
|
||||
"report_style": report_style,
|
||||
}
|
||||
)
|
||||
return {"result": final_state["output"]}
|
||||
except Exception as e:
|
||||
logger.exception(f"Error occurred during prompt enhancement: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail=INTERNAL_SERVER_ERROR_DETAIL)
|
||||
|
||||
|
||||
@app.post("/api/mcp/server/metadata", response_model=MCPServerMetadataResponse)
|
||||
async def mcp_server_metadata(request: MCPServerMetadataRequest):
|
||||
"""Get information about an MCP server."""
|
||||
|
||||
@@ -94,3 +94,13 @@ class GenerateProseRequest(BaseModel):
|
||||
command: Optional[str] = Field(
|
||||
"", description="The user custom command of the prose writer"
|
||||
)
|
||||
|
||||
|
||||
class EnhancePromptRequest(BaseModel):
|
||||
prompt: str = Field(..., description="The original prompt to enhance")
|
||||
context: Optional[str] = Field(
|
||||
"", description="Additional context about the intended use"
|
||||
)
|
||||
report_style: Optional[str] = Field(
|
||||
"academic", description="The style of the report"
|
||||
)
|
||||
|
||||
2
tests/unit/prompt_enhancer/__init__.py
Normal file
2
tests/unit/prompt_enhancer/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
2
tests/unit/prompt_enhancer/graph/__init__.py
Normal file
2
tests/unit/prompt_enhancer/graph/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
156
tests/unit/prompt_enhancer/graph/test_builder.py
Normal file
156
tests/unit/prompt_enhancer/graph/test_builder.py
Normal file
@@ -0,0 +1,156 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from src.prompt_enhancer.graph.builder import build_graph
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
class TestBuildGraph:
|
||||
"""Test cases for build_graph function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_structure(self, mock_state_graph):
|
||||
"""Test that build_graph creates the correct graph structure."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify StateGraph was created with correct state type
|
||||
mock_state_graph.assert_called_once_with(PromptEnhancerState)
|
||||
|
||||
# Verify entry point was set
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify finish point was set
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
# Verify graph was compiled
|
||||
mock_builder.compile.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
@patch("src.prompt_enhancer.graph.builder.prompt_enhancer_node")
|
||||
def test_build_graph_node_function(self, mock_enhancer_node, mock_state_graph):
|
||||
"""Test that the correct node function is added to the graph."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
# Verify the correct node function was added
|
||||
mock_builder.add_node.assert_called_once_with("enhancer", mock_enhancer_node)
|
||||
|
||||
def test_build_graph_returns_compiled_graph(self):
|
||||
"""Test that build_graph returns a compiled graph object."""
|
||||
with patch("src.prompt_enhancer.graph.builder.StateGraph") as mock_state_graph:
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
result = build_graph()
|
||||
|
||||
assert result is mock_compiled_graph
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_call_sequence(self, mock_state_graph):
|
||||
"""Test that build_graph calls methods in the correct sequence."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
# Track call order
|
||||
call_order = []
|
||||
|
||||
def track_add_node(*args, **kwargs):
|
||||
call_order.append("add_node")
|
||||
|
||||
def track_set_entry_point(*args, **kwargs):
|
||||
call_order.append("set_entry_point")
|
||||
|
||||
def track_set_finish_point(*args, **kwargs):
|
||||
call_order.append("set_finish_point")
|
||||
|
||||
def track_compile(*args, **kwargs):
|
||||
call_order.append("compile")
|
||||
return mock_compiled_graph
|
||||
|
||||
mock_builder.add_node.side_effect = track_add_node
|
||||
mock_builder.set_entry_point.side_effect = track_set_entry_point
|
||||
mock_builder.set_finish_point.side_effect = track_set_finish_point
|
||||
mock_builder.compile.side_effect = track_compile
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify the correct call sequence
|
||||
expected_order = ["add_node", "set_entry_point", "set_finish_point", "compile"]
|
||||
assert call_order == expected_order
|
||||
|
||||
def test_build_graph_integration(self):
|
||||
"""Integration test to verify the graph can be built without mocking."""
|
||||
# This test verifies that all imports and dependencies are correct
|
||||
try:
|
||||
graph = build_graph()
|
||||
assert graph is not None
|
||||
# The graph should be a compiled LangGraph object
|
||||
assert hasattr(graph, "invoke") or hasattr(graph, "stream")
|
||||
except ImportError as e:
|
||||
pytest.skip(f"Skipping integration test due to missing dependencies: {e}")
|
||||
except Exception as e:
|
||||
# If there are configuration issues (like missing LLM config),
|
||||
# we still consider the test successful if the graph structure is built
|
||||
if "LLM" in str(e) or "configuration" in str(e).lower():
|
||||
pytest.skip(
|
||||
f"Skipping integration test due to configuration issues: {e}"
|
||||
)
|
||||
else:
|
||||
raise
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_single_node_workflow(self, mock_state_graph):
|
||||
"""Test that the graph is configured as a single-node workflow."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify only one node is added
|
||||
assert mock_builder.add_node.call_count == 1
|
||||
|
||||
# Verify entry and finish points are the same node
|
||||
mock_builder.set_entry_point.assert_called_once_with("enhancer")
|
||||
mock_builder.set_finish_point.assert_called_once_with("enhancer")
|
||||
|
||||
@patch("src.prompt_enhancer.graph.builder.StateGraph")
|
||||
def test_build_graph_state_type(self, mock_state_graph):
|
||||
"""Test that the graph is initialized with the correct state type."""
|
||||
mock_builder = MagicMock()
|
||||
mock_compiled_graph = MagicMock()
|
||||
|
||||
mock_state_graph.return_value = mock_builder
|
||||
mock_builder.compile.return_value = mock_compiled_graph
|
||||
|
||||
build_graph()
|
||||
|
||||
# Verify StateGraph was initialized with PromptEnhancerState
|
||||
args, kwargs = mock_state_graph.call_args
|
||||
assert args[0] == PromptEnhancerState
|
||||
219
tests/unit/prompt_enhancer/graph/test_enhancer_node.py
Normal file
219
tests/unit/prompt_enhancer/graph/test_enhancer_node.py
Normal file
@@ -0,0 +1,219 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from unittest.mock import patch, MagicMock
|
||||
from langchain.schema import HumanMessage, SystemMessage
|
||||
|
||||
from src.prompt_enhancer.graph.enhancer_node import prompt_enhancer_node
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_llm():
|
||||
"""Mock LLM that returns a test response."""
|
||||
llm = MagicMock()
|
||||
llm.invoke.return_value = MagicMock(content="Enhanced test prompt")
|
||||
return llm
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_messages():
|
||||
"""Mock messages returned by apply_prompt_template."""
|
||||
return [
|
||||
SystemMessage(content="System prompt template"),
|
||||
HumanMessage(content="Test human message"),
|
||||
]
|
||||
|
||||
|
||||
class TestPromptEnhancerNode:
|
||||
"""Test cases for prompt_enhancer_node function."""
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_basic_prompt_enhancement(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test basic prompt enhancement without context or report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(prompt="Write about AI")
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify LLM was called
|
||||
mock_get_llm.assert_called_once_with("basic")
|
||||
mock_llm.invoke.assert_called_once_with(mock_messages)
|
||||
|
||||
# Verify apply_prompt_template was called correctly
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert "messages" in call_args[0][1]
|
||||
assert "report_style" in call_args[0][1]
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_report_style(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with report style."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", report_style=ReportStyle.ACADEMIC
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called with report_style
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
assert call_args[0][0] == "prompt_enhancer/prompt_enhancer"
|
||||
assert call_args[0][1]["report_style"] == ReportStyle.ACADEMIC
|
||||
|
||||
# Verify result
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prompt_enhancement_with_context(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test prompt enhancement with additional context."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI", context="Focus on machine learning applications"
|
||||
)
|
||||
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Verify apply_prompt_template was called
|
||||
mock_apply_template.assert_called_once()
|
||||
call_args = mock_apply_template.call_args
|
||||
|
||||
# Check that the context was included in the human message
|
||||
messages_arg = call_args[0][1]["messages"]
|
||||
assert len(messages_arg) == 1
|
||||
human_message = messages_arg[0]
|
||||
assert isinstance(human_message, HumanMessage)
|
||||
assert "Focus on machine learning applications" in human_message.content
|
||||
|
||||
assert result == {"output": "Enhanced test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when LLM call fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM to raise an exception
|
||||
mock_llm.invoke.side_effect = Exception("LLM error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_template_error_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test error handling when template application fails."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
|
||||
# Mock apply_prompt_template to raise an exception
|
||||
mock_apply_template.side_effect = Exception("Template error")
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
# Should return original prompt on error
|
||||
assert result == {"output": "Test prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_prefix_removal(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that common prefixes are removed from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Test different prefixes that should be removed
|
||||
test_cases = [
|
||||
"Enhanced Prompt: This is the enhanced prompt",
|
||||
"Enhanced prompt: This is the enhanced prompt",
|
||||
"Here's the enhanced prompt: This is the enhanced prompt",
|
||||
"Here is the enhanced prompt: This is the enhanced prompt",
|
||||
"**Enhanced Prompt**: This is the enhanced prompt",
|
||||
"**Enhanced prompt**: This is the enhanced prompt",
|
||||
]
|
||||
|
||||
for response_with_prefix in test_cases:
|
||||
mock_llm.invoke.return_value = MagicMock(content=response_with_prefix)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "This is the enhanced prompt"}
|
||||
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.apply_prompt_template")
|
||||
@patch("src.prompt_enhancer.graph.enhancer_node.get_llm_by_type")
|
||||
@patch(
|
||||
"src.prompt_enhancer.graph.enhancer_node.AGENT_LLM_MAP",
|
||||
{"prompt_enhancer": "basic"},
|
||||
)
|
||||
def test_whitespace_handling(
|
||||
self, mock_get_llm, mock_apply_template, mock_llm, mock_messages
|
||||
):
|
||||
"""Test that whitespace is properly stripped from LLM response."""
|
||||
mock_get_llm.return_value = mock_llm
|
||||
mock_apply_template.return_value = mock_messages
|
||||
|
||||
# Mock LLM response with extra whitespace
|
||||
mock_llm.invoke.return_value = MagicMock(
|
||||
content=" \n\n Enhanced prompt \n\n "
|
||||
)
|
||||
|
||||
state = PromptEnhancerState(prompt="Test prompt")
|
||||
result = prompt_enhancer_node(state)
|
||||
|
||||
assert result == {"output": "Enhanced prompt"}
|
||||
108
tests/unit/prompt_enhancer/graph/test_state.py
Normal file
108
tests/unit/prompt_enhancer/graph/test_state.py
Normal file
@@ -0,0 +1,108 @@
|
||||
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
# SPDX-License-Identifier: MIT
|
||||
|
||||
import pytest
|
||||
from src.prompt_enhancer.graph.state import PromptEnhancerState
|
||||
from src.config.report_style import ReportStyle
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_creation():
|
||||
"""Test that PromptEnhancerState can be created with required fields."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt", context=None, report_style=None, output=None
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Test prompt"
|
||||
assert state["context"] is None
|
||||
assert state["report_style"] is None
|
||||
assert state["output"] is None
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_all_fields():
|
||||
"""Test PromptEnhancerState with all fields populated."""
|
||||
state = PromptEnhancerState(
|
||||
prompt="Write about AI",
|
||||
context="Additional context about AI research",
|
||||
report_style=ReportStyle.ACADEMIC,
|
||||
output="Enhanced prompt about AI research",
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Write about AI"
|
||||
assert state["context"] == "Additional context about AI research"
|
||||
assert state["report_style"] == ReportStyle.ACADEMIC
|
||||
assert state["output"] == "Enhanced prompt about AI research"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_minimal():
|
||||
"""Test PromptEnhancerState with only required prompt field."""
|
||||
state = PromptEnhancerState(prompt="Minimal prompt")
|
||||
|
||||
assert state["prompt"] == "Minimal prompt"
|
||||
# Optional fields should not be present if not specified
|
||||
assert "context" not in state
|
||||
assert "report_style" not in state
|
||||
assert "output" not in state
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_with_different_report_styles():
|
||||
"""Test PromptEnhancerState with different ReportStyle values."""
|
||||
styles = [
|
||||
ReportStyle.ACADEMIC,
|
||||
ReportStyle.POPULAR_SCIENCE,
|
||||
ReportStyle.NEWS,
|
||||
ReportStyle.SOCIAL_MEDIA,
|
||||
]
|
||||
|
||||
for style in styles:
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=style)
|
||||
assert state["report_style"] == style
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_update():
|
||||
"""Test updating PromptEnhancerState fields."""
|
||||
state = PromptEnhancerState(prompt="Original prompt")
|
||||
|
||||
# Update with new fields
|
||||
state.update(
|
||||
{
|
||||
"context": "New context",
|
||||
"report_style": ReportStyle.NEWS,
|
||||
"output": "Enhanced output",
|
||||
}
|
||||
)
|
||||
|
||||
assert state["prompt"] == "Original prompt"
|
||||
assert state["context"] == "New context"
|
||||
assert state["report_style"] == ReportStyle.NEWS
|
||||
assert state["output"] == "Enhanced output"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_get_method():
|
||||
"""Test using get() method on PromptEnhancerState."""
|
||||
state = PromptEnhancerState(prompt="Test prompt", report_style=ReportStyle.ACADEMIC)
|
||||
|
||||
# Test get with existing keys
|
||||
assert state.get("prompt") == "Test prompt"
|
||||
assert state.get("report_style") == ReportStyle.ACADEMIC
|
||||
|
||||
# Test get with non-existing keys
|
||||
assert state.get("context") is None
|
||||
assert state.get("output") is None
|
||||
assert state.get("nonexistent", "default") == "default"
|
||||
|
||||
|
||||
def test_prompt_enhancer_state_type_annotations():
|
||||
"""Test that the state accepts correct types."""
|
||||
# This test ensures the TypedDict structure is working correctly
|
||||
state = PromptEnhancerState(
|
||||
prompt="Test prompt",
|
||||
context="Test context",
|
||||
report_style=ReportStyle.POPULAR_SCIENCE,
|
||||
output="Test output",
|
||||
)
|
||||
|
||||
# Verify types
|
||||
assert isinstance(state["prompt"], str)
|
||||
assert isinstance(state["context"], str)
|
||||
assert isinstance(state["report_style"], ReportStyle)
|
||||
assert isinstance(state["output"], str)
|
||||
@@ -1,9 +1,10 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import { MagicWandIcon } from "@radix-ui/react-icons";
|
||||
import { AnimatePresence, motion } from "framer-motion";
|
||||
import { ArrowUp, X } from "lucide-react";
|
||||
import { useCallback, useRef } from "react";
|
||||
import { useCallback, useRef, useState } from "react";
|
||||
|
||||
import { Detective } from "~/components/deer-flow/icons/detective";
|
||||
import MessageInput, {
|
||||
@@ -11,7 +12,9 @@ import MessageInput, {
|
||||
} from "~/components/deer-flow/message-input";
|
||||
import { ReportStyleDialog } from "~/components/deer-flow/report-style-dialog";
|
||||
import { Tooltip } from "~/components/deer-flow/tooltip";
|
||||
import { BorderBeam } from "~/components/magicui/border-beam";
|
||||
import { Button } from "~/components/ui/button";
|
||||
import { enhancePrompt } from "~/core/api";
|
||||
import type { Option, Resource } from "~/core/messages";
|
||||
import {
|
||||
setEnableBackgroundInvestigation,
|
||||
@@ -44,10 +47,16 @@ export function InputBox({
|
||||
const backgroundInvestigation = useSettingsStore(
|
||||
(state) => state.general.enableBackgroundInvestigation,
|
||||
);
|
||||
const reportStyle = useSettingsStore((state) => state.general.reportStyle);
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<MessageInputRef>(null);
|
||||
const feedbackRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Enhancement state
|
||||
const [isEnhancing, setIsEnhancing] = useState(false);
|
||||
const [isEnhanceAnimating, setIsEnhanceAnimating] = useState(false);
|
||||
const [currentPrompt, setCurrentPrompt] = useState("");
|
||||
|
||||
const handleSendMessage = useCallback(
|
||||
(message: string, resources: Array<Resource>) => {
|
||||
if (responding) {
|
||||
@@ -62,12 +71,50 @@ export function InputBox({
|
||||
resources,
|
||||
});
|
||||
onRemoveFeedback?.();
|
||||
// Clear enhancement animation after sending
|
||||
setIsEnhanceAnimating(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[responding, onCancel, onSend, feedback, onRemoveFeedback],
|
||||
);
|
||||
|
||||
const handleEnhancePrompt = useCallback(async () => {
|
||||
if (currentPrompt.trim() === "" || isEnhancing) {
|
||||
return;
|
||||
}
|
||||
|
||||
setIsEnhancing(true);
|
||||
setIsEnhanceAnimating(true);
|
||||
|
||||
try {
|
||||
const enhancedPrompt = await enhancePrompt({
|
||||
prompt: currentPrompt,
|
||||
report_style: reportStyle.toUpperCase(),
|
||||
});
|
||||
|
||||
// Add a small delay for better UX
|
||||
await new Promise((resolve) => setTimeout(resolve, 500));
|
||||
|
||||
// Update the input with the enhanced prompt with animation
|
||||
if (inputRef.current) {
|
||||
inputRef.current.setContent(enhancedPrompt);
|
||||
setCurrentPrompt(enhancedPrompt);
|
||||
}
|
||||
|
||||
// Keep animation for a bit longer to show the effect
|
||||
setTimeout(() => {
|
||||
setIsEnhanceAnimating(false);
|
||||
}, 1000);
|
||||
} catch (error) {
|
||||
console.error("Failed to enhance prompt:", error);
|
||||
setIsEnhanceAnimating(false);
|
||||
// Could add toast notification here
|
||||
} finally {
|
||||
setIsEnhancing(false);
|
||||
}
|
||||
}, [currentPrompt, isEnhancing, reportStyle]);
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -97,11 +144,61 @@ export function InputBox({
|
||||
/>
|
||||
</motion.div>
|
||||
)}
|
||||
{isEnhanceAnimating && (
|
||||
<motion.div
|
||||
className="pointer-events-none absolute inset-0 z-20"
|
||||
initial={{ opacity: 0 }}
|
||||
animate={{ opacity: 1 }}
|
||||
exit={{ opacity: 0 }}
|
||||
transition={{ duration: 0.3 }}
|
||||
>
|
||||
<div className="relative h-full w-full">
|
||||
{/* Sparkle effect overlay */}
|
||||
<motion.div
|
||||
className="absolute inset-0 rounded-[24px] bg-gradient-to-r from-blue-500/10 via-purple-500/10 to-blue-500/10"
|
||||
animate={{
|
||||
background: [
|
||||
"linear-gradient(45deg, rgba(59, 130, 246, 0.1), rgba(147, 51, 234, 0.1), rgba(59, 130, 246, 0.1))",
|
||||
"linear-gradient(225deg, rgba(147, 51, 234, 0.1), rgba(59, 130, 246, 0.1), rgba(147, 51, 234, 0.1))",
|
||||
"linear-gradient(45deg, rgba(59, 130, 246, 0.1), rgba(147, 51, 234, 0.1), rgba(59, 130, 246, 0.1))",
|
||||
],
|
||||
}}
|
||||
transition={{ duration: 2, repeat: Infinity }}
|
||||
/>
|
||||
{/* Floating sparkles */}
|
||||
{[...Array(6)].map((_, i) => (
|
||||
<motion.div
|
||||
key={i}
|
||||
className="absolute h-2 w-2 rounded-full bg-blue-400"
|
||||
style={{
|
||||
left: `${20 + i * 12}%`,
|
||||
top: `${30 + (i % 2) * 40}%`,
|
||||
}}
|
||||
animate={{
|
||||
y: [-10, -20, -10],
|
||||
opacity: [0, 1, 0],
|
||||
scale: [0.5, 1, 0.5],
|
||||
}}
|
||||
transition={{
|
||||
duration: 1.5,
|
||||
repeat: Infinity,
|
||||
delay: i * 0.2,
|
||||
}}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
</motion.div>
|
||||
)}
|
||||
</AnimatePresence>
|
||||
<MessageInput
|
||||
className={cn("h-24 px-4 pt-5", feedback && "pt-9")}
|
||||
className={cn(
|
||||
"h-24 px-4 pt-5",
|
||||
feedback && "pt-9",
|
||||
isEnhanceAnimating && "transition-all duration-500",
|
||||
)}
|
||||
ref={inputRef}
|
||||
onEnter={handleSendMessage}
|
||||
onChange={setCurrentPrompt}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center px-4 py-2">
|
||||
@@ -137,6 +234,26 @@ export function InputBox({
|
||||
<ReportStyleDialog />
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-2">
|
||||
<Tooltip title="Enhance prompt with AI">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="icon"
|
||||
className={cn(
|
||||
"hover:bg-accent h-10 w-10",
|
||||
isEnhancing && "animate-pulse",
|
||||
)}
|
||||
onClick={handleEnhancePrompt}
|
||||
disabled={isEnhancing || currentPrompt.trim() === ""}
|
||||
>
|
||||
{isEnhancing ? (
|
||||
<div className="flex h-10 w-10 items-center justify-center">
|
||||
<div className="bg-foreground h-3 w-3 animate-bounce rounded-full opacity-70" />
|
||||
</div>
|
||||
) : (
|
||||
<MagicWandIcon className="text-brand" />
|
||||
)}
|
||||
</Button>
|
||||
</Tooltip>
|
||||
<Tooltip title={responding ? "Stop" : "Send"}>
|
||||
<Button
|
||||
variant="outline"
|
||||
@@ -155,6 +272,21 @@ export function InputBox({
|
||||
</Tooltip>
|
||||
</div>
|
||||
</div>
|
||||
{isEnhancing && (
|
||||
<>
|
||||
<BorderBeam
|
||||
duration={5}
|
||||
size={250}
|
||||
className="from-transparent via-red-500 to-transparent"
|
||||
/>
|
||||
<BorderBeam
|
||||
duration={5}
|
||||
delay={3}
|
||||
size={250}
|
||||
className="from-transparent via-blue-500 to-transparent"
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
42
web/src/components/deer-flow/icons/enhance.tsx
Normal file
42
web/src/components/deer-flow/icons/enhance.tsx
Normal file
@@ -0,0 +1,42 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import type { SVGProps } from "react";
|
||||
|
||||
export function Enhance(props: SVGProps<SVGSVGElement>) {
|
||||
return (
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<path
|
||||
d="M12 2L13.09 8.26L20 9L13.09 9.74L12 16L10.91 9.74L4 9L10.91 8.26L12 2Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="2"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
fill="none"
|
||||
/>
|
||||
<path
|
||||
d="M19 14L19.5 16.5L22 17L19.5 17.5L19 20L18.5 17.5L16 17L18.5 16.5L19 14Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
fill="none"
|
||||
/>
|
||||
<path
|
||||
d="M5 6L5.5 7.5L7 8L5.5 8.5L5 10L4.5 8.5L3 8L4.5 7.5L5 6Z"
|
||||
stroke="currentColor"
|
||||
strokeWidth="1.5"
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
fill="none"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
}
|
||||
@@ -26,6 +26,7 @@ import { LoadingOutlined } from "@ant-design/icons";
|
||||
export interface MessageInputRef {
|
||||
focus: () => void;
|
||||
submit: () => void;
|
||||
setContent: (content: string) => void;
|
||||
}
|
||||
|
||||
export interface MessageInputProps {
|
||||
@@ -82,8 +83,9 @@ const MessageInput = forwardRef<MessageInputRef, MessageInputProps>(
|
||||
const debouncedUpdates = useDebouncedCallback(
|
||||
async (editor: EditorInstance) => {
|
||||
if (onChange) {
|
||||
const markdown = editor.storage.markdown.getMarkdown();
|
||||
onChange(markdown);
|
||||
// Get the plain text content for prompt enhancement
|
||||
const { text } = formatMessage(editor.getJSON() ?? []);
|
||||
onChange(text);
|
||||
}
|
||||
},
|
||||
200,
|
||||
@@ -101,6 +103,11 @@ const MessageInput = forwardRef<MessageInputRef, MessageInputProps>(
|
||||
onEnter(text, resources);
|
||||
}
|
||||
},
|
||||
setContent: (content: string) => {
|
||||
if (editorRef.current) {
|
||||
editorRef.current.commands.setContent(content);
|
||||
}
|
||||
},
|
||||
}));
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
94
web/src/components/magicui/border-beam.tsx
Normal file
94
web/src/components/magicui/border-beam.tsx
Normal file
@@ -0,0 +1,94 @@
|
||||
"use client";
|
||||
|
||||
import { cn } from "~/lib/utils";
|
||||
import { motion, type MotionStyle, type Transition } from "motion/react";
|
||||
|
||||
interface BorderBeamProps {
|
||||
/**
|
||||
* The size of the border beam.
|
||||
*/
|
||||
size?: number;
|
||||
/**
|
||||
* The duration of the border beam.
|
||||
*/
|
||||
duration?: number;
|
||||
/**
|
||||
* The delay of the border beam.
|
||||
*/
|
||||
delay?: number;
|
||||
/**
|
||||
* The color of the border beam from.
|
||||
*/
|
||||
colorFrom?: string;
|
||||
/**
|
||||
* The color of the border beam to.
|
||||
*/
|
||||
colorTo?: string;
|
||||
/**
|
||||
* The motion transition of the border beam.
|
||||
*/
|
||||
transition?: Transition;
|
||||
/**
|
||||
* The class name of the border beam.
|
||||
*/
|
||||
className?: string;
|
||||
/**
|
||||
* The style of the border beam.
|
||||
*/
|
||||
style?: React.CSSProperties;
|
||||
/**
|
||||
* Whether to reverse the animation direction.
|
||||
*/
|
||||
reverse?: boolean;
|
||||
/**
|
||||
* The initial offset position (0-100).
|
||||
*/
|
||||
initialOffset?: number;
|
||||
}
|
||||
|
||||
export const BorderBeam = ({
|
||||
className,
|
||||
size = 50,
|
||||
delay = 0,
|
||||
duration = 6,
|
||||
colorFrom = "#ffaa40",
|
||||
colorTo = "#9c40ff",
|
||||
transition,
|
||||
style,
|
||||
reverse = false,
|
||||
initialOffset = 0,
|
||||
}: BorderBeamProps) => {
|
||||
return (
|
||||
<div className="pointer-events-none absolute inset-0 rounded-[inherit] border border-transparent [mask-image:linear-gradient(transparent,transparent),linear-gradient(#000,#000)] [mask-composite:intersect] [mask-clip:padding-box,border-box]">
|
||||
<motion.div
|
||||
className={cn(
|
||||
"absolute aspect-square",
|
||||
"bg-gradient-to-l from-[var(--color-from)] via-[var(--color-to)] to-transparent",
|
||||
className,
|
||||
)}
|
||||
style={
|
||||
{
|
||||
width: size,
|
||||
offsetPath: `rect(0 auto auto 0 round ${size}px)`,
|
||||
"--color-from": colorFrom,
|
||||
"--color-to": colorTo,
|
||||
...style,
|
||||
} as MotionStyle
|
||||
}
|
||||
initial={{ offsetDistance: `${initialOffset}%` }}
|
||||
animate={{
|
||||
offsetDistance: reverse
|
||||
? [`${100 - initialOffset}%`, `${-initialOffset}%`]
|
||||
: [`${initialOffset}%`, `${100 + initialOffset}%`],
|
||||
}}
|
||||
transition={{
|
||||
repeat: Infinity,
|
||||
ease: "linear",
|
||||
duration,
|
||||
delay: -delay,
|
||||
...transition,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -4,4 +4,5 @@
|
||||
export * from "./chat";
|
||||
export * from "./mcp";
|
||||
export * from "./podcast";
|
||||
export * from "./prompt-enhancer";
|
||||
export * from "./types";
|
||||
|
||||
62
web/src/core/api/prompt-enhancer.ts
Normal file
62
web/src/core/api/prompt-enhancer.ts
Normal file
@@ -0,0 +1,62 @@
|
||||
// Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
|
||||
// SPDX-License-Identifier: MIT
|
||||
|
||||
import { resolveServiceURL } from "./resolve-service-url";
|
||||
|
||||
export interface EnhancePromptRequest {
|
||||
prompt: string;
|
||||
context?: string;
|
||||
report_style?: string;
|
||||
}
|
||||
|
||||
export interface EnhancePromptResponse {
|
||||
enhanced_prompt: string;
|
||||
}
|
||||
|
||||
export async function enhancePrompt(
|
||||
request: EnhancePromptRequest,
|
||||
): Promise<string> {
|
||||
const response = await fetch(resolveServiceURL("prompt/enhance"), {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(request),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
const data = await response.json();
|
||||
console.log("Raw API response:", data); // Debug log
|
||||
|
||||
// The backend now returns the enhanced prompt directly in the result field
|
||||
let enhancedPrompt = data.result;
|
||||
|
||||
// If the result is somehow still a JSON object, extract the enhanced_prompt
|
||||
if (typeof enhancedPrompt === "object" && enhancedPrompt.enhanced_prompt) {
|
||||
enhancedPrompt = enhancedPrompt.enhanced_prompt;
|
||||
}
|
||||
|
||||
// If the result is a JSON string, try to parse it
|
||||
if (typeof enhancedPrompt === "string") {
|
||||
try {
|
||||
const parsed = JSON.parse(enhancedPrompt);
|
||||
if (parsed.enhanced_prompt) {
|
||||
enhancedPrompt = parsed.enhanced_prompt;
|
||||
}
|
||||
} catch {
|
||||
// If parsing fails, use the string as-is (which is what we want)
|
||||
console.log("Using enhanced prompt as-is:", enhancedPrompt);
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to original prompt if something went wrong
|
||||
if (!enhancedPrompt || enhancedPrompt.trim() === "") {
|
||||
console.warn("No enhanced prompt received, using original");
|
||||
enhancedPrompt = request.prompt;
|
||||
}
|
||||
|
||||
return enhancedPrompt;
|
||||
}
|
||||
Reference in New Issue
Block a user