mirror of
https://gitee.com/wanwujie/deer-flow
synced 2026-04-19 04:14:46 +08:00
feat(agent):Supports custom agent and chat experience with refactoring (#957)
* feat: add agent management functionality with creation, editing, and deletion * feat: enhance agent creation and chat experience - Added AgentWelcome component to display agent description on new thread creation. - Improved agent name validation with availability check during agent creation. - Updated NewAgentPage to handle agent creation flow more effectively, including enhanced error handling and user feedback. - Refactored chat components to streamline message handling and improve user experience. - Introduced new bootstrap skill for personalized onboarding conversations, including detailed conversation phases and a structured SOUL.md template. - Updated localization files to reflect new features and error messages. - General code cleanup and optimizations across various components and hooks. * Refactor workspace layout and agent management components - Updated WorkspaceLayout to use useLayoutEffect for sidebar state initialization. - Removed unused AgentFormDialog and related edit functionality from AgentCard. - Introduced ArtifactTrigger component to manage artifact visibility. - Enhanced ChatBox to handle artifact selection and display. - Improved message list rendering logic to avoid loading states. - Updated localization files to remove deprecated keys and add new translations. - Refined hooks for local settings and thread management to improve performance and clarity. - Added temporal awareness guidelines to deep research skill documentation. * feat: refactor chat components and introduce thread management hooks * feat: improve artifact file detail preview logic and clean up console logs * feat: refactor lead agent creation logic and improve logging details * feat: validate agent name format and enhance error handling in agent setup * feat: simplify thread search query by removing unnecessary metadata * feat: update query key in useDeleteThread and useRenameThread for consistency * feat: add isMock parameter to thread and artifact handling for improved testing * fix: reorder import of setup_agent for consistency in builtins module * feat: append mock parameter to thread links in CaseStudySection for testing purposes * fix: update load_agent_soul calls to use cfg.name for improved clarity * fix: update date format in apply_prompt_template for consistency * feat: integrate isMock parameter into artifact content loading for enhanced testing * docs: add license section to SKILL.md for clarity and attribution * feat(agent): enhance model resolution and agent configuration handling * chore: remove unused import of _resolve_model_name from agents * feat(agent): remove unused field * fix(agent): set default value for requested_model_name in _resolve_model_name function * feat(agent): update get_available_tools call to handle optional agent_config and improve middleware function signature --------- Co-authored-by: Willem Jiang <willem.jiang@gmail.com>
This commit is contained in:
@@ -14,6 +14,7 @@ from src.agents.middlewares.title_middleware import TitleMiddleware
|
||||
from src.agents.middlewares.uploads_middleware import UploadsMiddleware
|
||||
from src.agents.middlewares.view_image_middleware import ViewImageMiddleware
|
||||
from src.agents.thread_state import ThreadState
|
||||
from src.config.agents_config import load_agent_config
|
||||
from src.config.app_config import get_app_config
|
||||
from src.config.summarization_config import get_summarization_config
|
||||
from src.models import create_chat_model
|
||||
@@ -22,14 +23,12 @@ from src.sandbox.middleware import SandboxMiddleware
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _resolve_model_name(requested_model_name: str | None) -> str:
|
||||
def _resolve_model_name(requested_model_name: str | None = None) -> str:
|
||||
"""Resolve a runtime model name safely, falling back to default if invalid. Returns None if no models are configured."""
|
||||
app_config = get_app_config()
|
||||
default_model_name = app_config.models[0].name if app_config.models else None
|
||||
if default_model_name is None:
|
||||
raise ValueError(
|
||||
"No chat models are configured. Please configure at least one model in config.yaml."
|
||||
)
|
||||
raise ValueError("No chat models are configured. Please configure at least one model in config.yaml.")
|
||||
|
||||
if requested_model_name and app_config.get_model_config(requested_model_name):
|
||||
return requested_model_name
|
||||
@@ -205,11 +204,12 @@ Being proactive with task management demonstrates thoroughness and ensures all r
|
||||
# MemoryMiddleware queues conversation for memory update (after TitleMiddleware)
|
||||
# ViewImageMiddleware should be before ClarificationMiddleware to inject image details before LLM
|
||||
# ClarificationMiddleware should be last to intercept clarification requests after model calls
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None):
|
||||
def _build_middlewares(config: RunnableConfig, model_name: str | None, agent_name: str | None = None):
|
||||
"""Build middleware chain based on runtime configuration.
|
||||
|
||||
Args:
|
||||
config: Runtime configuration containing configurable options like is_plan_mode.
|
||||
agent_name: If provided, MemoryMiddleware will use per-agent memory storage.
|
||||
|
||||
Returns:
|
||||
List of middleware instances.
|
||||
@@ -231,7 +231,7 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None):
|
||||
middlewares.append(TitleMiddleware())
|
||||
|
||||
# Add MemoryMiddleware (after TitleMiddleware)
|
||||
middlewares.append(MemoryMiddleware())
|
||||
middlewares.append(MemoryMiddleware(agent_name=agent_name))
|
||||
|
||||
# Add ViewImageMiddleware only if the current model supports vision.
|
||||
# Use the resolved runtime model_name from make_lead_agent to avoid stale config values.
|
||||
@@ -254,28 +254,36 @@ def _build_middlewares(config: RunnableConfig, model_name: str | None):
|
||||
def make_lead_agent(config: RunnableConfig):
|
||||
# Lazy import to avoid circular dependency
|
||||
from src.tools import get_available_tools
|
||||
from src.tools.builtins import setup_agent
|
||||
|
||||
thinking_enabled = config.get("configurable", {}).get("thinking_enabled", True)
|
||||
reasoning_effort = config.get("configurable", {}).get("reasoning_effort", None)
|
||||
requested_model_name = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
||||
model_name = _resolve_model_name(requested_model_name)
|
||||
if model_name is None:
|
||||
raise ValueError(
|
||||
"No chat model could be resolved. Please configure at least one model in "
|
||||
"config.yaml or provide a valid 'model_name'/'model' in the request."
|
||||
)
|
||||
requested_model_name: str | None = config.get("configurable", {}).get("model_name") or config.get("configurable", {}).get("model")
|
||||
is_plan_mode = config.get("configurable", {}).get("is_plan_mode", False)
|
||||
subagent_enabled = config.get("configurable", {}).get("subagent_enabled", False)
|
||||
max_concurrent_subagents = config.get("configurable", {}).get("max_concurrent_subagents", 3)
|
||||
is_bootstrap = config.get("configurable", {}).get("is_bootstrap", False)
|
||||
agent_name = config.get("configurable", {}).get("agent_name")
|
||||
|
||||
agent_config = load_agent_config(agent_name)
|
||||
# Custom agent model or fallback to global/default model resolution
|
||||
agent_model_name = agent_config.model if agent_config and agent_config.model else _resolve_model_name()
|
||||
|
||||
# Final model name resolution with request override, then agent config, then global default
|
||||
model_name = requested_model_name or agent_model_name
|
||||
|
||||
app_config = get_app_config()
|
||||
model_config = app_config.get_model_config(model_name) if model_name else None
|
||||
if thinking_enabled and model_config is not None and not model_config.supports_thinking:
|
||||
|
||||
if model_config is None:
|
||||
raise ValueError("No chat model could be resolved. Please configure at least one model in config.yaml or provide a valid 'model_name'/'model' in the request.")
|
||||
if thinking_enabled and not model_config.supports_thinking:
|
||||
logger.warning(f"Thinking mode is enabled but model '{model_name}' does not support it; fallback to non-thinking mode.")
|
||||
thinking_enabled = False
|
||||
|
||||
logger.info(
|
||||
"thinking_enabled: %s, reasoning_effort: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
|
||||
"Create Agent(%s) -> thinking_enabled: %s, reasoning_effort: %s, model_name: %s, is_plan_mode: %s, subagent_enabled: %s, max_concurrent_subagents: %s",
|
||||
agent_name or "default",
|
||||
thinking_enabled,
|
||||
reasoning_effort,
|
||||
model_name,
|
||||
@@ -287,8 +295,10 @@ def make_lead_agent(config: RunnableConfig):
|
||||
# Inject run metadata for LangSmith trace tagging
|
||||
if "metadata" not in config:
|
||||
config["metadata"] = {}
|
||||
|
||||
config["metadata"].update(
|
||||
{
|
||||
"agent_name": agent_name or "default",
|
||||
"model_name": model_name or "default",
|
||||
"thinking_enabled": thinking_enabled,
|
||||
"reasoning_effort": reasoning_effort,
|
||||
@@ -297,10 +307,23 @@ def make_lead_agent(config: RunnableConfig):
|
||||
}
|
||||
)
|
||||
|
||||
if is_bootstrap:
|
||||
# Special bootstrap agent with minimal prompt for initial custom agent creation flow
|
||||
system_prompt = apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, available_skills=set(["bootstrap"]))
|
||||
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled) + [setup_agent],
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
system_prompt=system_prompt,
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
# Default lead agent (unchanged behavior)
|
||||
return create_agent(
|
||||
model=create_chat_model(name=model_name, thinking_enabled=thinking_enabled, reasoning_effort=reasoning_effort),
|
||||
tools=get_available_tools(model_name=model_name, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents),
|
||||
tools=get_available_tools(model_name=model_name, groups=agent_config.tool_groups if agent_config else None, subagent_enabled=subagent_enabled),
|
||||
middleware=_build_middlewares(config, model_name=model_name, agent_name=agent_name),
|
||||
system_prompt=apply_prompt_template(subagent_enabled=subagent_enabled, max_concurrent_subagents=max_concurrent_subagents, agent_name=agent_name),
|
||||
state_schema=ThreadState,
|
||||
)
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
from src.config.agents_config import load_agent_soul
|
||||
from src.skills import load_skills
|
||||
|
||||
|
||||
@@ -148,9 +149,10 @@ bash("npm test") # Direct execution, not task()
|
||||
|
||||
SYSTEM_PROMPT_TEMPLATE = """
|
||||
<role>
|
||||
You are DeerFlow 2.0, an open-source super agent.
|
||||
You are {agent_name}, an open-source super agent.
|
||||
</role>
|
||||
|
||||
{soul}
|
||||
{memory_context}
|
||||
|
||||
<thinking_style>
|
||||
@@ -280,9 +282,12 @@ Recent breakthroughs in language models have also accelerated progress
|
||||
"""
|
||||
|
||||
|
||||
def _get_memory_context() -> str:
|
||||
def _get_memory_context(agent_name: str | None = None) -> str:
|
||||
"""Get memory context for injection into system prompt.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
Formatted memory context string wrapped in XML tags, or empty string if disabled.
|
||||
"""
|
||||
@@ -294,7 +299,7 @@ def _get_memory_context() -> str:
|
||||
if not config.enabled or not config.injection_enabled:
|
||||
return ""
|
||||
|
||||
memory_data = get_memory_data()
|
||||
memory_data = get_memory_data(agent_name)
|
||||
memory_content = format_memory_for_injection(memory_data, max_tokens=config.max_injection_tokens)
|
||||
|
||||
if not memory_content.strip():
|
||||
@@ -309,7 +314,7 @@ def _get_memory_context() -> str:
|
||||
return ""
|
||||
|
||||
|
||||
def get_skills_prompt_section() -> str:
|
||||
def get_skills_prompt_section(available_skills: set[str] | None = None) -> str:
|
||||
"""Generate the skills prompt section with available skills list.
|
||||
|
||||
Returns the <skill_system>...</skill_system> block listing all enabled skills,
|
||||
@@ -328,6 +333,9 @@ def get_skills_prompt_section() -> str:
|
||||
if not skills:
|
||||
return ""
|
||||
|
||||
if available_skills is not None:
|
||||
skills = [skill for skill in skills if skill.name in available_skills]
|
||||
|
||||
skill_items = "\n".join(
|
||||
f" <skill>\n <name>{skill.name}</name>\n <description>{skill.description}</description>\n <location>{skill.get_container_file_path(container_base_path)}</location>\n </skill>" for skill in skills
|
||||
)
|
||||
@@ -350,9 +358,17 @@ You have access to skills that provide optimized workflows for specific tasks. E
|
||||
</skill_system>"""
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3) -> str:
|
||||
def get_agent_soul(agent_name: str | None) -> str:
|
||||
# Append SOUL.md (agent personality) if present
|
||||
soul = load_agent_soul(agent_name)
|
||||
if soul:
|
||||
return f"<soul>\n{soul}\n</soul>\n" if soul else ""
|
||||
return ""
|
||||
|
||||
|
||||
def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagents: int = 3, *, agent_name: str | None = None, available_skills: set[str] | None = None) -> str:
|
||||
# Get memory context
|
||||
memory_context = _get_memory_context()
|
||||
memory_context = _get_memory_context(agent_name)
|
||||
|
||||
# Include subagent section only if enabled (from runtime parameter)
|
||||
n = max_concurrent_subagents
|
||||
@@ -377,10 +393,12 @@ def apply_prompt_template(subagent_enabled: bool = False, max_concurrent_subagen
|
||||
)
|
||||
|
||||
# Get skills section
|
||||
skills_section = get_skills_prompt_section()
|
||||
skills_section = get_skills_prompt_section(available_skills)
|
||||
|
||||
# Format the prompt with dynamic skills and memory
|
||||
prompt = SYSTEM_PROMPT_TEMPLATE.format(
|
||||
agent_name=agent_name or "DeerFlow 2.0",
|
||||
soul=get_agent_soul(agent_name),
|
||||
skills_section=skills_section,
|
||||
memory_context=memory_context,
|
||||
subagent_section=subagent_section,
|
||||
|
||||
@@ -16,6 +16,7 @@ class ConversationContext:
|
||||
thread_id: str
|
||||
messages: list[Any]
|
||||
timestamp: datetime = field(default_factory=datetime.utcnow)
|
||||
agent_name: str | None = None
|
||||
|
||||
|
||||
class MemoryUpdateQueue:
|
||||
@@ -33,12 +34,13 @@ class MemoryUpdateQueue:
|
||||
self._timer: threading.Timer | None = None
|
||||
self._processing = False
|
||||
|
||||
def add(self, thread_id: str, messages: list[Any]) -> None:
|
||||
def add(self, thread_id: str, messages: list[Any], agent_name: str | None = None) -> None:
|
||||
"""Add a conversation to the update queue.
|
||||
|
||||
Args:
|
||||
thread_id: The thread ID.
|
||||
messages: The conversation messages.
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
"""
|
||||
config = get_memory_config()
|
||||
if not config.enabled:
|
||||
@@ -47,6 +49,7 @@ class MemoryUpdateQueue:
|
||||
context = ConversationContext(
|
||||
thread_id=thread_id,
|
||||
messages=messages,
|
||||
agent_name=agent_name,
|
||||
)
|
||||
|
||||
with self._lock:
|
||||
@@ -108,6 +111,7 @@ class MemoryUpdateQueue:
|
||||
success = updater.update_memory(
|
||||
messages=context.messages,
|
||||
thread_id=context.thread_id,
|
||||
agent_name=context.agent_name,
|
||||
)
|
||||
if success:
|
||||
print(f"Memory updated successfully for thread {context.thread_id}")
|
||||
|
||||
@@ -15,8 +15,19 @@ from src.config.paths import get_paths
|
||||
from src.models import create_chat_model
|
||||
|
||||
|
||||
def _get_memory_file_path() -> Path:
|
||||
"""Get the path to the memory file."""
|
||||
def _get_memory_file_path(agent_name: str | None = None) -> Path:
|
||||
"""Get the path to the memory file.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, returns the per-agent memory file path.
|
||||
If None, returns the global memory file path.
|
||||
|
||||
Returns:
|
||||
Path to the memory file.
|
||||
"""
|
||||
if agent_name is not None:
|
||||
return get_paths().agent_memory_file(agent_name)
|
||||
|
||||
config = get_memory_config()
|
||||
if config.storage_path:
|
||||
p = Path(config.storage_path)
|
||||
@@ -44,24 +55,24 @@ def _create_empty_memory() -> dict[str, Any]:
|
||||
}
|
||||
|
||||
|
||||
# Global memory data cache
|
||||
_memory_data: dict[str, Any] | None = None
|
||||
# Track file modification time for cache invalidation
|
||||
_memory_file_mtime: float | None = None
|
||||
# Per-agent memory cache: keyed by agent_name (None = global)
|
||||
# Value: (memory_data, file_mtime)
|
||||
_memory_cache: dict[str | None, tuple[dict[str, Any], float | None]] = {}
|
||||
|
||||
|
||||
def get_memory_data() -> dict[str, Any]:
|
||||
def get_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Get the current memory data (cached with file modification time check).
|
||||
|
||||
The cache is automatically invalidated if the memory file has been modified
|
||||
since the last load, ensuring fresh data is always returned.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory. If None, loads global memory.
|
||||
|
||||
Returns:
|
||||
The memory data dictionary.
|
||||
"""
|
||||
global _memory_data, _memory_file_mtime
|
||||
|
||||
file_path = _get_memory_file_path()
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
|
||||
# Get current file modification time
|
||||
try:
|
||||
@@ -69,41 +80,48 @@ def get_memory_data() -> dict[str, Any]:
|
||||
except OSError:
|
||||
current_mtime = None
|
||||
|
||||
cached = _memory_cache.get(agent_name)
|
||||
|
||||
# Invalidate cache if file has been modified or doesn't exist
|
||||
if _memory_data is None or _memory_file_mtime != current_mtime:
|
||||
_memory_data = _load_memory_from_file()
|
||||
_memory_file_mtime = current_mtime
|
||||
if cached is None or cached[1] != current_mtime:
|
||||
memory_data = _load_memory_from_file(agent_name)
|
||||
_memory_cache[agent_name] = (memory_data, current_mtime)
|
||||
return memory_data
|
||||
|
||||
return _memory_data
|
||||
return cached[0]
|
||||
|
||||
|
||||
def reload_memory_data() -> dict[str, Any]:
|
||||
def reload_memory_data(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Reload memory data from file, forcing cache invalidation.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, reloads per-agent memory. If None, reloads global memory.
|
||||
|
||||
Returns:
|
||||
The reloaded memory data dictionary.
|
||||
"""
|
||||
global _memory_data, _memory_file_mtime
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
memory_data = _load_memory_from_file(agent_name)
|
||||
|
||||
file_path = _get_memory_file_path()
|
||||
_memory_data = _load_memory_from_file()
|
||||
|
||||
# Update file modification time after reload
|
||||
try:
|
||||
_memory_file_mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
mtime = file_path.stat().st_mtime if file_path.exists() else None
|
||||
except OSError:
|
||||
_memory_file_mtime = None
|
||||
mtime = None
|
||||
|
||||
return _memory_data
|
||||
_memory_cache[agent_name] = (memory_data, mtime)
|
||||
return memory_data
|
||||
|
||||
|
||||
def _load_memory_from_file() -> dict[str, Any]:
|
||||
def _load_memory_from_file(agent_name: str | None = None) -> dict[str, Any]:
|
||||
"""Load memory data from file.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, loads per-agent memory file. If None, loads global.
|
||||
|
||||
Returns:
|
||||
The memory data dictionary.
|
||||
"""
|
||||
file_path = _get_memory_file_path()
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
|
||||
if not file_path.exists():
|
||||
return _create_empty_memory()
|
||||
@@ -117,17 +135,17 @@ def _load_memory_from_file() -> dict[str, Any]:
|
||||
return _create_empty_memory()
|
||||
|
||||
|
||||
def _save_memory_to_file(memory_data: dict[str, Any]) -> bool:
|
||||
def _save_memory_to_file(memory_data: dict[str, Any], agent_name: str | None = None) -> bool:
|
||||
"""Save memory data to file and update cache.
|
||||
|
||||
Args:
|
||||
memory_data: The memory data to save.
|
||||
agent_name: If provided, saves to per-agent memory file. If None, saves to global.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
global _memory_data, _memory_file_mtime
|
||||
file_path = _get_memory_file_path()
|
||||
file_path = _get_memory_file_path(agent_name)
|
||||
|
||||
try:
|
||||
# Ensure directory exists
|
||||
@@ -145,11 +163,12 @@ def _save_memory_to_file(memory_data: dict[str, Any]) -> bool:
|
||||
temp_path.replace(file_path)
|
||||
|
||||
# Update cache and file modification time
|
||||
_memory_data = memory_data
|
||||
try:
|
||||
_memory_file_mtime = file_path.stat().st_mtime
|
||||
mtime = file_path.stat().st_mtime
|
||||
except OSError:
|
||||
_memory_file_mtime = None
|
||||
mtime = None
|
||||
|
||||
_memory_cache[agent_name] = (memory_data, mtime)
|
||||
|
||||
print(f"Memory saved to {file_path}")
|
||||
return True
|
||||
@@ -175,12 +194,13 @@ class MemoryUpdater:
|
||||
model_name = self._model_name or config.model_name
|
||||
return create_chat_model(name=model_name, thinking_enabled=False)
|
||||
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None) -> bool:
|
||||
def update_memory(self, messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
"""Update memory based on conversation messages.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID for tracking source.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
|
||||
Returns:
|
||||
True if update was successful, False otherwise.
|
||||
@@ -194,7 +214,7 @@ class MemoryUpdater:
|
||||
|
||||
try:
|
||||
# Get current memory
|
||||
current_memory = get_memory_data()
|
||||
current_memory = get_memory_data(agent_name)
|
||||
|
||||
# Format conversation for prompt
|
||||
conversation_text = format_conversation_for_update(messages)
|
||||
@@ -225,7 +245,7 @@ class MemoryUpdater:
|
||||
updated_memory = self._apply_updates(current_memory, update_data, thread_id)
|
||||
|
||||
# Save
|
||||
return _save_memory_to_file(updated_memory)
|
||||
return _save_memory_to_file(updated_memory, agent_name)
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Failed to parse LLM response for memory update: {e}")
|
||||
@@ -305,15 +325,16 @@ class MemoryUpdater:
|
||||
return current_memory
|
||||
|
||||
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None) -> bool:
|
||||
def update_memory_from_conversation(messages: list[Any], thread_id: str | None = None, agent_name: str | None = None) -> bool:
|
||||
"""Convenience function to update memory from a conversation.
|
||||
|
||||
Args:
|
||||
messages: List of conversation messages.
|
||||
thread_id: Optional thread ID.
|
||||
agent_name: If provided, updates per-agent memory. If None, updates global memory.
|
||||
|
||||
Returns:
|
||||
True if successful, False otherwise.
|
||||
"""
|
||||
updater = MemoryUpdater()
|
||||
return updater.update_memory(messages, thread_id)
|
||||
return updater.update_memory(messages, thread_id, agent_name)
|
||||
|
||||
@@ -62,6 +62,15 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
|
||||
state_schema = MemoryMiddlewareState
|
||||
|
||||
def __init__(self, agent_name: str | None = None):
|
||||
"""Initialize the MemoryMiddleware.
|
||||
|
||||
Args:
|
||||
agent_name: If provided, memory is stored per-agent. If None, uses global memory.
|
||||
"""
|
||||
super().__init__()
|
||||
self._agent_name = agent_name
|
||||
|
||||
@override
|
||||
def after_agent(self, state: MemoryMiddlewareState, runtime: Runtime) -> dict | None:
|
||||
"""Queue conversation for memory update after agent completes.
|
||||
@@ -102,6 +111,6 @@ class MemoryMiddleware(AgentMiddleware[MemoryMiddlewareState]):
|
||||
|
||||
# Queue the filtered conversation for memory update
|
||||
queue = get_memory_queue()
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages)
|
||||
queue.add(thread_id=thread_id, messages=filtered_messages, agent_name=self._agent_name)
|
||||
|
||||
return None
|
||||
|
||||
Reference in New Issue
Block a user