mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
chore: Rename RagTool FileSearchTool
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
4566eebe05
commit
2d9163529a
288 changed files with 16985 additions and 2071 deletions
|
|
@ -86,7 +86,7 @@ from .safety import SafetyException, ShieldRunnerMixin
|
|||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
RAG_TOOL_GROUP = "builtin::file_search"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
@ -927,14 +927,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::file_search/knowledge_search")
|
||||
|
||||
Returns:
|
||||
A tuple of (tool_type, tool_group, tool_name)
|
||||
"""
|
||||
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(split_names) == 2:
|
||||
# e.g. "builtin::rag"
|
||||
# e.g. "builtin::file_search"
|
||||
tool_group, tool_name = split_names
|
||||
else:
|
||||
tool_group, tool_name = split_names[0], None
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from .config import RagToolRuntimeConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
from .file_search import FileSearchToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
impl = FileSearchToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -9,19 +9,19 @@ from jinja2 import Template
|
|||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
RAGQueryGenerator,
|
||||
RAGQueryGeneratorConfig,
|
||||
from llama_stack.apis.tools.file_search_tool import (
|
||||
DefaultFileSearchGeneratorConfig,
|
||||
FileSearchGenerator,
|
||||
FileSearchGeneratorConfig,
|
||||
LLMFileSearchGeneratorConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: RAGQueryGeneratorConfig,
|
||||
async def generate_file_search_query(
|
||||
config: FileSearchGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -29,25 +29,25 @@ async def generate_rag_query(
|
|||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == RAGQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, content, **kwargs)
|
||||
elif config.type == RAGQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||
if config.type == FileSearchGenerator.default.value:
|
||||
query = await default_file_search_query_generator(config, content, **kwargs)
|
||||
elif config.type == FileSearchGenerator.llm.value:
|
||||
query = await llm_file_search_query_generator(config, content, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultRAGQueryGeneratorConfig,
|
||||
async def default_file_search_query_generator(
|
||||
config: DefaultFileSearchGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
return interleaved_content_as_str(content, sep=config.separator)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMRAGQueryGeneratorConfig,
|
||||
async def llm_file_search_query_generator(
|
||||
config: LLMFileSearchGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -23,11 +23,11 @@ from llama_stack.apis.common.content_types import (
|
|||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
FileSearchConfig,
|
||||
FileSearchResult,
|
||||
FileSearchToolRuntime,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
|
|
@ -45,7 +45,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
from .context_retriever import generate_file_search_query
|
||||
|
||||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
|||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
class FileSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, FileSearchToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
|
|
@ -177,15 +177,15 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
query_config: FileSearchConfig | None = None,
|
||||
) -> FileSearchResult:
|
||||
if not vector_db_ids:
|
||||
raise ValueError(
|
||||
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
||||
)
|
||||
|
||||
query_config = query_config or RAGQueryConfig()
|
||||
query = await generate_rag_query(
|
||||
query_config = query_config or FileSearchConfig()
|
||||
query = await generate_file_search_query(
|
||||
query_config.query_generator_config,
|
||||
content,
|
||||
inference_api=self.inference_api,
|
||||
|
|
@ -218,7 +218,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
scores.append(score)
|
||||
|
||||
if not chunks:
|
||||
return RAGQueryResult(content=None)
|
||||
return FileSearchResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
||||
|
|
@ -269,7 +269,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
return FileSearchResult(
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.document_id for c in chunks[: len(picked)]],
|
||||
|
|
@ -312,9 +312,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
query_config = TypeAdapter(FileSearchConfig).validate_python(query_config)
|
||||
else:
|
||||
query_config = RAGQueryConfig()
|
||||
query_config = FileSearchConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
result = await self.query(
|
||||
Loading…
Add table
Add a link
Reference in a new issue