From bb2690f176f56b760770a0921fe71cd94715b3ac Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 26 Feb 2025 13:04:52 -0800 Subject: [PATCH 1/8] feat: remove special handling of builtin::rag tool (#1015) Summary: Lets the model decide which tool it needs to call to respond to a query. Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/ --safety-shield meta-llama/Llama-Guard-3-8B ``` Also evaluated on a small benchmark with 20 questions from HotpotQA. With this PR and some prompting, the performance is 77% recall compared to 50% currently. --- [//]: # (BEGIN SAPLING FOOTER) Stack created with [Sapling](https://sapling-scm.com). Best reviewed with [ReviewStack](https://reviewstack.dev/meta-llama/llama-stack/pull/1015). * #1268 * #1239 * __->__ #1015 --- llama_stack/distribution/routers/routers.py | 2 +- .../agents/meta_reference/agent_instance.py | 108 ++---------------- .../inline/tool_runtime/rag/memory.py | 60 +++++++--- tests/client-sdk/agents/test_agents.py | 57 ++++++--- 4 files changed, 94 insertions(+), 133 deletions(-) diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index a7c0d63e5..b0cb50e42 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -441,7 +441,7 @@ class ToolRuntimeRouter(ToolRuntime): vector_db_ids: List[str], query_config: Optional[RAGQueryConfig] = None, ) -> RAGQueryResult: - return await self.routing_table.get_provider_impl("query_from_memory").query( + return await self.routing_table.get_provider_impl("knowledge_search").query( content, vector_db_ids, query_config ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 4a1421245..64cd41636 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -17,7 +17,6 @@ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple from urllib.parse import urlparse import httpx -from pydantic import TypeAdapter from llama_stack.apis.agents import ( AgentConfig, @@ -62,7 +61,7 @@ from llama_stack.apis.inference import ( UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import RAGDocument, RAGQueryConfig, ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools import RAGDocument, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.models.llama.datatypes import ( BuiltinTool, @@ -70,7 +69,6 @@ from llama_stack.models.llama.datatypes import ( ToolParamDefinition, ) from llama_stack.providers.utils.kvstore import KVStore -from llama_stack.providers.utils.memory.vector_store import concat_interleaved_content from llama_stack.providers.utils.telemetry import tracing from .persistence import AgentPersistence @@ -84,7 +82,7 @@ def make_random_string(length: int = 8): TOOLS_ATTACHMENT_KEY_REGEX = re.compile(r"__tools_attachment__=(\{.*?\})") -MEMORY_QUERY_TOOL = "query_from_memory" +MEMORY_QUERY_TOOL = "knowledge_search" WEB_SEARCH_TOOL = "web_search" RAG_TOOL_GROUP = "builtin::rag" @@ -517,93 +515,6 @@ class ChatAgent(ShieldRunnerMixin): if documents: await self.handle_documents(session_id, documents, input_messages, tool_defs) - if RAG_TOOL_GROUP in toolgroups and len(input_messages) > 0: - with tracing.span(MEMORY_QUERY_TOOL) as span: - step_id = str(uuid.uuid4()) - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepStartPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - ) - ) - ) - - args = toolgroup_args.get(RAG_TOOL_GROUP, {}) - vector_db_ids = args.get("vector_db_ids", []) - query_config = args.get("query_config") - if query_config: - query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) - else: - # handle someone passing an empty dict - query_config = RAGQueryConfig() - - session_info = await self.storage.get_session_info(session_id) - - # if the session has a memory bank id, let the memory tool use it - if session_info.vector_db_id: - vector_db_ids.append(session_info.vector_db_id) - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepProgressPayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - delta=ToolCallDelta( - parse_status=ToolCallParseStatus.succeeded, - tool_call=ToolCall( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - arguments={}, - ), - ), - ) - ) - ) - result = await self.tool_runtime_api.rag_tool.query( - content=concat_interleaved_content([msg.content for msg in input_messages]), - vector_db_ids=vector_db_ids, - query_config=query_config, - ) - retrieved_context = result.content - - yield AgentTurnResponseStreamChunk( - event=AgentTurnResponseEvent( - payload=AgentTurnResponseStepCompletePayload( - step_type=StepType.tool_execution.value, - step_id=step_id, - step_details=ToolExecutionStep( - step_id=step_id, - turn_id=turn_id, - tool_calls=[ - ToolCall( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - arguments={}, - ) - ], - tool_responses=[ - ToolResponse( - call_id="", - tool_name=MEMORY_QUERY_TOOL, - content=retrieved_context or [], - metadata=result.metadata, - ) - ], - ), - ) - ) - ) - span.set_attribute("input", [m.model_dump_json() for m in input_messages]) - span.set_attribute("output", retrieved_context) - span.set_attribute("tool_name", MEMORY_QUERY_TOOL) - - # append retrieved_context to the last user message - for message in input_messages[::-1]: - if isinstance(message, UserMessage): - message.context = retrieved_context - break - output_attachments = [] n_iter = 0 @@ -631,9 +542,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[ - tool for tool in tool_defs.values() if tool_to_group.get(tool.tool_name, None) != RAG_TOOL_GROUP - ], + tools=[tool for tool in tool_defs.values()], tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, @@ -845,8 +754,9 @@ class ChatAgent(ShieldRunnerMixin): # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially - - if out_attachment := _interpret_content_as_attachment(result_message.content): + if (type(result_message.content) is str) and ( + out_attachment := _interpret_content_as_attachment(result_message.content) + ): # NOTE: when we push this message back to the model, the model may ignore the # attached file path etc. since the model is trained to only provide a user message # with the summary. We keep all generated attachments and then attach them to final message @@ -1072,7 +982,11 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa else: raise ValueError(f"Unsupported URL {url}") - content.append(TextContentItem(text=f'# There is a file accessible to you at "{filepath}"\n')) + content.append( + TextContentItem( + text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.' + ) + ) return ToolResponseMessage( call_id="", diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 306bd78a6..4b3f7d9e7 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -10,6 +10,8 @@ import secrets import string from typing import Any, Dict, List, Optional +from pydantic import TypeAdapter + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -23,6 +25,7 @@ from llama_stack.apis.tools import ( RAGToolRuntime, ToolDef, ToolInvocationResult, + ToolParameter, ToolRuntime, ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO @@ -120,9 +123,14 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # sort by score chunks, scores = zip(*sorted(zip(chunks, scores, strict=False), key=lambda x: x[1], reverse=True), strict=False) chunks = chunks[: query_config.max_chunks] + tokens = 0 - picked = [] - for c in chunks: + picked = [ + TextContentItem( + text=f"knowledge_search tool found {len(chunks)} chunks:\nBEGIN of knowledge_search tool results.\n" + ) + ] + for i, c in enumerate(chunks): metadata = c.metadata tokens += metadata["token_count"] if tokens > query_config.max_tokens_in_context: @@ -132,20 +140,13 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): break picked.append( TextContentItem( - text=f"id:{metadata['document_id']}; content:{c.content}", + text=f"Result {i + 1}:\nDocument_id:{metadata['document_id'][:5]}\nContent: {c.content}\n", ) ) + picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) return RAGQueryResult( - content=[ - TextContentItem( - text="Here are the retrieved documents for relevant context:\n=== START-RETRIEVED-CONTEXT ===\n", - ), - *picked, - TextContentItem( - text="\n=== END-RETRIEVED-CONTEXT ===\n", - ), - ], + content=picked, metadata={ "document_ids": [c.metadata["document_id"] for c in chunks[: len(picked)]], }, @@ -158,17 +159,40 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): # by the LLM. The method is only implemented so things like /tools can list without # encountering fatals. return [ - ToolDef( - name="query_from_memory", - description="Retrieve context from memory", - ), ToolDef( name="insert_into_memory", description="Insert documents into memory", ), + ToolDef( + name="knowledge_search", + description="Search for information in a database.", + parameters=[ + ToolParameter( + name="query", + description="The query to search for. Can be a natural language sentence or keywords.", + parameter_type="string", + ), + ], + ), ] async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult: - raise RuntimeError( - "This toolgroup should not be called generically but only through specific methods of the RAGToolRuntime protocol" + vector_db_ids = kwargs.get("vector_db_ids", []) + query_config = kwargs.get("query_config") + if query_config: + query_config = TypeAdapter(RAGQueryConfig).validate_python(query_config) + else: + # handle someone passing an empty dict + query_config = RAGQueryConfig() + + query = kwargs["query"] + result = await self.query( + content=query, + vector_db_ids=vector_db_ids, + query_config=query_config, + ) + + return ToolInvocationResult( + content=result.content, + metadata=result.metadata, ) diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 876a9baf9..8e2c793e6 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -96,7 +96,7 @@ def agent_config(llama_stack_client, text_model_id): sampling_params={ "strategy": { "type": "top_p", - "temperature": 1.0, + "temperature": 0.0001, "top_p": 0.9, }, }, @@ -496,23 +496,36 @@ def test_rag_agent(llama_stack_client, agent_config): ) # rag is called tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") - assert tool_execution_step.tool_calls[0].tool_name == "query_from_memory" + assert tool_execution_step.tool_calls[0].tool_name == "knowledge_search" # document ids are present in metadata - assert "num-0" in tool_execution_step.tool_responses[0].metadata["document_ids"] - assert expected_kw in response.output_message.content.lower() + assert all( + doc_id.startswith("num-") for doc_id in tool_execution_step.tool_responses[0].metadata["document_ids"] + ) + if expected_kw: + assert expected_kw in response.output_message.content.lower() def test_rag_and_code_agent(llama_stack_client, agent_config): - urls = ["chat.rst"] - documents = [ + documents = [] + documents.append( Document( - document_id=f"num-{i}", - content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}", - mime_type="text/plain", + document_id="nba_wiki", + content="The NBA was created on August 3, 1949, with the merger of the Basketball Association of America (BAA) and the National Basketball League (NBL).", metadata={}, ) - for i, url in enumerate(urls) - ] + ) + documents.append( + Document( + document_id="perplexity_wiki", + content="""Perplexity the company was founded in 2022 by Aravind Srinivas, Andy Konwinski, Denis Yarats and Johnny Ho, engineers with backgrounds in back-end systems, artificial intelligence (AI) and machine learning: + + Srinivas, the CEO, worked at OpenAI as an AI researcher. + Konwinski was among the founding team at Databricks. + Yarats, the CTO, was an AI research scientist at Meta. + Ho, the CSO, worked as an engineer at Quora, then as a quantitative trader on Wall Street.[5]""", + metadata={}, + ) + ) vector_db_id = f"test-vector-db-{uuid4()}" llama_stack_client.vector_dbs.register( vector_db_id=vector_db_id, @@ -546,24 +559,34 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): "Here is a csv file, can you describe it?", [inflation_doc], "code_interpreter", + "", ), ( - "What are the top 5 topics that were explained? Only list succinct bullet points.", + "when was Perplexity the company founded?", [], - "query_from_memory", + "knowledge_search", + "2022", + ), + ( + "when was the nba created?", + [], + "knowledge_search", + "1949", ), ] - for prompt, docs, tool_name in user_prompts: + for prompt, docs, tool_name, expected_kw in user_prompts: session_id = agent.create_session(f"test-session-{uuid4()}") response = agent.create_turn( messages=[{"role": "user", "content": prompt}], session_id=session_id, documents=docs, + stream=False, ) - logs = [str(log) for log in EventLogger().log(response) if log is not None] - logs_str = "".join(logs) - assert f"Tool:{tool_name}" in logs_str + tool_execution_step = next(step for step in response.steps if step.step_type == "tool_execution") + assert tool_execution_step.tool_calls[0].tool_name == tool_name + if expected_kw: + assert expected_kw in response.output_message.content.lower() def test_create_turn_response(llama_stack_client, agent_config): From 9a3db9a290a5922dbf2cc9f88d4a320e149b11ee Mon Sep 17 00:00:00 2001 From: Botao Chen Date: Wed, 26 Feb 2025 13:39:16 -0800 Subject: [PATCH 2/8] feat: update the post training notebook (#1280) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ## What does this PR do? - add 'open in colab' icon that links to the notebook - update the pip install llama-stack pkg part ## test preview Screenshot 2025-02-26 at 1 25 34 PM Screenshot 2025-02-26 at 1 25 38 PM --- docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb b/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb index 3979088c1..ae50b95a1 100644 --- a/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb +++ b/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb @@ -3,6 +3,8 @@ { "cell_type": "markdown", "source": [ + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/notebooks/Alpha_Llama_Stack_Post_Training.ipynb)\n", + "\n", "# [Alpha] Llama Stack Post Training\n", "This notebook will use a real world problem (improve LLM as tax preparer) to walk through the main sets of APIs we offer with Llama stack for post training to improve the LLM performance for agentic apps (We support supervised finetune now, RLHF and knowledge distillation will come soon!).\n", "\n", @@ -64,7 +66,7 @@ "output_type": "stream", "name": "stdout", "text": [ - "Collecting git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer\n", + "Collecting git+https://github.com/meta-llama/llama-stack.git\n", " Cloning https://github.com/meta-llama/llama-stack.git (to revision hf_format_checkpointer) to /tmp/pip-req-build-j_1bxqzm\n", " Running command git clone --filter=blob:none --quiet https://github.com/meta-llama/llama-stack.git /tmp/pip-req-build-j_1bxqzm\n", " Running command git checkout -b hf_format_checkpointer --track origin/hf_format_checkpointer\n", @@ -76,7 +78,7 @@ } ], "source": [ - "!pip install git+https://github.com/meta-llama/llama-stack.git@hf_format_checkpointer" + "!pip install git+https://github.com/meta-llama/llama-stack.git #TODO: update this after the next pkg release" ] }, { From 6b075e50754f933b94d37903447b4342e1155d21 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 26 Feb 2025 13:41:54 -0800 Subject: [PATCH 3/8] feat: automatically update documentation version based on pyproject.toml source of truth --- docs/source/conf.py | 8 ++++++++ docs/source/index.md | 3 +-- pyproject.toml | 1 + 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/source/conf.py b/docs/source/conf.py index fd105a6cf..44975c02c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -13,6 +13,13 @@ # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information from docutils import nodes +import tomli # Import tomli for TOML parsing +from pathlib import Path + +# Read version from pyproject.toml +with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f: + pyproject = tomli.load(f) + llama_stack_version = pyproject["project"]["version"] project = "llama-stack" copyright = "2025, Meta" @@ -66,6 +73,7 @@ myst_enable_extensions = [ myst_substitutions = { "docker_hub": "https://hub.docker.com/repository/docker/llamastack", + "llama_stack_version": llama_stack_version, } suppress_warnings = ['myst.header'] diff --git a/docs/source/index.md b/docs/source/index.md index 48e1e7124..8234e1a9a 100644 --- a/docs/source/index.md +++ b/docs/source/index.md @@ -1,8 +1,7 @@ - ```{admonition} News :class: tip -Llama Stack 0.1.4 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.4) for more details. +Llama Stack {{ llama_stack_version }} is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v{{ llama_stack_version }}) for more details. ``` # Llama Stack diff --git a/pyproject.toml b/pyproject.toml index 2ed2c4fa9..dc5659f06 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ docs = [ "sphinxcontrib.redoc", "sphinxcontrib.video", "sphinxcontrib.mermaid", + "tomli", ] [project.urls] From fca84db5b0c97bce60a9fe9f161f1ac7d40d3e46 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 26 Feb 2025 13:51:33 -0800 Subject: [PATCH 4/8] fix: time logging format (#1281) Summary: missed in last PR Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py::test_create_turn_response --safety-shield meta-llama/Llama-Guard-3-8B ``` --- .../providers/inline/agents/meta_reference/agent_instance.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 64cd41636..c910598b1 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -746,7 +746,7 @@ class ChatAgent(ShieldRunnerMixin): ) ], started_at=tool_execution_start_time, - completed_at=datetime.now(), + completed_at=datetime.now().astimezone().isoformat(), ), ) ) From 3f0b8c25aa113aca517def3b5a40529bfb6ab85e Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 26 Feb 2025 13:53:57 -0800 Subject: [PATCH 5/8] fix: run uv-sync manually. locally pre-commit is not triggering --- uv.lock | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/uv.lock b/uv.lock index c92a6e79a..80c250fcc 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,4 @@ version = 1 -revision = 1 requires-python = ">=3.10" resolution-markers = [ "(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')", @@ -913,6 +912,7 @@ docs = [ { name = "sphinxcontrib-mermaid" }, { name = "sphinxcontrib-redoc" }, { name = "sphinxcontrib-video" }, + { name = "tomli" }, ] test = [ { name = "aiosqlite" }, @@ -971,13 +971,13 @@ requires-dist = [ { name = "sphinxcontrib-redoc", marker = "extra == 'docs'" }, { name = "sphinxcontrib-video", marker = "extra == 'docs'" }, { name = "termcolor" }, + { name = "tomli", marker = "extra == 'docs'" }, { name = "torch", marker = "extra == 'test'", specifier = ">=2.6.0", index = "https://download.pytorch.org/whl/cpu" }, { name = "torchvision", marker = "extra == 'test'", specifier = ">=0.21.0", index = "https://download.pytorch.org/whl/cpu" }, { name = "types-requests", marker = "extra == 'dev'" }, { name = "types-setuptools", marker = "extra == 'dev'" }, { name = "uvicorn", marker = "extra == 'dev'" }, ] -provides-extras = ["dev", "test", "docs"] [[package]] name = "llama-stack-client" From 657efc67bc33e5f3cef846bdbdaf090c09386188 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Wed, 26 Feb 2025 13:58:03 -0800 Subject: [PATCH 6/8] fix: bump up registry key version to clear off stale entries in dbs --- llama_stack/distribution/store/registry.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index 854e5d5ae..ef770ff72 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -33,7 +33,7 @@ class DistributionRegistry(Protocol): REGISTER_PREFIX = "distributions:registry" -KEY_VERSION = "v7" +KEY_VERSION = "v8" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" From c8a20b8ed0e0100ada7dfb8b3eec5065f454005c Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 26 Feb 2025 14:07:05 -0800 Subject: [PATCH 7/8] feat: allow specifying specific tool within toolgroup (#1239) Summary: E.g. `builtin::rag::knowledge_search` Test Plan: ``` LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/ --safety-shield meta-llama/Llama-Guard-3-8B ``` --- docs/getting_started.ipynb | 4 +- .../agent_execution_loop.md | 30 +++--- docs/source/building_applications/rag.md | 2 +- docs/source/getting_started/index.md | 2 +- .../distribution/ui/page/playground/rag.py | 2 +- .../agents/meta_reference/agent_instance.py | 93 +++++++++++-------- tests/client-sdk/agents/test_agents.py | 11 +-- 7 files changed, 80 insertions(+), 64 deletions(-) diff --git a/docs/getting_started.ipynb b/docs/getting_started.ipynb index 3b3059285..329734f4c 100644 --- a/docs/getting_started.ipynb +++ b/docs/getting_started.ipynb @@ -803,7 +803,7 @@ } ], "source": [ - "model_id = \"meta-llama/Llama-3.1-70B-Instruct\"\n", + "model_id = \"meta-llama/Llama-3.3-70B-Instruct\"\n", "\n", "model_id\n" ] @@ -1688,7 +1688,7 @@ " enable_session_persistence=False,\n", " toolgroups = [\n", " {\n", - " \"name\": \"builtin::rag\",\n", + " \"name\": \"builtin::rag/knowledge_search\",\n", " \"args\" : {\n", " \"vector_db_ids\": [vector_db_id],\n", " }\n", diff --git a/docs/source/building_applications/agent_execution_loop.md b/docs/source/building_applications/agent_execution_loop.md index 6b3f64423..0d212df7a 100644 --- a/docs/source/building_applications/agent_execution_loop.md +++ b/docs/source/building_applications/agent_execution_loop.md @@ -7,12 +7,12 @@ Each agent turn follows these key steps: 1. **Initial Safety Check**: The user's input is first screened through configured safety shields 2. **Context Retrieval**: - - If RAG is enabled, the agent queries relevant documents from memory banks - - For new documents, they are first inserted into the memory bank - - Retrieved context is augmented to the user's prompt + - If RAG is enabled, the agent can choose to query relevant documents from memory banks. You can use the `instructions` field to steer the agent. + - For new documents, they are first inserted into the memory bank. + - Retrieved context is provided to the LLM as a tool response in the message history. 3. **Inference Loop**: The agent enters its main execution loop: - - The LLM receives the augmented prompt (with context and/or previous tool outputs) + - The LLM receives a user prompt (with previous tool outputs) - The LLM generates a response, potentially with tool calls - If tool calls are present: - Tool inputs are safety-checked @@ -40,19 +40,16 @@ sequenceDiagram S->>E: Input Safety Check deactivate S - E->>M: 2.1 Query Context - M-->>E: 2.2 Retrieved Documents - loop Inference Loop - E->>L: 3.1 Augment with Context - L-->>E: 3.2 Response (with/without tool calls) + E->>L: 2.1 Augment with Context + L-->>E: 2.2 Response (with/without tool calls) alt Has Tool Calls E->>S: Check Tool Input - S->>T: 4.1 Execute Tool - T-->>E: 4.2 Tool Response - E->>L: 5.1 Tool Response - L-->>E: 5.2 Synthesized Response + S->>T: 3.1 Execute Tool + T-->>E: 3.2 Tool Response + E->>L: 4.1 Tool Response + L-->>E: 4.2 Synthesized Response end opt Stop Conditions @@ -64,7 +61,7 @@ sequenceDiagram end E->>S: Output Safety Check - S->>U: 6. Final Response + S->>U: 5. Final Response ``` Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution: @@ -77,7 +74,10 @@ agent_config = AgentConfig( instructions="You are a helpful assistant", # Enable both RAG and tool usage toolgroups=[ - {"name": "builtin::rag", "args": {"vector_db_ids": ["my_docs"]}}, + { + "name": "builtin::rag/knowledge_search", + "args": {"vector_db_ids": ["my_docs"]}, + }, "builtin::code_interpreter", ], # Configure safety diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index e6d628193..e2e5fd6b5 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -91,7 +91,7 @@ agent_config = AgentConfig( enable_session_persistence=False, toolgroups=[ { - "name": "builtin::rag", + "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], }, diff --git a/docs/source/getting_started/index.md b/docs/source/getting_started/index.md index 554f4354a..f017a9723 100644 --- a/docs/source/getting_started/index.md +++ b/docs/source/getting_started/index.md @@ -243,7 +243,7 @@ agent_config = AgentConfig( # Define tools available to the agent toolgroups=[ { - "name": "builtin::rag", + "name": "builtin::rag/knowledge_search", "args": { "vector_db_ids": [vector_db_id], }, diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index d84418241..202c9322f 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -132,7 +132,7 @@ def rag_chat_page(): }, toolgroups=[ dict( - name="builtin::rag", + name="builtin::rag/knowledge_search", args={ "vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], }, diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index c910598b1..b17179463 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -497,19 +497,13 @@ class ChatAgent(ShieldRunnerMixin): # TODO: simplify all of this code, it can be simpler toolgroup_args = {} toolgroups = set() - for toolgroup in self.agent_config.toolgroups: + for toolgroup in self.agent_config.toolgroups + (toolgroups_for_turn or []): if isinstance(toolgroup, AgentToolGroupWithArgs): - toolgroups.add(toolgroup.name) - toolgroup_args[toolgroup.name] = toolgroup.args + tool_group_name, tool_name = self._parse_toolgroup_name(toolgroup.name) + toolgroups.add(tool_group_name) + toolgroup_args[tool_group_name] = toolgroup.args else: toolgroups.add(toolgroup) - if toolgroups_for_turn: - for toolgroup in toolgroups_for_turn: - if isinstance(toolgroup, AgentToolGroupWithArgs): - toolgroups.add(toolgroup.name) - toolgroup_args[toolgroup.name] = toolgroup.args - else: - toolgroups.add(toolgroup) tool_defs, tool_to_group = await self._get_tool_defs(toolgroups_for_turn) if documents: @@ -542,7 +536,7 @@ class ChatAgent(ShieldRunnerMixin): async for chunk in await self.inference_api.chat_completion( self.agent_config.model, input_messages, - tools=[tool for tool in tool_defs.values()], + tools=tool_defs, tool_prompt_format=self.agent_config.tool_config.tool_prompt_format, response_format=self.agent_config.response_format, stream=True, @@ -768,7 +762,7 @@ class ChatAgent(ShieldRunnerMixin): async def _get_tool_defs( self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None - ) -> Tuple[Dict[str, ToolDefinition], Dict[str, str]]: + ) -> Tuple[List[ToolDefinition], Dict[str, str]]: # Determine which tools to include agent_config_toolgroups = set( (toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) @@ -783,13 +777,13 @@ class ChatAgent(ShieldRunnerMixin): } ) - tool_def_map = {} + tool_name_to_def = {} tool_to_group = {} for tool_def in self.agent_config.client_tools: - if tool_def_map.get(tool_def.name, None): + if tool_name_to_def.get(tool_def.name, None): raise ValueError(f"Tool {tool_def.name} already exists") - tool_def_map[tool_def.name] = ToolDefinition( + tool_name_to_def[tool_def.name] = ToolDefinition( tool_name=tool_def.name, description=tool_def.description, parameters={ @@ -803,10 +797,17 @@ class ChatAgent(ShieldRunnerMixin): }, ) tool_to_group[tool_def.name] = "__client_tools__" - for toolgroup_name in agent_config_toolgroups: - if toolgroup_name not in toolgroups_for_turn_set: + for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: + if toolgroup_name_with_maybe_tool_name not in toolgroups_for_turn_set: continue + + toolgroup_name, tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) + if tool_name is not None and not any(tool.identifier == tool_name for tool in tools.data): + raise ValueError( + f"Tool {tool_name} not found in toolgroup {toolgroup_name}. Available tools: {', '.join([tool.identifier for tool in tools.data])}" + ) + for tool_def in tools.data: if toolgroup_name.startswith("builtin") and toolgroup_name != RAG_TOOL_GROUP: tool_name = tool_def.identifier @@ -816,10 +817,10 @@ class ChatAgent(ShieldRunnerMixin): else: built_in_type = BuiltinTool(tool_name) - if tool_def_map.get(built_in_type, None): + if tool_name_to_def.get(built_in_type, None): raise ValueError(f"Tool {built_in_type} already exists") - tool_def_map[built_in_type] = ToolDefinition( + tool_name_to_def[built_in_type] = ToolDefinition( tool_name=built_in_type, description=tool_def.description, parameters={ @@ -835,24 +836,42 @@ class ChatAgent(ShieldRunnerMixin): tool_to_group[built_in_type] = tool_def.toolgroup_id continue - if tool_def_map.get(tool_def.identifier, None): + if tool_name_to_def.get(tool_def.identifier, None): raise ValueError(f"Tool {tool_def.identifier} already exists") - tool_def_map[tool_def.identifier] = ToolDefinition( - tool_name=tool_def.identifier, - description=tool_def.description, - parameters={ - param.name: ToolParamDefinition( - param_type=param.parameter_type, - description=param.description, - required=param.required, - default=param.default, - ) - for param in tool_def.parameters - }, - ) - tool_to_group[tool_def.identifier] = tool_def.toolgroup_id + if tool_name in (None, tool_def.identifier): + tool_name_to_def[tool_def.identifier] = ToolDefinition( + tool_name=tool_def.identifier, + description=tool_def.description, + parameters={ + param.name: ToolParamDefinition( + param_type=param.parameter_type, + description=param.description, + required=param.required, + default=param.default, + ) + for param in tool_def.parameters + }, + ) + tool_to_group[tool_def.identifier] = tool_def.toolgroup_id - return tool_def_map, tool_to_group + return list(tool_name_to_def.values()), tool_to_group + + def _parse_toolgroup_name(self, toolgroup_name_with_maybe_tool_name: str) -> tuple[str, Optional[str]]: + """Parse a toolgroup name into its components. + + Args: + toolgroup_name: The toolgroup name to parse (e.g. "builtin::rag/knowledge_search") + + Returns: + A tuple of (tool_type, tool_group, tool_name) + """ + split_names = toolgroup_name_with_maybe_tool_name.split("/") + if len(split_names) == 2: + # e.g. "builtin::rag" + tool_group, tool_name = split_names + else: + tool_group, tool_name = split_names[0], None + return tool_group, tool_name async def handle_documents( self, @@ -861,8 +880,8 @@ class ChatAgent(ShieldRunnerMixin): input_messages: List[Message], tool_defs: Dict[str, ToolDefinition], ) -> None: - memory_tool = tool_defs.get(MEMORY_QUERY_TOOL, None) - code_interpreter_tool = tool_defs.get(BuiltinTool.code_interpreter, None) + memory_tool = any(tool_def.tool_name == MEMORY_QUERY_TOOL for tool_def in tool_defs) + code_interpreter_tool = any(tool_def.tool_name == BuiltinTool.code_interpreter for tool_def in tool_defs) content_items = [] url_items = [] pattern = re.compile("^(https?://|file://|data:)") diff --git a/tests/client-sdk/agents/test_agents.py b/tests/client-sdk/agents/test_agents.py index 8e2c793e6..6e3dc0739 100644 --- a/tests/client-sdk/agents/test_agents.py +++ b/tests/client-sdk/agents/test_agents.py @@ -441,7 +441,8 @@ def xtest_override_system_message_behavior(llama_stack_client, agent_config): assert "get_boiling_point" in logs_str -def test_rag_agent(llama_stack_client, agent_config): +@pytest.mark.parametrize("rag_tool_name", ["builtin::rag/knowledge_search", "builtin::rag"]) +def test_rag_agent(llama_stack_client, agent_config, rag_tool_name): urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"] documents = [ Document( @@ -469,7 +470,7 @@ def test_rag_agent(llama_stack_client, agent_config): **agent_config, "toolgroups": [ dict( - name="builtin::rag", + name=rag_tool_name, args={ "vector_db_ids": [vector_db_id], }, @@ -483,10 +484,6 @@ def test_rag_agent(llama_stack_client, agent_config): "Instead of the standard multi-head attention, what attention type does Llama3-8B use?", "grouped", ), - ( - "What `tune` command to use for getting access to Llama3-8B-Instruct ?", - "download", - ), ] for prompt, expected_kw in user_prompts: response = rag_agent.create_turn( @@ -541,7 +538,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config): **agent_config, "toolgroups": [ dict( - name="builtin::rag", + name="builtin::rag/knowledge_search", args={"vector_db_ids": [vector_db_id]}, ), "builtin::code_interpreter", From 270d64007aa6d13ecc8e63149b5a489018c9d031 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 26 Feb 2025 14:44:31 -0800 Subject: [PATCH 8/8] fix: sqlite conn (#1282) # Summary: Our tests sometimes error out with ``` ========================== 11 passed, 342 warnings in 58.86s ========================== Error exporting span to SQLite: Cannot operate on a closed database. Fatal Python error: _enter_buffered_busy: could not acquire lock for <_io.BufferedWriter name=''> at interpreter shutdown, possibly due to daemon threads Python runtime state: finalizing (tstate=0x000000012af04280) Current thread 0x00000001fa29c240 (most recent call first): ``` Usually able to repro this by running 10 times. The proposed fix is to use threadsafe var for creating sqlite connection to ensure connection is only used by one thread. Not 100% if this is the fix, but am not able to repro with this. # Test Plan: Run 10 times and saw no more errors ``` for i in {1..10}; do echo "=== Starting Run $i ===" LLAMA_STACK_CONFIG=fireworks pytest -s -v tests/client-sdk/agents/test_agents.py --safety-shield meta-llama/Llama-Guard-3-8B if [[ $? -ne 0 ]]; then echo "=== Run $i FAILED with exit code $? ===" break else echo "=== Run $i PASSED ===" fi echo done ``` --- .../meta_reference/sqlite_span_processor.py | 28 +++++++++++++------ 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py index 3455c2236..168808bf8 100644 --- a/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py +++ b/llama_stack/providers/inline/telemetry/meta_reference/sqlite_span_processor.py @@ -7,6 +7,7 @@ import json import os import sqlite3 +import threading from datetime import datetime from opentelemetry.sdk.trace import SpanProcessor @@ -17,14 +18,18 @@ class SQLiteSpanProcessor(SpanProcessor): def __init__(self, conn_string): """Initialize the SQLite span processor with a connection string.""" self.conn_string = conn_string - self.conn = None + self._local = threading.local() # Thread-local storage for connections self.setup_database() - def _get_connection(self) -> sqlite3.Connection: - """Get the database connection.""" - if self.conn is None: - self.conn = sqlite3.connect(self.conn_string, check_same_thread=False) - return self.conn + def _get_connection(self): + """Get a thread-local database connection.""" + if not hasattr(self._local, "conn"): + try: + self._local.conn = sqlite3.connect(self.conn_string) + except Exception as e: + print(f"Error connecting to SQLite database: {e}") + raise e + return self._local.conn def setup_database(self): """Create the necessary tables if they don't exist.""" @@ -168,9 +173,14 @@ class SQLiteSpanProcessor(SpanProcessor): def shutdown(self): """Cleanup any resources.""" - if self.conn: - self.conn.close() - self.conn = None + # We can't access other threads' connections, so we just close our own + if hasattr(self._local, "conn"): + try: + self._local.conn.close() + except Exception as e: + print(f"Error closing SQLite connection: {e}") + finally: + del self._local.conn def force_flush(self, timeout_millis=30000): """Force export of spans."""