This commit is contained in:
Ashwin Bharambe 2025-10-11 21:52:30 -07:00
parent 58fcaa445e
commit bf59d26362
3 changed files with 24 additions and 8 deletions

View file

@ -135,6 +135,8 @@ class VectorIORouter(VectorIO):
logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}") logger.debug(f"VectorIORouter.openai_create_vector_store: name={params.name}, provider_id={provider_id}")
# If no embedding model is provided, use the first available one # If no embedding model is provided, use the first available one
# TODO: this branch will soon be deleted so you _must_ provide the embedding_model when
# creating a vector store
if embedding_model is None: if embedding_model is None:
embedding_model_info = await self._get_first_embedding_model() embedding_model_info = await self._get_first_embedding_model()
if embedding_model_info is None: if embedding_model_info is None:
@ -153,7 +155,14 @@ class VectorIORouter(VectorIO):
) )
provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier) provider = await self.routing_table.get_provider_impl(registered_vector_db.identifier)
# Pass params as-is to provider - it will extract what it needs from model_extra # Update model_extra with registered values so provider uses the already-registered vector_db
if params.model_extra is None:
params.model_extra = {}
params.model_extra["provider_vector_db_id"] = registered_vector_db.provider_resource_id
params.model_extra["provider_id"] = registered_vector_db.provider_id
params.model_extra["embedding_model"] = embedding_model
params.model_extra["embedding_dimension"] = embedding_dimension
return await provider.openai_create_vector_store(params) return await provider.openai_create_vector_store(params)
async def openai_list_vector_stores( async def openai_list_vector_stores(

View file

@ -10,8 +10,9 @@ import mimetypes
import time import time
import uuid import uuid
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Any from typing import Annotated, Any
from fastapi import Body
from pydantic import TypeAdapter from pydantic import TypeAdapter
from llama_stack.apis.common.errors import VectorStoreNotFoundError from llama_stack.apis.common.errors import VectorStoreNotFoundError
@ -342,7 +343,7 @@ class OpenAIVectorStoreMixin(ABC):
async def openai_create_vector_store( async def openai_create_vector_store(
self, self,
params: OpenAICreateVectorStoreRequestWithExtraBody, params: Annotated[OpenAICreateVectorStoreRequestWithExtraBody, Body(...)],
) -> VectorStoreObject: ) -> VectorStoreObject:
"""Creates a vector store.""" """Creates a vector store."""
created_at = int(time.time()) created_at = int(time.time())
@ -978,7 +979,7 @@ class OpenAIVectorStoreMixin(ABC):
async def openai_create_vector_store_file_batch( async def openai_create_vector_store_file_batch(
self, self,
vector_store_id: str, vector_store_id: str,
params: OpenAICreateVectorStoreFileBatchRequestWithExtraBody, params: Annotated[OpenAICreateVectorStoreFileBatchRequestWithExtraBody, Body(...)],
) -> VectorStoreFileBatchObject: ) -> VectorStoreFileBatchObject:
"""Create a vector store file batch.""" """Create a vector store file batch."""
if vector_store_id not in self.openai_vector_stores: if vector_store_id not in self.openai_vector_stores:

View file

@ -21,6 +21,7 @@ from llama_stack.apis.common.content_types import (
URL, URL,
InterleavedContent, InterleavedContent,
) )
from llama_stack.apis.inference import OpenAIEmbeddingsRequestWithExtraBody
from llama_stack.apis.tools import RAGDocument from llama_stack.apis.tools import RAGDocument
from llama_stack.apis.vector_dbs import VectorDB from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse from llama_stack.apis.vector_io import Chunk, ChunkMetadata, QueryChunksResponse
@ -274,10 +275,11 @@ class VectorDBWithIndex:
_validate_embedding(c.embedding, i, self.vector_db.embedding_dimension) _validate_embedding(c.embedding, i, self.vector_db.embedding_dimension)
if chunks_to_embed: if chunks_to_embed:
resp = await self.inference_api.openai_embeddings( params = OpenAIEmbeddingsRequestWithExtraBody(
self.vector_db.embedding_model, model=self.vector_db.embedding_model,
[c.content for c in chunks_to_embed], input=[c.content for c in chunks_to_embed],
) )
resp = await self.inference_api.openai_embeddings(params)
for c, data in zip(chunks_to_embed, resp.data, strict=False): for c, data in zip(chunks_to_embed, resp.data, strict=False):
c.embedding = data.embedding c.embedding = data.embedding
@ -316,7 +318,11 @@ class VectorDBWithIndex:
if mode == "keyword": if mode == "keyword":
return await self.index.query_keyword(query_string, k, score_threshold) return await self.index.query_keyword(query_string, k, score_threshold)
embeddings_response = await self.inference_api.openai_embeddings(self.vector_db.embedding_model, [query_string]) params = OpenAIEmbeddingsRequestWithExtraBody(
model=self.vector_db.embedding_model,
input=[query_string],
)
embeddings_response = await self.inference_api.openai_embeddings(params)
query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32) query_vector = np.array(embeddings_response.data[0].embedding, dtype=np.float32)
if mode == "hybrid": if mode == "hybrid":
return await self.index.query_hybrid( return await self.index.query_hybrid(