diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 3b3059285..329734f4c 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -803,7 +803,7 @@ } ], "source": [ - "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", + "model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n", "\n", "model_id\n" ] @@ -1688,7 +1688,7 @@ " enable_session_persistence=False,\n", " toolgroups = [\n", " {\n", - " \"name\": \"builtin::rag\",\n", + " \"name\": \"builtin::rag/knowledge_search\",\n", " \"args\" : {\n", " \"vector_db_ids\": [vector_db_id],\n", " }\n", diff --git a/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb b/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb index 3979088c1..ae50b95a1 100644 --- a/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb +++ b/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb @@ -3,6 +3,8 @@ { "cell_type": "markdown", "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb)\n", + "\n", "# [Alpha] Llama Stack Post Training\n", "This notebook will use a real world problem (improve LLM as tax preparer) to walk through the main sets of APIs we offer with Llama stack for post training to improve the LLM performance for agentic apps (We support supervised finetune now, RLHF and knowledge distillation will come soon!).\n", "\n", @@ -64,7 +66,7 @@ "output_type": "stream", "name": "stdout", "text": [ - "Collecting git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer\n", + "Collecting git+https://github.com/meta-llama/llama-stack.git\n", " Cloning https://github.com/meta-llama/llama-stack.git (to revision hf_format_checkpointer) to /tmp/pip-req-build-j_1bxqzm\n", " Running command git clone --filter=blob:none --quiet https://github.com/meta-llama/llama-stack.git /tmp/pip-req-build-j_1bxqzm\n", " Running command git checkout -b hf_format_checkpointer --track origin/hf_format_checkpointer\n", @@ -76,7 +78,7 @@ } ], "source": [ - "!pip install git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer" + "!pip install git+https://github.com/meta-llama/llama-stack.git #TODO: update this after the next pkg release" ] }, { diff --git a/docs/source/building_applications/agent_execution_loop.md b/docs/source/building_applications/agent_execution_loop.md index 6b3f64423..0d212df7a 100644 --- a/docs/source/building_applications/agent_execution_loop.md +++ b/docs/source/building_applications/agent_execution_loop.md @@ -7,12 +7,12 @@ Each agent turn follows these key steps: 1. **Initial Safety Check**: The user's input is first screened through configured safety shields 2. **Context Retrieval**: - - If RAG is enabled, the agent queries relevant documents from memory banks - - For new documents, they are first inserted into the memory bank - - Retrieved context is augmented to the user's prompt + - If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent. + - For new documents, they are first inserted into the memory bank. + - Retrieved context is provided to the LLM as a tool response in the message history. 3. **Inference Loop**: The agent enters its main execution loop: - - The LLM receives the augmented prompt (with context and/or previous tool outputs) + - The LLM receives a user prompt (with previous tool outputs) - The LLM generates a response, potentially with tool calls - If tool calls are present: - Tool inputs are safety-checked @@ -40,19 +40,16 @@ sequenceDiagram S->>E: Input Safety Check deactivate S - E->>M: 2.1 Query Context - M-->>E: 2.2 Retrieved Documents - loop Inference Loop - E->>L: 3.1 Augment with Context - L-->>E: 3.2 Response (with/without tool calls) + E->>L: 2.1 Augment with Context + L-->>E: 2.2 Response (with/without tool calls) alt Has Tool Calls E->>S: Check Tool Input - S->>T: 4.1 Execute Tool - T-->>E: 4.2 Tool Response - E->>L: 5.1 Tool Response - L-->>E: 5.2 Synthesized Response + S->>T: 3.1 Execute Tool + T-->>E: 3.2 Tool Response + E->>L: 4.1 Tool Response + L-->>E: 4.2 Synthesized Response end opt Stop Conditions @@ -64,7 +61,7 @@ sequenceDiagram end E->>S: Output Safety Check - S->>U: 6. Final Response + S->>U: 5. Final Response ``` Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: @@ -77,7 +74,10 @@ agent_config = AgentConfig( instructions="You are a helpful assistant", # Enable both RAG and tool usage toolgroups=[ - {"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}, + { + "name": "builtin::rag/knowledge_search", + "args": {"vector_db_ids": ["my_docs"]}, + }, "builtin::code_interpreter", ], # Configure safety diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index e6d628193..e2e5fd6b5 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -91,7 +91,7 @@ agent_config = AgentConfig( enable_session_persistence=False, toolgroups=[ { - "name": "builtin::rag", + "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], }, diff --git a/docs/source/conf.py b/docs/source/conf.py index fd105a6cf..44975c02c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,6 +13,13 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information from docutils import nodes +import tomli # Import tomli for TOML parsing +from pathlib import Path + +# Read version from pyproject.toml +with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: + pyproject = tomli.load(f) + llama_stack_version = pyproject["project"]["version"] project = "llama-stack" copyright = "2025, Meta" @@ -66,6 +73,7 @@ myst_enable_extensions = [ myst_substitutions = { "docker_hub": "https://hub.docker.com/repository/docker/llamastack", + "llama_stack_version": llama_stack_version, } suppress_warnings = ['myst.header'] diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index 554f4354a..f017a9723 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -243,7 +243,7 @@ agent_config = AgentConfig( # Define tools available to the agent toolgroups=[ { - "name": "builtin::rag", + "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], }, diff --git a/docs/source/index.md b/docs/source/index.md index 48e1e7124..8234e1a9a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,8 +1,7 @@ - ```{admonition} News :class: tip -Llama Stack 0.1.4 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.4) for more details. +Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details. ``` # Llama Stack diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a7c0d63e5..b0cb50e42 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: - return await self.routing_table.get_provider_impl("query_from_memory").query( + return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 854e5d5ae..ef770ff72 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -33,7 +33,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v7" +KEY_VERSION = "v8" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index d84418241..202c9322f 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -132,7 +132,7 @@ def rag_chat_page(): }, toolgroups=[ dict( - name="builtin::rag", + name="builtin::rag/knowledge_search", args={ "vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], }, diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4a1421245..b17179463 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx -from pydantic import TypeAdapter from llama_stack.apis.agents import ( AgentConfig, @@ -62,7 +61,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import ( ToolParamDefinition, ) from llama_stack.providers.utils.kvstore import KVStore -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence @@ -84,7 +82,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_QUERY_TOOL = "query_from_memory" +MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" @@ -499,111 +497,18 @@ class ChatAgent(ShieldRunnerMixin): # TODO: simplify all of this code, it can be simpler toolgroup_args = {} toolgroups = set() - for toolgroup in self.agent_config.toolgroups: + for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): if isinstance(toolgroup, AgentToolGroupWithArgs): - toolgroups.add(toolgroup.name) - toolgroup_args[toolgroup.name] = toolgroup.args + tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name) + toolgroups.add(tool_group_name) + toolgroup_args[tool_group_name] = toolgroup.args else: toolgroups.add(toolgroup) - if toolgroups_for_turn: - for toolgroup in toolgroups_for_turn: - if isinstance(toolgroup, AgentToolGroupWithArgs): - toolgroups.add(toolgroup.name) - toolgroup_args[toolgroup.name] = toolgroup.args - else: - toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: await self.handle_documents(session_id, documents, input_messages, tool_defs) - if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0: - with tracing.span(MEMORY_QUERY_TOOL) as span: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) - ) - ) - - args = toolgroup_args.get(RAG_TOOL_GROUP, {}) - vector_db_ids = args.get("vector_db_ids", []) - query_config = args.get("query_config") - if query_config: - query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) - else: - # handle someone passing an empty dict - query_config = RAGQueryConfig() - - session_info = await self.storage.get_session_info(session_id) - - # if the session has a memory bank id, let the memory tool use it - if session_info.vector_db_id: - vector_db_ids.append(session_info.vector_db_id) - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - delta=ToolCallDelta( - parse_status=ToolCallParseStatus.succeeded, - tool_call=ToolCall( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - arguments={}, - ), - ), - ) - ) - ) - result = await self.tool_runtime_api.rag_tool.query( - content=concat_interleaved_content([msg.content for msg in input_messages]), - vector_db_ids=vector_db_ids, - query_config=query_config, - ) - retrieved_context = result.content - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[ - ToolCall( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - arguments={}, - ) - ], - tool_responses=[ - ToolResponse( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - content=retrieved_context or [], - metadata=result.metadata, - ) - ], - ), - ) - ) - ) - span.set_attribute("input", [m.model_dump_json() for m in input_messages]) - span.set_attribute("output", retrieved_context) - span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - - # append retrieved_context to the last user message - for message in input_messages[::-1]: - if isinstance(message, UserMessage): - message.context = retrieved_context - break - output_attachments = [] n_iter = 0 @@ -631,9 +536,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[ - tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP - ], + tools=tool_defs, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, @@ -837,7 +740,7 @@ class ChatAgent(ShieldRunnerMixin): ) ], started_at=tool_execution_start_time, - completed_at=datetime.now(), + completed_at=datetime.now().astimezone().isoformat(), ), ) ) @@ -845,8 +748,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially - - if out_attachment := _interpret_content_as_attachment(result_message.content): + if (type(result_message.content) is str) and ( + out_attachment := _interpret_content_as_attachment(result_message.content) + ): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message # with the summary. We keep all generated attachments and then attach them to final message @@ -858,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin): async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None - ) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]: + ) -> Tuple[List[ToolDefinition], Dict[str, str]]: # Determine which tools to include agent_config_toolgroups = set( (toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) @@ -873,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin): } ) - tool_def_map = {} + tool_name_to_def = {} tool_to_group = {} for tool_def in self.agent_config.client_tools: - if tool_def_map.get(tool_def.name, None): + if tool_name_to_def.get(tool_def.name, None): raise ValueError(f"Tool {tool_def.name} already exists") - tool_def_map[tool_def.name] = ToolDefinition( + tool_name_to_def[tool_def.name] = ToolDefinition( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -893,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin): }, ) tool_to_group[tool_def.name] = "__client_tools__" - for toolgroup_name in agent_config_toolgroups: - if toolgroup_name not in toolgroups_for_turn_set: + for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: + if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set: continue + + toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) + if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data): + raise ValueError( + f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" + ) + for tool_def in tools.data: if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: tool_name = tool_def.identifier @@ -906,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin): else: built_in_type = BuiltinTool(tool_name) - if tool_def_map.get(built_in_type, None): + if tool_name_to_def.get(built_in_type, None): raise ValueError(f"Tool {built_in_type} already exists") - tool_def_map[built_in_type] = ToolDefinition( + tool_name_to_def[built_in_type] = ToolDefinition( tool_name=built_in_type, description=tool_def.description, parameters={ @@ -925,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin): tool_to_group[built_in_type] = tool_def.toolgroup_id continue - if tool_def_map.get(tool_def.identifier, None): + if tool_name_to_def.get(tool_def.identifier, None): raise ValueError(f"Tool {tool_def.identifier} already exists") - tool_def_map[tool_def.identifier] = ToolDefinition( - tool_name=tool_def.identifier, - description=tool_def.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool_def.parameters - }, - ) - tool_to_group[tool_def.identifier] = tool_def.toolgroup_id + if tool_name in (None, tool_def.identifier): + tool_name_to_def[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, + ) + tool_to_group[tool_def.identifier] = tool_def.toolgroup_id - return tool_def_map, tool_to_group + return list(tool_name_to_def.values()), tool_to_group + + def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: + """Parse a toolgroup name into its components. + + Args: + toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search") + + Returns: + A tuple of (tool_type, tool_group, tool_name) + """ + split_names = toolgroup_name_with_maybe_tool_name.split("/") + if len(split_names) == 2: + # e.g. "builtin::rag" + tool_group, tool_name = split_names + else: + tool_group, tool_name = split_names[0], None + return tool_group, tool_name async def handle_documents( self, @@ -951,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: - memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) - code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) + memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs) + code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") @@ -1072,7 +1001,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa else: raise ValueError(f"Unsupported URL {url}") - content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n')) + content.append( + TextContentItem( + text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' + ) + ) return ToolResponseMessage( call_id="", diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 3455c2236..168808bf8 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -7,6 +7,7 @@ import json import os import sqlite3 +import threading from datetime import datetime from opentelemetry.sdk.trace import SpanProcessor @@ -17,14 +18,18 @@ class SQLiteSpanProcessor(SpanProcessor): def __init__(self, conn_string): """Initialize the SQLite span processor with a connection string.""" self.conn_string = conn_string - self.conn = None + self._local = threading.local() # Thread-local storage for connections self.setup_database() - def _get_connection(self) -> sqlite3.Connection: - """Get the database connection.""" - if self.conn is None: - self.conn = sqlite3.connect(self.conn_string, check_same_thread=False) - return self.conn + def _get_connection(self): + """Get a thread-local database connection.""" + if not hasattr(self._local, "conn"): + try: + self._local.conn = sqlite3.connect(self.conn_string) + except Exception as e: + print(f"Error connecting to SQLite database: {e}") + raise e + return self._local.conn def setup_database(self): """Create the necessary tables if they don't exist.""" @@ -168,9 +173,14 @@ class SQLiteSpanProcessor(SpanProcessor): def shutdown(self): """Cleanup any resources.""" - if self.conn: - self.conn.close() - self.conn = None + # We can't access other threads' connections, so we just close our own + if hasattr(self._local, "conn"): + try: + self._local.conn.close() + except Exception as e: + print(f"Error closing SQLite connection: {e}") + finally: + del self._local.conn def force_flush(self, timeout_millis=30000): """Force export of spans.""" diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 306bd78a6..4b3f7d9e7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -10,6 +10,8 @@ import secrets import string from typing import Any, Dict, List, Optional +from pydantic import TypeAdapter + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -23,6 +25,7 @@ from llama_stack.apis.tools import ( RAGToolRuntime, ToolDef, ToolInvocationResult, + ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO @@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) chunks = chunks[: query_config.max_chunks] + tokens = 0 - picked = [] - for c in chunks: + picked = [ + TextContentItem( + text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ] + for i, c in enumerate(chunks): metadata = c.metadata tokens += metadata["token_count"] if tokens > query_config.max_tokens_in_context: @@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): break picked.append( TextContentItem( - text=f"id:{metadata['document_id']}; content:{c.content}", + text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n", ) ) + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) return RAGQueryResult( - content=[ - TextContentItem( - text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - ), - *picked, - TextContentItem( - text="\n=== END-RETRIEVED-CONTEXT ===\n", - ), - ], + content=picked, metadata={ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], }, @@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # by the LLM. The method is only implemented so things like /tools can list without # encountering fatals. return [ - ToolDef( - name="query_from_memory", - description="Retrieve context from memory", - ), ToolDef( name="insert_into_memory", description="Insert documents into memory", ), + ToolDef( + name="knowledge_search", + description="Search for information in a database.", + parameters=[ + ToolParameter( + name="query", + description="The query to search for. Can be a natural language sentence or keywords.", + parameter_type="string", + ), + ], + ), ] async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: - raise RuntimeError( - "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol" + vector_db_ids = kwargs.get("vector_db_ids", []) + query_config = kwargs.get("query_config") + if query_config: + query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) + else: + # handle someone passing an empty dict + query_config = RAGQueryConfig() + + query = kwargs["query"] + result = await self.query( + content=query, + vector_db_ids=vector_db_ids, + query_config=query_config, + ) + + return ToolInvocationResult( + content=result.content, + metadata=result.metadata, ) diff --git a/pyproject.toml b/pyproject.toml index 2ed2c4fa9..dc5659f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ docs = [ "sphinxcontrib.redoc", "sphinxcontrib.video", "sphinxcontrib.mermaid", + "tomli", ] [project.urls] diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 876a9baf9..6e3dc0739 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id): sampling_params={ "strategy": { "type": "top_p", - "temperature": 1.0, + "temperature": 0.0001, "top_p": 0.9, }, }, @@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str -def test_rag_agent(llama_stack_client, agent_config): +@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) +def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config): **agent_config, "toolgroups": [ dict( - name="builtin::rag", + name=rag_tool_name, args={ "vector_db_ids": [vector_db_id], }, @@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config): "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", "grouped", ), - ( - "What `tune` command to use for getting access to Llama3-8B-Instruct ?", - "download", - ), ] for prompt, expected_kw in user_prompts: response = rag_agent.create_turn( @@ -496,23 +493,36 @@ def test_rag_agent(llama_stack_client, agent_config): ) # rag is called tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory" + assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" # document ids are present in metadata - assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"] - assert expected_kw in response.output_message.content.lower() + assert all( + doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] + ) + if expected_kw: + assert expected_kw in response.output_message.content.lower() def test_rag_and_code_agent(llama_stack_client, agent_config): - urls = ["chat.rst"] - documents = [ + documents = [] + documents.append( Document( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", + document_id="nba_wiki", + content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).", metadata={}, ) - for i, url in enumerate(urls) - ] + ) + documents.append( + Document( + document_id="perplexity_wiki", + content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning: + + Srinivas, the CEO, worked at OpenAI as an AI researcher. + Konwinski was among the founding team at Databricks. + Yarats, the CTO, was an AI research scientist at Meta. + Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""", + metadata={}, + ) + ) vector_db_id = f"test-vector-db-{uuid4()}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, @@ -528,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): **agent_config, "toolgroups": [ dict( - name="builtin::rag", + name="builtin::rag/knowledge_search", args={"vector_db_ids": [vector_db_id]}, ), "builtin::code_interpreter", @@ -546,24 +556,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): "Here is a csv file, can you describe it?", [inflation_doc], "code_interpreter", + "", ), ( - "What are the top 5 topics that were explained? Only list succinct bullet points.", + "when was Perplexity the company founded?", [], - "query_from_memory", + "knowledge_search", + "2022", + ), + ( + "when was the nba created?", + [], + "knowledge_search", + "1949", ), ] - for prompt, docs, tool_name in user_prompts: + for prompt, docs, tool_name, expected_kw in user_prompts: session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, documents=docs, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert f"Tool:{tool_name}" in logs_str + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == tool_name + if expected_kw: + assert expected_kw in response.output_message.content.lower() def test_create_turn_response(llama_stack_client, agent_config): diff --git a/uv.lock b/uv.lock index c92a6e79a..80c250fcc 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -913,6 +912,7 @@ docs = [ { name = "sphinxcontrib-mermaid" }, { name = "sphinxcontrib-redoc" }, { name = "sphinxcontrib-video" }, + { name = "tomli" }, ] test = [ { name = "aiosqlite" }, @@ -971,13 +971,13 @@ requires-dist = [ { name = "sphinxcontrib-redoc", marker = "extra == 'docs'" }, { name = "sphinxcontrib-video", marker = "extra == 'docs'" }, { name = "termcolor" }, + { name = "tomli", marker = "extra == 'docs'" }, { name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" }, { name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" }, { name = "types-requests", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "uvicorn", marker = "extra == 'dev'" }, ] -provides-extras = ["dev", "test", "docs"] [[package]] name = "llama-stack-client"