fix: Agent uses the first configured vector_db_id when documents are provided (#1276)

# What does this PR do?
The agent API allows to query multiple DBs using the `vector_db_ids`
argument of the `rag` tool:
```py
        toolgroups=[
            {
                "name": "builtin::rag",
                "args": {"vector_db_ids": [vector_db_id]},
            }
        ],
```
This means that multiple DBs can be used to compose an aggregated
context by executing the query on each of them.

When documents are passed to the next agent turn, there is no explicit
way to configure the vector DB where the embeddings will be ingested. In
such cases, we can assume that:
- if any `vector_db_ids` is given, we use the first one (it probably
makes sense to assume that it's the only one in the list, otherwise we
should loop on all the given DBs to have a consistent ingestion)
- if no `vector_db_ids` is given, we can use the current logic to
generate a default DB using the default provider. If multiple providers
are defined, the API will fail as expected: the user has to provide
details on where to ingest the documents.

(Closes #1270)

## Test Plan
The issue description details how to replicate the problem.

[//]: # (## Documentation)

---------

Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
This commit is contained in:
Daniele Martinoli 2025-03-05 06:44:13 +01:00 committed by GitHub
parent 78962be996
commit fb998683e0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 62 additions and 50 deletions

View file

@ -122,7 +122,7 @@ response = agent.create_turn(
], ],
documents=[ documents=[
{ {
"content": "https://raw.githubusercontent.com/example/doc.rst", "content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
"mime_type": "text/plain", "mime_type": "text/plain",
} }
], ],

View file

@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
if provider_vector_db_id is None: if provider_vector_db_id is None:
provider_vector_db_id = vector_db_id provider_vector_db_id = vector_db_id
if provider_id is None: if provider_id is None:
# If provider_id not specified, use the only provider if it supports this shield type if len(self.impls_by_provider_id) > 0:
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0] provider_id = list(self.impls_by_provider_id.keys())[0]
if len(self.impls_by_provider_id) > 1:
logger.warning(
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
)
else: else:
raise ValueError( raise ValueError("No provider available. Please configure a vector_io provider.")
"No provider specified and multiple providers available. Please specify a provider_id."
)
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")

View file

@ -16,10 +16,11 @@ from llama_stack.apis.agents import (
AgentTurnResponseTurnCompletePayload, AgentTurnResponseTurnCompletePayload,
StepType, StepType,
) )
from llama_stack.apis.common.content_types import URL from llama_stack.apis.common.content_types import URL, TextDelta
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
ChatCompletionResponseStreamChunk, ChatCompletionResponseStreamChunk,
CompletionMessage, CompletionMessage,
LogProbConfig, LogProbConfig,
@ -27,12 +28,15 @@ from llama_stack.apis.inference import (
ResponseFormat, ResponseFormat,
SamplingParams, SamplingParams,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition, ToolDefinition,
ToolPromptFormat, ToolPromptFormat,
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import RunShieldResponse from llama_stack.apis.safety import RunShieldResponse
from llama_stack.apis.tools import ( from llama_stack.apis.tools import (
ListToolGroupsResponse,
ListToolsResponse,
Tool, Tool,
ToolDef, ToolDef,
ToolGroup, ToolGroup,
@ -40,7 +44,7 @@ from llama_stack.apis.tools import (
ToolInvocationResult, ToolInvocationResult,
) )
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
from llama_stack.models.llama.datatypes import BuiltinTool from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
from llama_stack.providers.inline.agents.meta_reference.agent_instance import ( from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL, MEMORY_QUERY_TOOL,
) )
@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
class MockInferenceAPI: class MockInferenceAPI:
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None, tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = None, tool_choice: Optional[ToolChoice] = None,
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
async def stream_response(): async def stream_response():
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="start", event_type=ChatCompletionResponseEventType.start,
delta="", delta=TextDelta(text=""),
) )
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="progress", event_type=ChatCompletionResponseEventType.progress,
delta="AI is a fascinating field...", delta=TextDelta(text="AI is a fascinating field..."),
) )
) )
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type="complete", event_type=ChatCompletionResponseEventType.complete,
delta="", delta=TextDelta(text=""),
stop_reason="end_of_turn", stop_reason=StopReason.end_of_turn,
) )
) )
@ -133,35 +138,39 @@ class MockToolGroupsAPI:
provider_resource_id=toolgroup_id, provider_resource_id=toolgroup_id,
) )
async def list_tool_groups(self) -> List[ToolGroup]: async def list_tool_groups(self) -> ListToolGroupsResponse:
return [] return ListToolGroupsResponse(data=[])
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]: async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
if tool_group_id == MEMORY_TOOLGROUP: if toolgroup_id == MEMORY_TOOLGROUP:
return [ return ListToolsResponse(
Tool( data=[
identifier=MEMORY_QUERY_TOOL, Tool(
provider_resource_id=MEMORY_QUERY_TOOL, identifier=MEMORY_QUERY_TOOL,
toolgroup_id=MEMORY_TOOLGROUP, provider_resource_id=MEMORY_QUERY_TOOL,
tool_host=ToolHost.client, toolgroup_id=MEMORY_TOOLGROUP,
description="Mock tool", tool_host=ToolHost.client,
provider_id="builtin::rag", description="Mock tool",
parameters=[], provider_id="builtin::rag",
) parameters=[],
] )
if tool_group_id == CODE_INTERPRETER_TOOLGROUP: ]
return [ )
Tool( if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
identifier="code_interpreter", return ListToolsResponse(
provider_resource_id="code_interpreter", data=[
toolgroup_id=CODE_INTERPRETER_TOOLGROUP, Tool(
tool_host=ToolHost.client, identifier="code_interpreter",
description="Mock tool", provider_resource_id="code_interpreter",
provider_id="builtin::code_interpreter", toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
parameters=[], tool_host=ToolHost.client,
) description="Mock tool",
] provider_id="builtin::code_interpreter",
return [] parameters=[],
)
]
)
return ListToolsResponse(data=[])
async def get_tool(self, tool_name: str) -> Tool: async def get_tool(self, tool_name: str) -> Tool:
return Tool( return Tool(
@ -174,7 +183,7 @@ class MockToolGroupsAPI:
parameters=[], parameters=[],
) )
async def unregister_tool_group(self, tool_group_id: str) -> None: async def unregister_tool_group(self, toolgroup_id: str) -> None:
pass pass
@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
chat_agent = await impl.get_agent(response.agent_id) chat_agent = await impl.get_agent(response.agent_id)
tool_defs, _ = await chat_agent._get_tool_defs() tool_defs, _ = await chat_agent._get_tool_defs()
tool_defs_names = [t.tool_name for t in tool_defs]
if expected_memory: if expected_memory:
assert MEMORY_QUERY_TOOL in tool_defs assert MEMORY_QUERY_TOOL in tool_defs_names
if expected_code_interpreter: if expected_code_interpreter:
assert BuiltinTool.code_interpreter in tool_defs assert BuiltinTool.code_interpreter in tool_defs_names
if expected_memory and expected_code_interpreter: if expected_memory and expected_code_interpreter:
# override the tools for turn # override the tools for turn
new_tool_defs, _ = await chat_agent._get_tool_defs( new_tool_defs, _ = await chat_agent._get_tool_defs(
@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
) )
] ]
) )
assert MEMORY_QUERY_TOOL in new_tool_defs new_tool_defs_names = [t.tool_name for t in new_tool_defs]
assert BuiltinTool.code_interpreter not in new_tool_defs assert MEMORY_QUERY_TOOL in new_tool_defs_names
assert BuiltinTool.code_interpreter not in new_tool_defs_names