mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
fix: Agent uses the first configured vector_db_id when documents are provided (#1276)
# What does this PR do? The agent API allows to query multiple DBs using the `vector_db_ids` argument of the `rag` tool: ```py toolgroups=[ { "name": "builtin::rag", "args": {"vector_db_ids": [vector_db_id]}, } ], ``` This means that multiple DBs can be used to compose an aggregated context by executing the query on each of them. When documents are passed to the next agent turn, there is no explicit way to configure the vector DB where the embeddings will be ingested. In such cases, we can assume that: - if any `vector_db_ids` is given, we use the first one (it probably makes sense to assume that it's the only one in the list, otherwise we should loop on all the given DBs to have a consistent ingestion) - if no `vector_db_ids` is given, we can use the current logic to generate a default DB using the default provider. If multiple providers are defined, the API will fail as expected: the user has to provide details on where to ingest the documents. (Closes #1270) ## Test Plan The issue description details how to replicate the problem. [//]: # (## Documentation) --------- Signed-off-by: Daniele Martinoli <dmartino@redhat.com>
This commit is contained in:
parent
78962be996
commit
fb998683e0
3 changed files with 62 additions and 50 deletions
|
@ -309,13 +309,14 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
|
|||
if provider_vector_db_id is None:
|
||||
provider_vector_db_id = vector_db_id
|
||||
if provider_id is None:
|
||||
# If provider_id not specified, use the only provider if it supports this shield type
|
||||
if len(self.impls_by_provider_id) == 1:
|
||||
if len(self.impls_by_provider_id) > 0:
|
||||
provider_id = list(self.impls_by_provider_id.keys())[0]
|
||||
if len(self.impls_by_provider_id) > 1:
|
||||
logger.warning(
|
||||
f"No provider specified and multiple providers available. Arbitrarily selected the first provider {provider_id}."
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||
)
|
||||
raise ValueError("No provider available. Please configure a vector_io provider.")
|
||||
model = await self.get_object_by_identifier("model", embedding_model)
|
||||
if model is None:
|
||||
raise ValueError(f"Model {embedding_model} not found")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue