2025-05-29 19:52:34 +08:00
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
2025-05-28 14:13:46 +08:00
import logging
from typing import List , Optional , Type
from langchain_core . tools import BaseTool
from langchain_core . callbacks import (
AsyncCallbackManagerForToolRun ,
CallbackManagerForToolRun ,
)
from pydantic import BaseModel , Field
from src . config . tools import SELECTED_RAG_PROVIDER
from src . rag import Document , Retriever , Resource , build_retriever
logger = logging . getLogger ( __name__ )
class RetrieverInput ( BaseModel ) :
keywords : str = Field ( description = " search keywords to look up " )
class RetrieverTool ( BaseTool ) :
name : str = " local_search_tool "
2025-06-03 11:48:51 +08:00
description : str = (
" Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords. "
)
2025-05-28 14:13:46 +08:00
args_schema : Type [ BaseModel ] = RetrieverInput
retriever : Retriever = Field ( default_factory = Retriever )
resources : list [ Resource ] = Field ( default_factory = list )
def _run (
self ,
keywords : str ,
run_manager : Optional [ CallbackManagerForToolRun ] = None ,
) - > list [ Document ] :
logger . info (
f " Retriever tool query: { keywords } " , extra = { " resources " : self . resources }
)
documents = self . retriever . query_relevant_documents ( keywords , self . resources )
if not documents :
return " No results found from the local knowledge base. "
return [ doc . to_dict ( ) for doc in documents ]
async def _arun (
self ,
keywords : str ,
run_manager : Optional [ AsyncCallbackManagerForToolRun ] = None ,
) - > list [ Document ] :
return self . _run ( keywords , run_manager . get_sync ( ) )
def get_retriever_tool ( resources : List [ Resource ] ) - > RetrieverTool | None :
if not resources :
return None
logger . info ( f " create retriever tool: { SELECTED_RAG_PROVIDER } " )
retriever = build_retriever ( )
if not retriever :
return None
return RetrieverTool ( retriever = retriever , resources = resources )