diff --git a/llama_stack/core/routers/vector_io.py b/llama_stack/core/routers/vector_io.py index 0e3f9d8d9..d2145f3b1 100644 --- a/llama_stack/core/routers/vector_io.py +++ b/llama_stack/core/routers/vector_io.py @@ -6,12 +6,16 @@ import asyncio import uuid -from typing import Any +from typing import Annotated, Any + +from fastapi import Body from llama_stack.apis.common.content_types import InterleavedContent from llama_stack.apis.models import ModelType from llama_stack.apis.vector_io import ( Chunk, + OpenAICreateVectorStoreFileBatchRequestWithExtraBody, + OpenAICreateVectorStoreRequestWithExtraBody, QueryChunksResponse, SearchRankingOptions, VectorIO, @@ -120,18 +124,13 @@ class VectorIORouter(VectorIO): # OpenAI Vector Stores API endpoints async def openai_create_vector_store( self, - name: str, - file_ids: list[str] | None = None, - expires_after: dict[str, Any] | None = None, - chunking_strategy: dict[str, Any] | None = None, - metadata: dict[str, Any] | None = None, - embedding_model: str | None = None, - embedding_dimension: int | None = None, - provider_id: str | None = None, + params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)], ) -> VectorStoreObject: - logger.debug(f"VectorIORouter.openai_create_vector_store: name={name}, provider_id={provider_id}") + logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={params.provider_id}") # If no embedding model is provided, use the first available one + embedding_model = params.embedding_model + embedding_dimension = params.embedding_dimension if embedding_model is None: embedding_model_info = await self._get_first_embedding_model() if embedding_model_info is None: @@ -144,22 +143,23 @@ class VectorIORouter(VectorIO): vector_db_id=vector_db_id, embedding_model=embedding_model, embedding_dimension=embedding_dimension, - provider_id=provider_id, + provider_id=params.provider_id, provider_vector_db_id=vector_db_id, - vector_db_name=name, + vector_db_name=params.name, ) provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) - return await provider.openai_create_vector_store( - name=name, - file_ids=file_ids, - expires_after=expires_after, - chunking_strategy=chunking_strategy, - metadata=metadata, - embedding_model=embedding_model, - embedding_dimension=embedding_dimension, - provider_id=registered_vector_db.provider_id, - provider_vector_db_id=registered_vector_db.provider_resource_id, - ) + + # Update params with resolved values + params.embedding_model = embedding_model + params.embedding_dimension = embedding_dimension + params.provider_id = registered_vector_db.provider_id + + # Add provider_vector_db_id to extra_body if not already there + if params.model_extra is None: + params.model_extra = {} + params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id + + return await provider.openai_create_vector_store(params) async def openai_list_vector_stores( self, @@ -370,16 +370,14 @@ class VectorIORouter(VectorIO): async def openai_create_vector_store_file_batch( self, vector_store_id: str, - file_ids: list[str], - attributes: dict[str, Any] | None = None, - chunking_strategy: VectorStoreChunkingStrategy | None = None, + params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)], ) -> VectorStoreFileBatchObject: - logger.debug(f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(file_ids)} files") + logger.debug( + f"VectorIORouter.openai_create_vector_store_file_batch: {vector_store_id}, {len(params.file_ids)} files" + ) return await self.routing_table.openai_create_vector_store_file_batch( vector_store_id=vector_store_id, - file_ids=file_ids, - attributes=attributes, - chunking_strategy=chunking_strategy, + params=params, ) async def openai_retrieve_vector_store_file_batch(