mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 04:00:42 +00:00
update based on feedback
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
917deb28fe
commit
92107f316c
2 changed files with 9 additions and 8 deletions
|
|
@ -117,7 +117,7 @@ class VectorIORouter(VectorIO):
|
|||
# Extract llama-stack-specific parameters from extra_body
|
||||
extra = params.model_extra or {}
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension", 384)
|
||||
embedding_dimension = extra.get("embedding_dimension")
|
||||
provider_id = extra.get("provider_id")
|
||||
|
||||
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
|
||||
|
|
@ -126,20 +126,21 @@ class VectorIORouter(VectorIO):
|
|||
if embedding_model is None:
|
||||
raise ValueError("embedding_model is required in extra_body when creating a vector store")
|
||||
|
||||
# Always extract embedding dimension from the model registry
|
||||
if embedding_dimension is None:
|
||||
embedding_dimension = await self._get_embedding_model_dimension(embedding_model)
|
||||
|
||||
# Auto-select provider if not specified
|
||||
if provider_id is None:
|
||||
if len(self.routing_table.impls_by_provider_id) == 1:
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
logger.info(f"No provider_id specified, using the only available vector_io provider: {provider_id}")
|
||||
else:
|
||||
num_providers = len(self.routing_table.impls_by_provider_id)
|
||||
if num_providers == 0:
|
||||
raise ValueError("No vector_io providers available")
|
||||
if num_providers > 1:
|
||||
available_providers = list(self.routing_table.impls_by_provider_id.keys())
|
||||
raise ValueError(
|
||||
f"Multiple vector_io providers available. Please specify provider_id in extra_body. "
|
||||
f"Available providers: {available_providers}"
|
||||
)
|
||||
provider_id = list(self.routing_table.impls_by_provider_id.keys())[0]
|
||||
|
||||
vector_db_id = f"vs_{uuid.uuid4()}"
|
||||
registered_vector_db = await self.routing_table.register_vector_db(
|
||||
|
|
|
|||
|
|
@ -353,7 +353,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
provider_vector_db_id = extra.get("provider_vector_db_id")
|
||||
embedding_model = extra.get("embedding_model")
|
||||
embedding_dimension = extra.get("embedding_dimension", 384)
|
||||
# use provider_id from router or default to this provider's own ID (need for --stack-config)
|
||||
# use provider_id set by router; fallback to provider's own ID when used directly via --stack-config
|
||||
provider_id = extra.get("provider_id") or getattr(self, "__provider_id__", None)
|
||||
|
||||
# Derive the canonical vector_db_id (allow override, else generate)
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue