mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
Merge 679e2d09c1 into 63422e5b36
This commit is contained in:
commit
1665c6c4be
470 changed files with 17725 additions and 2810 deletions
|
|
@ -84,9 +84,9 @@ from .persistence import AgentPersistence
|
|||
from .safety import SafetyException, ShieldRunnerMixin
|
||||
|
||||
TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})")
|
||||
MEMORY_QUERY_TOOL = "knowledge_search"
|
||||
MEMORY_QUERY_TOOL = "file_search"
|
||||
WEB_SEARCH_TOOL = "web_search"
|
||||
RAG_TOOL_GROUP = "builtin::rag"
|
||||
RAG_TOOL_GROUP = "builtin::file_search"
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
|
@ -927,14 +927,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
"""Parse a toolgroup name into its components.
|
||||
|
||||
Args:
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search")
|
||||
toolgroup_name: The toolgroup name to parse (e.g. "builtin::file_search/file_search")
|
||||
|
||||
Returns:
|
||||
A tuple of (tool_type, tool_group, tool_name)
|
||||
"""
|
||||
split_names = toolgroup_name_with_maybe_tool_name.split("/")
|
||||
if len(split_names) == 2:
|
||||
# e.g. "builtin::rag"
|
||||
# e.g. "builtin::file_search"
|
||||
tool_group, tool_name = split_names
|
||||
else:
|
||||
tool_group, tool_name = split_names[0], None
|
||||
|
|
|
|||
|
|
@ -677,7 +677,7 @@ class StreamingResponseOrchestrator:
|
|||
# Emit output_item.added event for the new function call
|
||||
self.sequence_number += 1
|
||||
is_mcp_tool = tool_call.function.name and tool_call.function.name in self.mcp_tool_to_server
|
||||
if not is_mcp_tool and tool_call.function.name not in ["web_search", "knowledge_search"]:
|
||||
if not is_mcp_tool and tool_call.function.name not in ["web_search", "file_search"]:
|
||||
# for MCP tools (and even other non-function tools) we emit an output message item later
|
||||
function_call_item = OpenAIResponseOutputMessageFunctionToolCall(
|
||||
arguments="", # Will be filled incrementally via delta events
|
||||
|
|
@ -902,7 +902,7 @@ class StreamingResponseOrchestrator:
|
|||
id=matching_item_id,
|
||||
status="in_progress",
|
||||
)
|
||||
elif tool_call.function.name == "knowledge_search":
|
||||
elif tool_call.function.name == "file_search":
|
||||
item = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=matching_item_id,
|
||||
status="in_progress",
|
||||
|
|
@ -1021,7 +1021,7 @@ class StreamingResponseOrchestrator:
|
|||
raise ValueError(f"Tool {tool_name} not found")
|
||||
self.ctx.chat_tools.append(make_openai_tool(tool_name, tool))
|
||||
elif input_tool.type == "file_search":
|
||||
tool_name = "knowledge_search"
|
||||
tool_name = "file_search"
|
||||
tool = await self.tool_executor.tool_groups_api.get_tool(tool_name)
|
||||
if not tool:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
|
|
|||
|
|
@ -104,12 +104,12 @@ class ToolExecutor:
|
|||
citation_files=result.metadata.get("citation_files") if result and result.metadata else None,
|
||||
)
|
||||
|
||||
async def _execute_knowledge_search_via_vector_store(
|
||||
async def _execute_file_search_via_vector_store(
|
||||
self,
|
||||
query: str,
|
||||
response_file_search_tool: OpenAIResponseInputToolFileSearch,
|
||||
) -> ToolInvocationResult:
|
||||
"""Execute knowledge search using vector_stores.search API with filters support."""
|
||||
"""Execute file search using vector_stores.search API with filters support."""
|
||||
search_results = []
|
||||
|
||||
# Create search tasks for all vector stores
|
||||
|
|
@ -139,7 +139,7 @@ class ToolExecutor:
|
|||
content_items = []
|
||||
content_items.append(
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(search_results)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
text=f"file_search tool found {len(search_results)} chunks:\nBEGIN of file_search tool results.\n"
|
||||
)
|
||||
)
|
||||
|
||||
|
|
@ -158,7 +158,7 @@ class ToolExecutor:
|
|||
content_items.append(TextContentItem(text=text_content))
|
||||
unique_files.add(file_id)
|
||||
|
||||
content_items.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
content_items.append(TextContentItem(text="END of file_search tool results.\n"))
|
||||
|
||||
citation_instruction = ""
|
||||
if unique_files:
|
||||
|
|
@ -224,7 +224,7 @@ class ToolExecutor:
|
|||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
elif function_name == "file_search":
|
||||
sequence_number += 1
|
||||
progress_event = OpenAIResponseObjectStreamResponseFileSearchCallInProgress(
|
||||
item_id=item_id,
|
||||
|
|
@ -246,7 +246,7 @@ class ToolExecutor:
|
|||
yield ToolExecutionResult(stream_event=searching_event, sequence_number=sequence_number)
|
||||
|
||||
# For file search, emit searching event
|
||||
if function_name == "knowledge_search":
|
||||
if function_name == "file_search":
|
||||
sequence_number += 1
|
||||
searching_event = OpenAIResponseObjectStreamResponseFileSearchCallSearching(
|
||||
item_id=item_id,
|
||||
|
|
@ -283,17 +283,17 @@ class ToolExecutor:
|
|||
tool_name=function_name,
|
||||
kwargs=tool_kwargs,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
elif function_name == "file_search":
|
||||
response_file_search_tool = next(
|
||||
(t for t in ctx.response_tools if isinstance(t, OpenAIResponseInputToolFileSearch)),
|
||||
None,
|
||||
)
|
||||
if response_file_search_tool:
|
||||
# Use vector_stores.search API instead of knowledge_search tool
|
||||
# Use vector_stores.search API instead of file_search tool
|
||||
# to support filters and ranking_options
|
||||
query = tool_kwargs.get("query", "")
|
||||
async with tracing.span("knowledge_search", {}):
|
||||
result = await self._execute_knowledge_search_via_vector_store(
|
||||
async with tracing.span("file_search", {}):
|
||||
result = await self._execute_file_search_via_vector_store(
|
||||
query=query,
|
||||
response_file_search_tool=response_file_search_tool,
|
||||
)
|
||||
|
|
@ -341,7 +341,7 @@ class ToolExecutor:
|
|||
output_index=output_index,
|
||||
sequence_number=sequence_number,
|
||||
)
|
||||
elif function_name == "knowledge_search":
|
||||
elif function_name == "file_search":
|
||||
sequence_number += 1
|
||||
completion_event = OpenAIResponseObjectStreamResponseFileSearchCallCompleted(
|
||||
item_id=item_id,
|
||||
|
|
@ -395,7 +395,7 @@ class ToolExecutor:
|
|||
)
|
||||
if has_error:
|
||||
message.status = "failed"
|
||||
elif function.name == "knowledge_search":
|
||||
elif function.name == "file_search":
|
||||
message = OpenAIResponseOutputMessageFileSearchToolCall(
|
||||
id=item_id,
|
||||
queries=[tool_kwargs.get("query", "")],
|
||||
|
|
|
|||
|
|
@ -12,8 +12,8 @@ from .config import RagToolRuntimeConfig
|
|||
|
||||
|
||||
async def get_provider_impl(config: RagToolRuntimeConfig, deps: dict[Api, Any]):
|
||||
from .memory import MemoryToolRuntimeImpl
|
||||
from .file_search import FileSearchToolRuntimeImpl
|
||||
|
||||
impl = MemoryToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
impl = FileSearchToolRuntimeImpl(config, deps[Api.vector_io], deps[Api.inference], deps[Api.files])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
@ -9,19 +9,19 @@ from jinja2 import Template
|
|||
|
||||
from llama_stack.apis.common.content_types import InterleavedContent
|
||||
from llama_stack.apis.inference import OpenAIChatCompletionRequestWithExtraBody, OpenAIUserMessageParam
|
||||
from llama_stack.apis.tools.rag_tool import (
|
||||
DefaultRAGQueryGeneratorConfig,
|
||||
LLMRAGQueryGeneratorConfig,
|
||||
RAGQueryGenerator,
|
||||
RAGQueryGeneratorConfig,
|
||||
from llama_stack.apis.tools.file_search_tool import (
|
||||
DefaultFileSearchGeneratorConfig,
|
||||
FileSearchGenerator,
|
||||
FileSearchGeneratorConfig,
|
||||
LLMFileSearchGeneratorConfig,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
)
|
||||
|
||||
|
||||
async def generate_rag_query(
|
||||
config: RAGQueryGeneratorConfig,
|
||||
async def generate_file_search_query(
|
||||
config: FileSearchGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -29,25 +29,25 @@ async def generate_rag_query(
|
|||
Generates a query that will be used for
|
||||
retrieving relevant information from the memory bank.
|
||||
"""
|
||||
if config.type == RAGQueryGenerator.default.value:
|
||||
query = await default_rag_query_generator(config, content, **kwargs)
|
||||
elif config.type == RAGQueryGenerator.llm.value:
|
||||
query = await llm_rag_query_generator(config, content, **kwargs)
|
||||
if config.type == FileSearchGenerator.default.value:
|
||||
query = await default_file_search_query_generator(config, content, **kwargs)
|
||||
elif config.type == FileSearchGenerator.llm.value:
|
||||
query = await llm_file_search_query_generator(config, content, **kwargs)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported memory query generator {config.type}")
|
||||
return query
|
||||
|
||||
|
||||
async def default_rag_query_generator(
|
||||
config: DefaultRAGQueryGeneratorConfig,
|
||||
async def default_file_search_query_generator(
|
||||
config: DefaultFileSearchGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
return interleaved_content_as_str(content, sep=config.separator)
|
||||
|
||||
|
||||
async def llm_rag_query_generator(
|
||||
config: LLMRAGQueryGeneratorConfig,
|
||||
async def llm_file_search_query_generator(
|
||||
config: LLMFileSearchGeneratorConfig,
|
||||
content: InterleavedContent,
|
||||
**kwargs,
|
||||
):
|
||||
|
|
@ -23,11 +23,11 @@ from llama_stack.apis.common.content_types import (
|
|||
from llama_stack.apis.files import Files, OpenAIFilePurpose
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.tools import (
|
||||
FileSearchConfig,
|
||||
FileSearchResult,
|
||||
FileSearchToolRuntime,
|
||||
ListToolDefsResponse,
|
||||
RAGDocument,
|
||||
RAGQueryConfig,
|
||||
RAGQueryResult,
|
||||
RAGToolRuntime,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
ToolInvocationResult,
|
||||
|
|
@ -45,7 +45,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
|
|||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import RagToolRuntimeConfig
|
||||
from .context_retriever import generate_rag_query
|
||||
from .context_retriever import generate_file_search_query
|
||||
|
||||
log = get_logger(name=__name__, category="tool_runtime")
|
||||
|
||||
|
|
@ -91,7 +91,7 @@ async def raw_data_from_doc(doc: RAGDocument) -> tuple[bytes, str]:
|
|||
return content_str.encode("utf-8"), "text/plain"
|
||||
|
||||
|
||||
class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||
class FileSearchToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, FileSearchToolRuntime):
|
||||
def __init__(
|
||||
self,
|
||||
config: RagToolRuntimeConfig,
|
||||
|
|
@ -177,15 +177,15 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
self,
|
||||
content: InterleavedContent,
|
||||
vector_db_ids: list[str],
|
||||
query_config: RAGQueryConfig | None = None,
|
||||
) -> RAGQueryResult:
|
||||
query_config: FileSearchConfig | None = None,
|
||||
) -> FileSearchResult:
|
||||
if not vector_db_ids:
|
||||
raise ValueError(
|
||||
"No vector DBs were provided to the knowledge search tool. Please provide at least one vector DB ID."
|
||||
)
|
||||
|
||||
query_config = query_config or RAGQueryConfig()
|
||||
query = await generate_rag_query(
|
||||
query_config = query_config or FileSearchConfig()
|
||||
query = await generate_file_search_query(
|
||||
query_config.query_generator_config,
|
||||
content,
|
||||
inference_api=self.inference_api,
|
||||
|
|
@ -218,7 +218,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
scores.append(score)
|
||||
|
||||
if not chunks:
|
||||
return RAGQueryResult(content=None)
|
||||
return FileSearchResult(content=None)
|
||||
|
||||
# sort by score
|
||||
chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) # type: ignore
|
||||
|
|
@ -226,9 +226,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
|
||||
tokens = 0
|
||||
picked: list[InterleavedContentItem] = [
|
||||
TextContentItem(
|
||||
text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n"
|
||||
)
|
||||
TextContentItem(text=f"file_search tool found {len(chunks)} chunks:\nBEGIN of file_search tool results.\n")
|
||||
]
|
||||
for i, chunk in enumerate(chunks):
|
||||
metadata = chunk.metadata
|
||||
|
|
@ -262,14 +260,14 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
text_content = query_config.chunk_template.format(index=i + 1, chunk=chunk, metadata=metadata_for_context)
|
||||
picked.append(TextContentItem(text=text_content))
|
||||
|
||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||
picked.append(TextContentItem(text="END of file_search tool results.\n"))
|
||||
picked.append(
|
||||
TextContentItem(
|
||||
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||
)
|
||||
)
|
||||
|
||||
return RAGQueryResult(
|
||||
return FileSearchResult(
|
||||
content=picked,
|
||||
metadata={
|
||||
"document_ids": [c.document_id for c in chunks[: len(picked)]],
|
||||
|
|
@ -292,7 +290,7 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
description="Insert documents into memory",
|
||||
),
|
||||
ToolDef(
|
||||
name="knowledge_search",
|
||||
name="file_search",
|
||||
description="Search for information in a database.",
|
||||
input_schema={
|
||||
"type": "object",
|
||||
|
|
@ -312,9 +310,9 @@ class MemoryToolRuntimeImpl(ToolGroupsProtocolPrivate, ToolRuntime, RAGToolRunti
|
|||
vector_db_ids = kwargs.get("vector_db_ids", [])
|
||||
query_config = kwargs.get("query_config")
|
||||
if query_config:
|
||||
query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config)
|
||||
query_config = TypeAdapter(FileSearchConfig).validate_python(query_config)
|
||||
else:
|
||||
query_config = RAGQueryConfig()
|
||||
query_config = FileSearchConfig()
|
||||
|
||||
query = kwargs["query"]
|
||||
result = await self.query(
|
||||
|
|
@ -18,7 +18,7 @@ def available_providers() -> list[ProviderSpec]:
|
|||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.tool_runtime,
|
||||
provider_type="inline::rag-runtime",
|
||||
provider_type="inline::file_search-runtime",
|
||||
pip_packages=DEFAULT_VECTOR_IO_DEPS
|
||||
+ [
|
||||
"tqdm",
|
||||
|
|
@ -29,8 +29,8 @@ def available_providers() -> list[ProviderSpec]:
|
|||
"sentencepiece",
|
||||
"transformers",
|
||||
],
|
||||
module="llama_stack.providers.inline.tool_runtime.rag",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.rag.config.RagToolRuntimeConfig",
|
||||
module="llama_stack.providers.inline.tool_runtime.file_search",
|
||||
config_class="llama_stack.providers.inline.tool_runtime.file_search.config.RagToolRuntimeConfig",
|
||||
api_dependencies=[Api.vector_io, Api.inference, Api.files],
|
||||
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
|
||||
),
|
||||
|
|
|
|||
|
|
@ -241,33 +241,33 @@ Two ranker types are supported:
|
|||
- alpha=1: Only use vector scores
|
||||
- alpha=0.5: Equal weight to both (default)
|
||||
|
||||
Example using RAGQueryConfig with different search modes:
|
||||
Example using FileSearchConfig with different search modes:
|
||||
|
||||
```python
|
||||
from llama_stack.apis.tools import RAGQueryConfig, RRFRanker, WeightedRanker
|
||||
from llama_stack.apis.tools import FileSearchConfig, RRFRanker, WeightedRanker
|
||||
|
||||
# Vector search
|
||||
config = RAGQueryConfig(mode="vector", max_chunks=5)
|
||||
config = FileSearchConfig(mode="vector", max_chunks=5)
|
||||
|
||||
# Keyword search
|
||||
config = RAGQueryConfig(mode="keyword", max_chunks=5)
|
||||
config = FileSearchConfig(mode="keyword", max_chunks=5)
|
||||
|
||||
# Hybrid search with custom RRF ranker
|
||||
config = RAGQueryConfig(
|
||||
config = FileSearchConfig(
|
||||
mode="hybrid",
|
||||
max_chunks=5,
|
||||
ranker=RRFRanker(impact_factor=50.0), # Custom impact factor
|
||||
)
|
||||
|
||||
# Hybrid search with weighted ranker
|
||||
config = RAGQueryConfig(
|
||||
config = FileSearchConfig(
|
||||
mode="hybrid",
|
||||
max_chunks=5,
|
||||
ranker=WeightedRanker(alpha=0.7), # 70% vector, 30% keyword
|
||||
)
|
||||
|
||||
# Hybrid search with default RRF ranker
|
||||
config = RAGQueryConfig(
|
||||
config = FileSearchConfig(
|
||||
mode="hybrid", max_chunks=5
|
||||
) # Will use RRF with impact_factor=60.0
|
||||
```
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue