fix: extract llama-stack params from model_extra, not as explicit fields

This commit is contained in:
Ashwin Bharambe 2025-10-11 17:21:43 -07:00
parent 8fa91f98ef
commit 3568ccdc81
3 changed files with 23 additions and 32 deletions

View file

@ -477,9 +477,6 @@ class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
:param expires_after: (Optional) Expiration policy for the vector store :param expires_after: (Optional) Expiration policy for the vector store
:param chunking_strategy: (Optional) Strategy for splitting files into chunks :param chunking_strategy: (Optional) Strategy for splitting files into chunks
:param metadata: Set of key-value pairs that can be attached to the vector store :param metadata: Set of key-value pairs that can be attached to the vector store
:param embedding_model: (Optional) The embedding model to use for this vector store
:param embedding_dimension: (Optional) The dimension of the embedding vectors (default: 384)
:param provider_id: (Optional) The ID of the provider to use for this vector store
""" """
name: str | None = None name: str | None = None
@ -487,9 +484,6 @@ class OpenAICreateVectorStoreRequestWithExtraBody(BaseModel, extra="allow"):
expires_after: dict[str, Any] | None = None expires_after: dict[str, Any] | None = None
chunking_strategy: dict[str, Any] | None = None chunking_strategy: dict[str, Any] | None = None
metadata: dict[str, Any] | None = None metadata: dict[str, Any] | None = None
embedding_model: str | None = None
embedding_dimension: int | None = 384
provider_id: str | None = None
# extra_body can be accessed via .model_extra # extra_body can be accessed via .model_extra

View file

@ -126,11 +126,15 @@ class VectorIORouter(VectorIO):
self, self,
params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)], params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject: ) -> VectorStoreObject:
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={params.provider_id}") # Extract llama-stack-specific parameters from extra_body
extra = params.model_extra or {}
embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension", 384)
provider_id = extra.get("provider_id")
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one # If no embedding model is provided, use the first available one
embedding_model = params.embedding_model
embedding_dimension = params.embedding_dimension
if embedding_model is None: if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model() embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None: if embedding_model_info is None:
@ -143,22 +147,13 @@ class VectorIORouter(VectorIO):
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model, embedding_model=embedding_model,
embedding_dimension=embedding_dimension, embedding_dimension=embedding_dimension,
provider_id=params.provider_id, provider_id=provider_id,
provider_vector_db_id=vector_db_id, provider_vector_db_id=vector_db_id,
vector_db_name=params.name, vector_db_name=params.name,
) )
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
# Update params with resolved values # Pass params as-is to provider - it will extract what it needs from model_extra
params.embedding_model = embedding_model
params.embedding_dimension = embedding_dimension
params.provider_id = registered_vector_db.provider_id
# Add provider_vector_db_id to extra_body if not already there
if params.model_extra is None:
params.model_extra = {}
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
return await provider.openai_create_vector_store(params) return await provider.openai_create_vector_store(params)
async def openai_list_vector_stores( async def openai_list_vector_stores(

View file

@ -347,30 +347,32 @@ class OpenAIVectorStoreMixin(ABC):
"""Creates a vector store.""" """Creates a vector store."""
created_at = int(time.time()) created_at = int(time.time())
# Extract provider_vector_db_id from extra_body if present # Extract llama-stack-specific parameters from extra_body
provider_vector_db_id = None extra = params.model_extra or {}
if params.model_extra and "provider_vector_db_id" in params.model_extra: provider_vector_db_id = extra.get("provider_vector_db_id")
provider_vector_db_id = params.model_extra["provider_vector_db_id"] embedding_model = extra.get("embedding_model")
embedding_dimension = extra.get("embedding_dimension", 384)
provider_id = extra.get("provider_id")
# Derive the canonical vector_db_id (allow override, else generate) # Derive the canonical vector_db_id (allow override, else generate)
vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}") vector_db_id = provider_vector_db_id or generate_object_id("vector_store", lambda: f"vs_{uuid.uuid4()}")
if params.provider_id is None: if provider_id is None:
raise ValueError("Provider ID is required") raise ValueError("Provider ID is required")
if params.embedding_model is None: if embedding_model is None:
raise ValueError("Embedding model is required") raise ValueError("Embedding model is required")
# Embedding dimension is required (defaulted to 384 if not provided) # Embedding dimension is required (defaulted to 384 if not provided)
if params.embedding_dimension is None: if embedding_dimension is None:
raise ValueError("Embedding dimension is required") raise ValueError("Embedding dimension is required")
# Register the VectorDB backing this vector store # Register the VectorDB backing this vector store
vector_db = VectorDB( vector_db = VectorDB(
identifier=vector_db_id, identifier=vector_db_id,
embedding_dimension=params.embedding_dimension, embedding_dimension=embedding_dimension,
embedding_model=params.embedding_model, embedding_model=embedding_model,
provider_id=params.provider_id, provider_id=provider_id,
provider_resource_id=vector_db_id, provider_resource_id=vector_db_id,
vector_db_name=params.name, vector_db_name=params.name,
) )
@ -404,8 +406,8 @@ class OpenAIVectorStoreMixin(ABC):
# Add provider information to metadata if provided # Add provider information to metadata if provided
metadata = params.metadata or {} metadata = params.metadata or {}
if params.provider_id: if provider_id:
metadata["provider_id"] = params.provider_id metadata["provider_id"] = provider_id
if provider_vector_db_id: if provider_vector_db_id:
metadata["provider_vector_db_id"] = provider_vector_db_id metadata["provider_vector_db_id"] = provider_vector_db_id
store_info["metadata"] = metadata store_info["metadata"] = metadata