2025-05-29 19:52:34 +08:00
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
2025-05-28 14:13:46 +08:00
import logging
from typing import List , Optional , Type
from langchain_core . tools import BaseTool
from langchain_core . callbacks import (
AsyncCallbackManagerForToolRun ,
CallbackManagerForToolRun ,
)
from pydantic import BaseModel , Field
from src . config . tools import SELECTED_RAG_PROVIDER
from src . rag import Document , Retriever , Resource , build_retriever
logger = logging . getLogger ( __name__ )
class RetrieverInput ( BaseModel ) :
keywords : str = Field ( description = " search keywords to look up " )
class RetrieverTool ( BaseTool ) :
name : str = " local_search_tool "
2025-05-29 19:52:34 +08:00
description : str = " Useful for retrieving information from the file with `rag://` uri prefix, it should be higher priority than the web search or writing code. Input should be a search keywords. "
2025-05-28 14:13:46 +08:00
args_schema : Type [ BaseModel ] = RetrieverInput
retriever : Retriever = Field ( default_factory = Retriever )
resources : list [ Resource ] = Field ( default_factory = list )
def _run (
self ,
keywords : str ,
run_manager : Optional [ CallbackManagerForToolRun ] = None ,
) - > list [ Document ] :
logger . info (
f " Retriever tool query: { keywords } " , extra = { " resources " : self . resources }
)
documents = self . retriever . query_relevant_documents ( keywords , self . resources )
if not documents :
return " No results found from the local knowledge base. "
return [ doc . to_dict ( ) for doc in documents ]
async def _arun (
self ,
keywords : str ,
run_manager : Optional [ AsyncCallbackManagerForToolRun ] = None ,
) - > list [ Document ] :
return self . _run ( keywords , run_manager . get_sync ( ) )
def get_retriever_tool ( resources : List [ Resource ] ) - > RetrieverTool | None :
if not resources :
return None
logger . info ( f " create retriever tool: { SELECTED_RAG_PROVIDER } " )
retriever = build_retriever ( )
if not retriever :
return None
return RetrieverTool ( retriever = retriever , resources = resources )
if __name__ == " __main__ " :
resources = [
Resource (
uri = " rag://dataset/1c7e2ea4362911f09a41c290d4b6a7f0 " ,
title = " 西游记 " ,
description = " 西游记是中国古代四大名著之一,讲述了唐僧师徒四人西天取经的故事。 " ,
)
]
retriever_tool = get_retriever_tool ( resources )
print ( retriever_tool . name )
print ( retriever_tool . description )
print ( retriever_tool . args )
print ( retriever_tool . invoke ( " 三打白骨精 " ) )