From fb998683e053a8d1de80f47528131e1743ad0d29 Mon Sep 17 00:00:00 2001 From: Daniele Martinoli <86618610+dmartinol@users.noreply.github.com> Date: Wed, 5 Mar 2025 06:44:13 +0100 Subject: [PATCH] fix: Agent uses the first configured vector_db_id when documents are provided (#1276) # What does this PR do? The agent API allows to query multiple DBs using the `vector_db_ids` argument of the `rag` tool: ```py toolgroups=[ { "name": "builtin::rag", "args": {"vector_db_ids": [vector_db_id]}, } ], ``` This means that multiple DBs can be used to compose an aggregated context by executing the query on each of them. When documents are passed to the next agent turn, there is no explicit way to configure the vector DB where the embeddings will be ingested. In such cases, we can assume that: - if any `vector_db_ids` is given, we use the first one (it probably makes sense to assume that it's the only one in the list, otherwise we should loop on all the given DBs to have a consistent ingestion) - if no `vector_db_ids` is given, we can use the current logic to generate a default DB using the default provider. If multiple providers are defined, the API will fail as expected: the user has to provide details on where to ingest the documents. (Closes #1270) ## Test Plan The issue description details how to replicate the problem. [//]: # (## Documentation) --------- Signed-off-by: Daniele Martinoli --- docs/source/building_applications/rag.md | 2 +- .../distribution/routers/routing_tables.py | 11 ++- .../meta_reference/tests/test_chat_agent.py | 99 ++++++++++--------- 3 files changed, 62 insertions(+), 50 deletions(-) diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 03b71e057..acbc07ca4 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -122,7 +122,7 @@ response = agent.create_turn( ], documents=[ { - "content": "https://raw.githubusercontent.com/example/doc.rst", + "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst", "mime_type": "text/plain", } ], diff --git a/llama_stack/distribution/routers/routing_tables.py b/llama_stack/distribution/routers/routing_tables.py index 73f9c9672..1be43ec8b 100644 --- a/llama_stack/distribution/routers/routing_tables.py +++ b/llama_stack/distribution/routers/routing_tables.py @@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): if provider_vector_db_id is None: provider_vector_db_id = vector_db_id if provider_id is None: - # If provider_id not specified, use the only provider if it supports this shield type - if len(self.impls_by_provider_id) == 1: + if len(self.impls_by_provider_id) > 0: provider_id = list(self.impls_by_provider_id.keys())[0] + if len(self.impls_by_provider_id) > 1: + logger.warning( + f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}." + ) else: - raise ValueError( - "No provider specified and multiple providers available. Please specify a provider_id." - ) + raise ValueError("No provider available. Please configure a vector_io provider.") model = await self.get_object_by_identifier("model", embedding_model) if model is None: raise ValueError(f"Model {embedding_model} not found") diff --git a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py index b802937b6..84ab364b7 100644 --- a/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py +++ b/llama_stack/providers/inline/agents/meta_reference/tests/test_chat_agent.py @@ -16,10 +16,11 @@ from llama_stack.apis.agents import ( AgentTurnResponseTurnCompletePayload, StepType, ) -from llama_stack.apis.common.content_types import URL +from llama_stack.apis.common.content_types import URL, TextDelta from llama_stack.apis.inference import ( ChatCompletionResponse, ChatCompletionResponseEvent, + ChatCompletionResponseEventType, ChatCompletionResponseStreamChunk, CompletionMessage, LogProbConfig, @@ -27,12 +28,15 @@ from llama_stack.apis.inference import ( ResponseFormat, SamplingParams, ToolChoice, + ToolConfig, ToolDefinition, ToolPromptFormat, UserMessage, ) from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.tools import ( + ListToolGroupsResponse, + ListToolsResponse, Tool, ToolDef, ToolGroup, @@ -40,7 +44,7 @@ from llama_stack.apis.tools import ( ToolInvocationResult, ) from llama_stack.apis.vector_io import QueryChunksResponse -from llama_stack.models.llama.datatypes import BuiltinTool +from llama_stack.models.llama.datatypes import BuiltinTool, StopReason from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( MEMORY_QUERY_TOOL, ) @@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig class MockInferenceAPI: async def chat_completion( self, - model: str, + model_id: str, messages: List[Message], sampling_params: Optional[SamplingParams] = SamplingParams(), - response_format: Optional[ResponseFormat] = None, tools: Optional[List[ToolDefinition]] = None, tool_choice: Optional[ToolChoice] = None, tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: async def stream_response(): yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="start", - delta="", + event_type=ChatCompletionResponseEventType.start, + delta=TextDelta(text=""), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="progress", - delta="AI is a fascinating field...", + event_type=ChatCompletionResponseEventType.progress, + delta=TextDelta(text="AI is a fascinating field..."), ) ) yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( - event_type="complete", - delta="", - stop_reason="end_of_turn", + event_type=ChatCompletionResponseEventType.complete, + delta=TextDelta(text=""), + stop_reason=StopReason.end_of_turn, ) ) @@ -133,35 +138,39 @@ class MockToolGroupsAPI: provider_resource_id=toolgroup_id, ) - async def list_tool_groups(self) -> List[ToolGroup]: - return [] + async def list_tool_groups(self) -> ListToolGroupsResponse: + return ListToolGroupsResponse(data=[]) - async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: - if tool_group_id == MEMORY_TOOLGROUP: - return [ - Tool( - identifier=MEMORY_QUERY_TOOL, - provider_resource_id=MEMORY_QUERY_TOOL, - toolgroup_id=MEMORY_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::rag", - parameters=[], - ) - ] - if tool_group_id == CODE_INTERPRETER_TOOLGROUP: - return [ - Tool( - identifier="code_interpreter", - provider_resource_id="code_interpreter", - toolgroup_id=CODE_INTERPRETER_TOOLGROUP, - tool_host=ToolHost.client, - description="Mock tool", - provider_id="builtin::code_interpreter", - parameters=[], - ) - ] - return [] + async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse: + if toolgroup_id == MEMORY_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier=MEMORY_QUERY_TOOL, + provider_resource_id=MEMORY_QUERY_TOOL, + toolgroup_id=MEMORY_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::rag", + parameters=[], + ) + ] + ) + if toolgroup_id == CODE_INTERPRETER_TOOLGROUP: + return ListToolsResponse( + data=[ + Tool( + identifier="code_interpreter", + provider_resource_id="code_interpreter", + toolgroup_id=CODE_INTERPRETER_TOOLGROUP, + tool_host=ToolHost.client, + description="Mock tool", + provider_id="builtin::code_interpreter", + parameters=[], + ) + ] + ) + return ListToolsResponse(data=[]) async def get_tool(self, tool_name: str) -> Tool: return Tool( @@ -174,7 +183,7 @@ class MockToolGroupsAPI: parameters=[], ) - async def unregister_tool_group(self, tool_group_id: str) -> None: + async def unregister_tool_group(self, toolgroup_id: str) -> None: pass @@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex chat_agent = await impl.get_agent(response.agent_id) tool_defs, _ = await chat_agent._get_tool_defs() + tool_defs_names = [t.tool_name for t in tool_defs] if expected_memory: - assert MEMORY_QUERY_TOOL in tool_defs + assert MEMORY_QUERY_TOOL in tool_defs_names if expected_code_interpreter: - assert BuiltinTool.code_interpreter in tool_defs + assert BuiltinTool.code_interpreter in tool_defs_names if expected_memory and expected_code_interpreter: # override the tools for turn new_tool_defs, _ = await chat_agent._get_tool_defs( @@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex ) ] ) - assert MEMORY_QUERY_TOOL in new_tool_defs - assert BuiltinTool.code_interpreter not in new_tool_defs + new_tool_defs_names = [t.tool_name for t in new_tool_defs] + assert MEMORY_QUERY_TOOL in new_tool_defs_names + assert BuiltinTool.code_interpreter not in new_tool_defs_names