forked from phoenix-oss/llama-stack-mirror
fix: Agent uses the first configured vector_db_id when documents are provided (#1276)
# What does this PR do? The agent API allows to query multiple DBs using the `vector_db_ids` argument of the `rag` tool: ```py toolgroups=[ { "name": "builtin::rag", "args": {"vector_db_ids": [vector_db_id]}, } ], ``` This means that multiple DBs can be used to compose an aggregated context by executing the query on each of them. When documents are passed to the next agent turn, there is no explicit way to configure the vector DB where the embeddings will be ingested. In such cases, we can assume that: - if any `vector_db_ids` is given, we use the first one (it probably makes sense to assume that it's the only one in the list, otherwise we should loop on all the given DBs to have a consistent ingestion) - if no `vector_db_ids` is given, we can use the current logic to generate a default DB using the default provider. If multiple providers are defined, the API will fail as expected: the user has to provide details on where to ingest the documents. (Closes #1270) ## Test Plan The issue description details how to replicate the problem. [//]: # (## Documentation) --------- Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
This commit is contained in:
parent
78962be996
commit
fb998683e0
3 changed files with 62 additions and 50 deletions
|
@ -122,7 +122,7 @@ response = agent.create_turn(
|
|||
],
|
||||
documents=[
|
||||
{
|
||||
"content": "https://raw.githubusercontent.com/example/doc.rst",
|
||||
"content": "https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/memory_optimizations.rst",
|
||||
"mime_type": "text/plain",
|
||||
}
|
||||
],
|
||||
|
|
|
@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
|
|
|
@ -16,10 +16,11 @@ from llama_stack.apis.agents import (
|
|||
AgentTurnResponseTurnCompletePayload,
|
||||
StepType,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL
|
||||
from llama_stack.apis.common.content_types import URL, TextDelta
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionResponseEvent,
|
||||
ChatCompletionResponseEventType,
|
||||
ChatCompletionResponseStreamChunk,
|
||||
CompletionMessage,
|
||||
LogProbConfig,
|
||||
|
@ -27,12 +28,15 @@ from llama_stack.apis.inference import (
|
|||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import RunShieldResponse
|
||||
from llama_stack.apis.tools import (
|
||||
ListToolGroupsResponse,
|
||||
ListToolsResponse,
|
||||
Tool,
|
||||
ToolDef,
|
||||
ToolGroup,
|
||||
|
@ -40,7 +44,7 @@ from llama_stack.apis.tools import (
|
|||
ToolInvocationResult,
|
||||
)
|
||||
from llama_stack.apis.vector_io import QueryChunksResponse
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
|
@ -54,36 +58,37 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
class MockInferenceAPI:
|
||||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
model_id: str,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = None,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
tool_config: Optional[ToolConfig] = None,
|
||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||
async def stream_response():
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="start",
|
||||
delta="",
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
delta=TextDelta(text=""),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="progress",
|
||||
delta="AI is a fascinating field...",
|
||||
event_type=ChatCompletionResponseEventType.progress,
|
||||
delta=TextDelta(text="AI is a fascinating field..."),
|
||||
)
|
||||
)
|
||||
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type="complete",
|
||||
delta="",
|
||||
stop_reason="end_of_turn",
|
||||
event_type=ChatCompletionResponseEventType.complete,
|
||||
delta=TextDelta(text=""),
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -133,35 +138,39 @@ class MockToolGroupsAPI:
|
|||
provider_resource_id=toolgroup_id,
|
||||
)
|
||||
|
||||
async def list_tool_groups(self) -> List[ToolGroup]:
|
||||
return []
|
||||
async def list_tool_groups(self) -> ListToolGroupsResponse:
|
||||
return ListToolGroupsResponse(data=[])
|
||||
|
||||
async def list_tools(self, tool_group_id: Optional[str] = None) -> List[Tool]:
|
||||
if tool_group_id == MEMORY_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::rag",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
if tool_group_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return [
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
return []
|
||||
async def list_tools(self, toolgroup_id: Optional[str] = None) -> ListToolsResponse:
|
||||
if toolgroup_id == MEMORY_TOOLGROUP:
|
||||
return ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier=MEMORY_QUERY_TOOL,
|
||||
provider_resource_id=MEMORY_QUERY_TOOL,
|
||||
toolgroup_id=MEMORY_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::rag",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
if toolgroup_id == CODE_INTERPRETER_TOOLGROUP:
|
||||
return ListToolsResponse(
|
||||
data=[
|
||||
Tool(
|
||||
identifier="code_interpreter",
|
||||
provider_resource_id="code_interpreter",
|
||||
toolgroup_id=CODE_INTERPRETER_TOOLGROUP,
|
||||
tool_host=ToolHost.client,
|
||||
description="Mock tool",
|
||||
provider_id="builtin::code_interpreter",
|
||||
parameters=[],
|
||||
)
|
||||
]
|
||||
)
|
||||
return ListToolsResponse(data=[])
|
||||
|
||||
async def get_tool(self, tool_name: str) -> Tool:
|
||||
return Tool(
|
||||
|
@ -174,7 +183,7 @@ class MockToolGroupsAPI:
|
|||
parameters=[],
|
||||
)
|
||||
|
||||
async def unregister_tool_group(self, tool_group_id: str) -> None:
|
||||
async def unregister_tool_group(self, toolgroup_id: str) -> None:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -382,10 +391,11 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
|
|||
chat_agent = await impl.get_agent(response.agent_id)
|
||||
|
||||
tool_defs, _ = await chat_agent._get_tool_defs()
|
||||
tool_defs_names = [t.tool_name for t in tool_defs]
|
||||
if expected_memory:
|
||||
assert MEMORY_QUERY_TOOL in tool_defs
|
||||
assert MEMORY_QUERY_TOOL in tool_defs_names
|
||||
if expected_code_interpreter:
|
||||
assert BuiltinTool.code_interpreter in tool_defs
|
||||
assert BuiltinTool.code_interpreter in tool_defs_names
|
||||
if expected_memory and expected_code_interpreter:
|
||||
# override the tools for turn
|
||||
new_tool_defs, _ = await chat_agent._get_tool_defs(
|
||||
|
@ -396,5 +406,6 @@ async def test_chat_agent_tools(get_agents_impl, toolgroups, expected_memory, ex
|
|||
)
|
||||
]
|
||||
)
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs
|
||||
new_tool_defs_names = [t.tool_name for t in new_tool_defs]
|
||||
assert MEMORY_QUERY_TOOL in new_tool_defs_names
|
||||
assert BuiltinTool.code_interpreter not in new_tool_defs_names
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue