mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-27 18:50:41 +00:00
chore: Updating Milvus to use OpenAIVectorStoreMixin
This commit is contained in:
parent
0883944bc3
commit
8a19f69009
6 changed files with 138 additions and 129 deletions
|
@ -14,6 +14,6 @@ from .config import MilvusVectorIOConfig
|
|||
async def get_provider_impl(config: MilvusVectorIOConfig, deps: dict[Api, Any]):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusVectorIOAdapter
|
||||
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference])
|
||||
impl = MilvusVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||
await impl.initialize()
|
||||
return impl
|
||||
|
|
|
@ -8,13 +8,24 @@ from typing import Any
|
|||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.utils.kvstore.config import (
|
||||
KVStoreConfig,
|
||||
SqliteKVStoreConfig,
|
||||
)
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MilvusVectorIOConfig(BaseModel):
|
||||
db_path: str
|
||||
kvstore: KVStoreConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||
return {"db_path": "${env.MILVUS_DB_PATH}"}
|
||||
return {
|
||||
"db_path": "${env.MILVUS_DB_PATH}",
|
||||
"kvstore": SqliteKVStoreConfig.sample_run_config(
|
||||
__distro_dir__=__distro_dir__,
|
||||
db_name="milvus_registry.db",
|
||||
),
|
||||
}
|
||||
|
|
|
@ -130,5 +130,6 @@ def available_providers() -> list[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.vector_io.milvus",
|
||||
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
|
||||
api_dependencies=[Api.inference],
|
||||
optional_api_dependencies=[Api.files],
|
||||
),
|
||||
]
|
||||
|
|
|
@ -5,34 +5,27 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from numpy.typing import NDArray
|
||||
from pymilvus import MilvusClient
|
||||
|
||||
from llama_stack.apis.inference import InterleavedContent
|
||||
from llama_stack.apis.files.files import Files
|
||||
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||
from llama_stack.apis.vector_dbs import VectorDB
|
||||
from llama_stack.apis.vector_io import (
|
||||
Chunk,
|
||||
QueryChunksResponse,
|
||||
SearchRankingOptions,
|
||||
VectorIO,
|
||||
VectorStoreChunkingStrategy,
|
||||
VectorStoreDeleteResponse,
|
||||
VectorStoreFileContentsResponse,
|
||||
VectorStoreFileObject,
|
||||
VectorStoreFileStatus,
|
||||
VectorStoreListFilesResponse,
|
||||
VectorStoreListResponse,
|
||||
VectorStoreObject,
|
||||
VectorStoreSearchResponsePage,
|
||||
)
|
||||
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.datatypes import VectorDBsProtocolPrivate
|
||||
from llama_stack.providers.inline.vector_io.milvus import MilvusVectorIOConfig as InlineMilvusVectorIOConfig
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||
from llama_stack.providers.utils.memory.vector_store import (
|
||||
EmbeddingIndex,
|
||||
VectorDBWithIndex,
|
||||
|
@ -42,12 +35,22 @@ from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
|
|||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
VECTOR_INDEX_PREFIX = f"vector_index:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||
|
||||
|
||||
class MilvusIndex(EmbeddingIndex):
|
||||
def __init__(self, client: MilvusClient, collection_name: str, consistency_level="Strong"):
|
||||
def __init__(
|
||||
self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None
|
||||
):
|
||||
self.client = client
|
||||
self.collection_name = collection_name.replace("-", "_")
|
||||
self.consistency_level = consistency_level
|
||||
self.kvstore = kvstore
|
||||
|
||||
async def delete(self):
|
||||
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
|
||||
|
@ -68,11 +71,9 @@ class MilvusIndex(EmbeddingIndex):
|
|||
|
||||
data = []
|
||||
for chunk, embedding in zip(chunks, embeddings, strict=False):
|
||||
chunk_id = generate_chunk_id(chunk.metadata["document_id"], chunk.content)
|
||||
|
||||
data.append(
|
||||
{
|
||||
"chunk_id": chunk_id,
|
||||
"chunk_id": chunk.chunk_id,
|
||||
"vector": embedding,
|
||||
"chunk_content": chunk.model_dump(),
|
||||
}
|
||||
|
@ -120,16 +121,42 @@ class MilvusIndex(EmbeddingIndex):
|
|||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||
|
||||
|
||||
class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||
def __init__(
|
||||
self, config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig, inference_api: Api.inference
|
||||
self,
|
||||
config: RemoteMilvusVectorIOConfig | InlineMilvusVectorIOConfig,
|
||||
inference_api: Inference,
|
||||
files_api: Files | None,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.cache = {}
|
||||
self.client = None
|
||||
self.inference_api = inference_api
|
||||
self.files_api = files_api
|
||||
self.kvstore: KVStore | None = None
|
||||
self.vector_db_store = None
|
||||
self.openai_vector_stores: dict[str, dict[str, Any]] = {}
|
||||
self.metadata_collection_name = "openai_vector_stores_metadata"
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.kvstore = await kvstore_impl(self.config.kvstore)
|
||||
start_key = VECTOR_DBS_PREFIX
|
||||
end_key = f"{VECTOR_DBS_PREFIX}\xff"
|
||||
stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key)
|
||||
|
||||
for vector_db_data in stored_vector_dbs:
|
||||
vector_db = VectorDB.mdel_validate_json(vector_db_data)
|
||||
index = VectorDBWithIndex(
|
||||
vector_db,
|
||||
index=await MilvusIndex(
|
||||
client=self.client,
|
||||
collection_name=vector_db.identifier,
|
||||
consistency_level=self.config.consistency_level,
|
||||
kvstore=self.kvstore,
|
||||
),
|
||||
inference_api=self.inference_api,
|
||||
)
|
||||
self.cache[vector_db.identifier] = index
|
||||
if isinstance(self.config, RemoteMilvusVectorIOConfig):
|
||||
logger.info(f"Connecting to Milvus server at {self.config.uri}")
|
||||
self.client = MilvusClient(**self.config.model_dump(exclude_none=True))
|
||||
|
@ -138,6 +165,8 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
uri = os.path.expanduser(self.config.db_path)
|
||||
self.client = MilvusClient(uri=uri)
|
||||
|
||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
self.client.close()
|
||||
|
||||
|
@ -202,116 +231,62 @@ class MilvusVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
|||
|
||||
return await index.query_chunks(query, params)
|
||||
|
||||
async def openai_create_vector_store(
|
||||
self,
|
||||
name: str,
|
||||
file_ids: list[str] | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
chunking_strategy: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
embedding_model: str | None = None,
|
||||
embedding_dimension: int | None = 384,
|
||||
provider_id: str | None = None,
|
||||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
async def _save_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Save vector store metadata to persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
async def openai_list_vector_stores(
|
||||
self,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
) -> VectorStoreListResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
async def _update_openai_vector_store(self, store_id: str, store_info: dict[str, Any]) -> None:
|
||||
"""Update vector store metadata in persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(store_info))
|
||||
self.openai_vector_stores[store_id] = store_info
|
||||
|
||||
async def openai_retrieve_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
async def _delete_openai_vector_store_from_storage(self, store_id: str) -> None:
|
||||
"""Delete vector store metadata from persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_PREFIX}{store_id}"
|
||||
await self.kvstore.delete(key)
|
||||
|
||||
async def openai_update_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
name: str | None = None,
|
||||
expires_after: dict[str, Any] | None = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> VectorStoreObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to Milvus database."""
|
||||
assert self.kvstore is not None
|
||||
key = f"{OPENAI_VECTOR_STORES_FILES_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=key, value=json.dumps(file_info))
|
||||
content_key = f"{OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX}{store_id}:{file_id}"
|
||||
await self.kvstore.set(key=content_key, value=json.dumps(file_contents))
|
||||
|
||||
async def openai_delete_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
) -> VectorStoreDeleteResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||
async def _load_openai_vector_stores(self) -> dict[str, dict[str, Any]]:
|
||||
"""Load all vector store metadata from persistent storage."""
|
||||
assert self.kvstore is not None
|
||||
start_key = OPENAI_VECTOR_STORES_PREFIX
|
||||
end_key = f"{OPENAI_VECTOR_STORES_PREFIX}\xff"
|
||||
stored = await self.kvstore.values_in_range(start_key, end_key)
|
||||
return {json.loads(s)["id"]: json.loads(s) for s in stored}
|
||||
|
||||
async def openai_search_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
query: str | list[str],
|
||||
filters: dict[str, Any] | None = None,
|
||||
max_num_results: int | None = 10,
|
||||
ranking_options: SearchRankingOptions | None = None,
|
||||
rewrite_query: bool | None = False,
|
||||
search_mode: str | None = "vector",
|
||||
) -> VectorStoreSearchResponsePage:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
async def _save_openai_vector_store_file(
|
||||
self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]]
|
||||
) -> None:
|
||||
"""Save vector store file metadata to Milvus database."""
|
||||
raise NotImplementedError("Files API not yet implemented for Milvus")
|
||||
|
||||
async def openai_attach_file_to_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
chunking_strategy: VectorStoreChunkingStrategy | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]:
|
||||
"""Load vector store file metadata from Milvus database."""
|
||||
raise NotImplementedError("Files API not yet implemented for Milvus")
|
||||
|
||||
async def openai_list_files_in_vector_store(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
limit: int | None = 20,
|
||||
order: str | None = "desc",
|
||||
after: str | None = None,
|
||||
before: str | None = None,
|
||||
filter: VectorStoreFileStatus | None = None,
|
||||
) -> VectorStoreListFilesResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]:
|
||||
"""Load vector store file contents from Milvus database."""
|
||||
raise NotImplementedError("Files API not yet implemented for Milvus")
|
||||
|
||||
async def openai_retrieve_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None:
|
||||
"""Update vector store file metadata in Milvus database."""
|
||||
raise NotImplementedError("Files API not yet implemented for Milvus")
|
||||
|
||||
async def openai_retrieve_vector_store_file_contents(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileContentsResponse:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
|
||||
async def openai_update_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
attributes: dict[str, Any] | None = None,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
|
||||
async def openai_delete_vector_store_file(
|
||||
self,
|
||||
vector_store_id: str,
|
||||
file_id: str,
|
||||
) -> VectorStoreFileObject:
|
||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Milvus")
|
||||
|
||||
|
||||
def generate_chunk_id(document_id: str, chunk_text: str) -> str:
|
||||
"""Generate a unique chunk ID using a hash of document ID and chunk text."""
|
||||
hash_input = f"{document_id}:{chunk_text}".encode()
|
||||
return str(uuid.UUID(hashlib.md5(hash_input).hexdigest()))
|
||||
|
||||
|
||||
# TODO: refactor this generate_chunk_id along with the `sqlite-vec` implementation into a separate utils file
|
||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||
"""Delete vector store file metadata from Milvus database."""
|
||||
raise NotImplementedError("Files API not yet implemented for Milvus")
|
||||
|
|
|
@ -42,6 +42,12 @@ logger = logging.getLogger(__name__)
|
|||
# Constants for OpenAI vector stores
|
||||
CHUNK_MULTIPLIER = 5
|
||||
|
||||
VERSION = "v3"
|
||||
VECTOR_DBS_PREFIX = f"vector_dbs:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_PREFIX = f"openai_vector_stores:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:{VERSION}::"
|
||||
OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:{VERSION}::"
|
||||
|
||||
|
||||
class OpenAIVectorStoreMixin(ABC):
|
||||
"""
|
||||
|
@ -141,7 +147,6 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
provider_vector_db_id: str | None = None,
|
||||
) -> VectorStoreObject:
|
||||
"""Creates a vector store."""
|
||||
# store and vector_db have the same id
|
||||
store_id = name or str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
|
||||
|
@ -315,7 +320,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
await self._delete_openai_vector_store_from_storage(vector_store_id)
|
||||
|
||||
# Delete from in-memory cache
|
||||
del self.openai_vector_stores[vector_store_id]
|
||||
self.openai_vector_stores.pop(vector_store_id, None)
|
||||
|
||||
# Also delete the underlying vector DB
|
||||
try:
|
||||
|
@ -574,6 +579,7 @@ class OpenAIVectorStoreMixin(ABC):
|
|||
|
||||
# Save vector store file to persistent storage (provider-specific)
|
||||
dict_chunks = [c.model_dump() for c in chunks]
|
||||
# This should be updated to include chunk_id
|
||||
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
|
||||
|
||||
# Update file_ids and file_counts in vector store metadata
|
||||
|
|
|
@ -20,6 +20,15 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models):
|
||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec", "inline::milvus"]:
|
||||
return
|
||||
|
||||
pytest.skip("OpenAI vector stores are not supported by any provider")
|
||||
|
||||
|
||||
def skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models):
|
||||
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||
for p in vector_io_providers:
|
||||
if p.provider_type in ["inline::faiss", "inline::sqlite-vec"]:
|
||||
|
@ -443,6 +452,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
|||
def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store attach file."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||
|
@ -494,6 +504,7 @@ def test_openai_vector_store_attach_file(compat_client_with_empty_stores, client
|
|||
def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store attach files on creation."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||
|
@ -551,6 +562,7 @@ def test_openai_vector_store_attach_files_on_creation(compat_client_with_empty_s
|
|||
def test_openai_vector_store_list_files(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store list files."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
|
||||
|
@ -624,6 +636,7 @@ def test_openai_vector_store_list_files_invalid_vector_store(compat_client_with_
|
|||
def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store retrieve file contents."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files retrieve contents is not yet supported with LlamaStackClient")
|
||||
|
@ -665,6 +678,7 @@ def test_openai_vector_store_retrieve_file_contents(compat_client_with_empty_sto
|
|||
def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store delete file."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files list is not yet supported with LlamaStackClient")
|
||||
|
@ -718,10 +732,11 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client
|
|||
|
||||
|
||||
# TODO: Remove this xfail once we have a way to remove embeddings from vector store
|
||||
@pytest.mark.xfail(reason="Vector Store Files delete doesn't remove embeddings from vecntor store", strict=True)
|
||||
@pytest.mark.xfail(reason="Vector Store Files delete doesn't remove embeddings from vector store", strict=True)
|
||||
def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store delete file removes from vector store."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files attach is not yet supported with LlamaStackClient")
|
||||
|
@ -763,6 +778,7 @@ def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client
|
|||
def test_openai_vector_store_update_file(compat_client_with_empty_stores, client_with_models):
|
||||
"""Test OpenAI vector store update file."""
|
||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||
skip_if_provider_doesnt_support_openai_vector_store_files_api(client_with_models)
|
||||
|
||||
if isinstance(compat_client_with_empty_stores, LlamaStackClient):
|
||||
pytest.skip("Vector Store Files update is not yet supported with LlamaStackClient")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue