This commit is contained in:
Francisco Arceo 2025-10-27 10:47:57 -07:00 committed by GitHub
commit 1665c6c4be
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
470 changed files with 17725 additions and 2810 deletions

View file

@ -84,9 +84,9 @@ from .persistence import AgentPersistence
from .safety import SafetyException, ShieldRunnerMixin
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
MEMORY_QUERY_TOOL = "knowledge_search"
MEMORY_QUERY_TOOL = "file_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
RAG_TOOL_GROUP = "builtin::file_search"
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -927,14 +927,14 @@ class ChatAgent(ShieldRunnerMixin):
"""Parse a toolgroup name into its components.
Args:
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
toolgroup_name: The toolgroup name to parse (e.g. "builtin::file_search/file_search")
Returns:
A tuple of (tool_type, tool_group, tool_name)
"""
split_names = toolgroup_name_with_maybe_tool_name.split("/")
if len(split_names) == 2:
# e.g. "builtin::rag"
# e.g. "builtin::file_search"
tool_group, tool_name = split_names
else:
tool_group, tool_name = split_names[0], None

View file

@ -677,7 +677,7 @@ class StreamingResponseOrchestrator:
# Emit output_item.added event for the new function call
self.sequence_number += 1
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
if not is_mcp_tool and tool_call.function.name not in ["web_search", "file_search"]:
# for MCP tools (and even other non-function tools) we emit an output message item later
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
arguments="", # Will be filled incrementally via delta events
@ -902,7 +902,7 @@ class StreamingResponseOrchestrator:
id=matching_item_id,
status="in_progress",
)
elif tool_call.function.name == "knowledge_search":
elif tool_call.function.name == "file_search":
item = OpenAIResponseOutputMessageFileSearchToolCall(
id=matching_item_id,
status="in_progress",
@ -1021,7 +1021,7 @@ class StreamingResponseOrchestrator:
raise ValueError(f"Tool {tool_name} not found")
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
elif input_tool.type == "file_search":
tool_name = "knowledge_search"
tool_name = "file_search"
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
if not tool:
raise ValueError(f"Tool {tool_name} not found")

View file

@ -104,12 +104,12 @@ class ToolExecutor:
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
)
async def _execute_knowledge_search_via_vector_store(
async def _execute_file_search_via_vector_store(
self,
query: str,
response_file_search_tool: OpenAIResponseInputToolFileSearch,
) -> ToolInvocationResult:
"""Execute knowledge search using vector_stores.search API with filters support."""
"""Execute file search using vector_stores.search API with filters support."""
search_results = []
# Create search tasks for all vector stores
@ -139,7 +139,7 @@ class ToolExecutor:
content_items = []
content_items.append(
TextContentItem(
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
text=f"file_search tool found {len(search_results)} chunks:\nBEGIN of file_search tool results.\n"
)
)
@ -158,7 +158,7 @@ class ToolExecutor:
content_items.append(TextContentItem(text=text_content))
unique_files.add(file_id)
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
content_items.append(TextContentItem(text="END of file_search tool results.\n"))
citation_instruction = ""
if unique_files:
@ -224,7 +224,7 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "knowledge_search":
elif function_name == "file_search":
sequence_number += 1
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
item_id=item_id,
@ -246,7 +246,7 @@ class ToolExecutor:
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
# For file search, emit searching event
if function_name == "knowledge_search":
if function_name == "file_search":
sequence_number += 1
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
item_id=item_id,
@ -283,17 +283,17 @@ class ToolExecutor:
tool_name=function_name,
kwargs=tool_kwargs,
)
elif function_name == "knowledge_search":
elif function_name == "file_search":
response_file_search_tool = next(
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
None,
)
if response_file_search_tool:
# Use vector_stores.search API instead of knowledge_search tool
# Use vector_stores.search API instead of file_search tool
# to support filters and ranking_options
query = tool_kwargs.get("query", "")
async with tracing.span("knowledge_search", {}):
result = await self._execute_knowledge_search_via_vector_store(
async with tracing.span("file_search", {}):
result = await self._execute_file_search_via_vector_store(
query=query,
response_file_search_tool=response_file_search_tool,
)
@ -341,7 +341,7 @@ class ToolExecutor:
output_index=output_index,
sequence_number=sequence_number,
)
elif function_name == "knowledge_search":
elif function_name == "file_search":
sequence_number += 1
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
item_id=item_id,
@ -395,7 +395,7 @@ class ToolExecutor:
)
if has_error:
message.status = "failed"
elif function.name == "knowledge_search":
elif function.name == "file_search":
message = OpenAIResponseOutputMessageFileSearchToolCall(
id=item_id,
queries=[tool_kwargs.get("query", "")],

View file

@ -12,8 +12,8 @@ from .config import RagToolRuntimeConfig
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
from .memory import MemoryToolRuntimeImpl
from .file_search import FileSearchToolRuntimeImpl
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
impl = FileSearchToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
await impl.initialize()
return impl

View file

@ -9,19 +9,19 @@ from jinja2 import Template
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
from llama_stack.apis.tools.rag_tool import (
DefaultRAGQueryGeneratorConfig,
LLMRAGQueryGeneratorConfig,
RAGQueryGenerator,
RAGQueryGeneratorConfig,
from llama_stack.apis.tools.file_search_tool import (
DefaultFileSearchGeneratorConfig,
FileSearchGenerator,
FileSearchGeneratorConfig,
LLMFileSearchGeneratorConfig,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
async def generate_rag_query(
config: RAGQueryGeneratorConfig,
async def generate_file_search_query(
config: FileSearchGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
@ -29,25 +29,25 @@ async def generate_rag_query(
Generates a query that will be used for
retrieving relevant information from the memory bank.
"""
if config.type == RAGQueryGenerator.default.value:
query = await default_rag_query_generator(config, content, **kwargs)
elif config.type == RAGQueryGenerator.llm.value:
query = await llm_rag_query_generator(config, content, **kwargs)
if config.type == FileSearchGenerator.default.value:
query = await default_file_search_query_generator(config, content, **kwargs)
elif config.type == FileSearchGenerator.llm.value:
query = await llm_file_search_query_generator(config, content, **kwargs)
else:
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
return query
async def default_rag_query_generator(
config: DefaultRAGQueryGeneratorConfig,
async def default_file_search_query_generator(
config: DefaultFileSearchGeneratorConfig,
content: InterleavedContent,
**kwargs,
):
return interleaved_content_as_str(content, sep=config.separator)
async def llm_rag_query_generator(
config: LLMRAGQueryGeneratorConfig,
async def llm_file_search_query_generator(
config: LLMFileSearchGeneratorConfig,
content: InterleavedContent,
**kwargs,
):

View file

@ -23,11 +23,11 @@ from llama_stack.apis.common.content_types import (
from llama_stack.apis.files import Files, OpenAIFilePurpose
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
FileSearchConfig,
FileSearchResult,
FileSearchToolRuntime,
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolGroup,
ToolInvocationResult,
@ -45,7 +45,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import RagToolRuntimeConfig
from .context_retriever import generate_rag_query
from .context_retriever import generate_file_search_query
log = get_logger(name=__name__, category="tool_runtime")
@ -91,7 +91,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
return content_str.encode("utf-8"), "text/plain"
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
class FileSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, FileSearchToolRuntime):
def __init__(
self,
config: RagToolRuntimeConfig,
@ -177,15 +177,15 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
self,
content: InterleavedContent,
vector_db_ids: list[str],
query_config: RAGQueryConfig | None = None,
) -> RAGQueryResult:
query_config: FileSearchConfig | None = None,
) -> FileSearchResult:
if not vector_db_ids:
raise ValueError(
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
)
query_config = query_config or RAGQueryConfig()
query = await generate_rag_query(
query_config = query_config or FileSearchConfig()
query = await generate_file_search_query(
query_config.query_generator_config,
content,
inference_api=self.inference_api,
@ -218,7 +218,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
scores.append(score)
if not chunks:
return RAGQueryResult(content=None)
return FileSearchResult(content=None)
# sort by score
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
@ -226,9 +226,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
tokens = 0
picked: list[InterleavedContentItem] = [
TextContentItem(
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
)
TextContentItem(text=f"file_search tool found {len(chunks)} chunks:\nBEGIN of file_search tool results.\n")
]
for i, chunk in enumerate(chunks):
metadata = chunk.metadata
@ -262,14 +260,14 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
picked.append(TextContentItem(text=text_content))
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(TextContentItem(text="END of file_search tool results.\n"))
picked.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
)
)
return RAGQueryResult(
return FileSearchResult(
content=picked,
metadata={
"document_ids": [c.document_id for c in chunks[: len(picked)]],
@ -292,7 +290,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
name="file_search",
description="Search for information in a database.",
input_schema={
"type": "object",
@ -312,9 +310,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
vector_db_ids = kwargs.get("vector_db_ids", [])
query_config = kwargs.get("query_config")
if query_config:
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
query_config = TypeAdapter(FileSearchConfig).validate_python(query_config)
else:
query_config = RAGQueryConfig()
query_config = FileSearchConfig()
query = kwargs["query"]
result = await self.query(

View file

@ -18,7 +18,7 @@ def available_providers() -> list[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.tool_runtime,
provider_type="inline::rag-runtime",
provider_type="inline::file_search-runtime",
pip_packages=DEFAULT_VECTOR_IO_DEPS
+ [
"tqdm",
@ -29,8 +29,8 @@ def available_providers() -> list[ProviderSpec]:
"sentencepiece",
"transformers",
],
module="llama_stack.providers.inline.tool_runtime.rag",
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
module="llama_stack.providers.inline.tool_runtime.file_search",
config_class="llama_stack.providers.inline.tool_runtime.file_search.config.RagToolRuntimeConfig",
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),

View file

@ -241,33 +241,33 @@ Two ranker types are supported:
- alpha=1: Only use vector scores
- alpha=0.5: Equal weight to both (default)
Example using RAGQueryConfig with different search modes:
Example using FileSearchConfig with different search modes:
```python
from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
from llama_stack.apis.tools import FileSearchConfig, RRFRanker, WeightedRanker
# Vector search
config = RAGQueryConfig(mode="vector", max_chunks=5)
config = FileSearchConfig(mode="vector", max_chunks=5)
# Keyword search
config = RAGQueryConfig(mode="keyword", max_chunks=5)
config = FileSearchConfig(mode="keyword", max_chunks=5)
# Hybrid search with custom RRF ranker
config = RAGQueryConfig(
config = FileSearchConfig(
mode="hybrid",
max_chunks=5,
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
)
# Hybrid search with weighted ranker
config = RAGQueryConfig(
config = FileSearchConfig(
mode="hybrid",
max_chunks=5,
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
)
# Hybrid search with default RRF ranker
config = RAGQueryConfig(
config = FileSearchConfig(
mode="hybrid", max_chunks=5
) # Will use RRF with impact_factor=60.0
```