mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-27 06:28:50 +00:00
Merge 05d3fffbdf
into 21bae296f2
This commit is contained in:
commit
754fb32c59
13 changed files with 146 additions and 12 deletions
|
@ -237,7 +237,10 @@ API responses, specify the adapter here.
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(
|
def remote_provider_spec(
|
||||||
api: Api, adapter: AdapterSpec, api_dependencies: list[Api] | None = None
|
api: Api,
|
||||||
|
adapter: AdapterSpec,
|
||||||
|
api_dependencies: list[Api] | None = None,
|
||||||
|
optional_api_dependencies: list[Api] | None = None,
|
||||||
) -> RemoteProviderSpec:
|
) -> RemoteProviderSpec:
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
|
@ -245,6 +248,7 @@ def remote_provider_spec(
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
api_dependencies=api_dependencies or [],
|
api_dependencies=api_dependencies or [],
|
||||||
|
optional_api_dependencies=optional_api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -55,6 +55,11 @@ class FaissIndex(EmbeddingIndex):
|
||||||
self.kvstore = kvstore
|
self.kvstore = kvstore
|
||||||
self.bank_id = bank_id
|
self.bank_id = bank_id
|
||||||
|
|
||||||
|
# A list of chunk id's in the same order as they are in the index,
|
||||||
|
# must be updated when chunks are added or removed
|
||||||
|
self.chunk_id_lock = asyncio.Lock()
|
||||||
|
self.chunk_ids: list[Any] = []
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
async def create(cls, dimension: int, kvstore: KVStore | None = None, bank_id: str | None = None):
|
||||||
instance = cls(dimension, kvstore, bank_id)
|
instance = cls(dimension, kvstore, bank_id)
|
||||||
|
@ -75,6 +80,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
buffer = io.BytesIO(base64.b64decode(data["faiss_index"]))
|
||||||
try:
|
try:
|
||||||
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
self.index = faiss.deserialize_index(np.load(buffer, allow_pickle=False))
|
||||||
|
self.chunk_ids = [chunk.chunk_id for chunk in self.chunk_by_index.values()]
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(e, exc_info=True)
|
logger.debug(e, exc_info=True)
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -114,11 +120,33 @@ class FaissIndex(EmbeddingIndex):
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
|
||||||
|
async with self.chunk_id_lock:
|
||||||
self.index.add(np.array(embeddings).astype(np.float32))
|
self.index.add(np.array(embeddings).astype(np.float32))
|
||||||
|
self.chunk_ids.extend([chunk.chunk_id for chunk in chunks])
|
||||||
|
|
||||||
# Save updated index
|
# Save updated index
|
||||||
await self._save_index()
|
await self._save_index()
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
if chunk_id not in self.chunk_ids:
|
||||||
|
return
|
||||||
|
|
||||||
|
async with self.chunk_id_lock:
|
||||||
|
index = self.chunk_ids.index(chunk_id)
|
||||||
|
self.index.remove_ids(np.array([index]))
|
||||||
|
|
||||||
|
new_chunk_by_index = {}
|
||||||
|
for idx, chunk in self.chunk_by_index.items():
|
||||||
|
# Shift all chunks after the removed chunk to the left
|
||||||
|
if idx > index:
|
||||||
|
new_chunk_by_index[idx - 1] = chunk
|
||||||
|
else:
|
||||||
|
new_chunk_by_index[idx] = chunk
|
||||||
|
self.chunk_by_index = new_chunk_by_index
|
||||||
|
self.chunk_ids.pop(index)
|
||||||
|
|
||||||
|
await self._save_index()
|
||||||
|
|
||||||
async def query_vector(
|
async def query_vector(
|
||||||
self,
|
self,
|
||||||
embedding: NDArray,
|
embedding: NDArray,
|
||||||
|
@ -260,3 +288,8 @@ class FaissVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPr
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a faiss index"""
|
||||||
|
faiss_index = self.cache[store_id].index
|
||||||
|
await faiss_index.delete_chunk(chunk_id)
|
||||||
|
|
|
@ -425,6 +425,35 @@ class SQLiteVecIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
"""Remove a chunk from the SQLite vector store."""
|
||||||
|
|
||||||
|
def _delete_chunk():
|
||||||
|
connection = _create_sqlite_connection(self.db_path)
|
||||||
|
cur = connection.cursor()
|
||||||
|
try:
|
||||||
|
cur.execute("BEGIN TRANSACTION")
|
||||||
|
|
||||||
|
# Delete from metadata table
|
||||||
|
cur.execute(f"DELETE FROM {self.metadata_table} WHERE id = ?", (chunk_id,))
|
||||||
|
|
||||||
|
# Delete from vector table
|
||||||
|
cur.execute(f"DELETE FROM {self.vector_table} WHERE id = ?", (chunk_id,))
|
||||||
|
|
||||||
|
# Delete from FTS table
|
||||||
|
cur.execute(f"DELETE FROM {self.fts_table} WHERE id = ?", (chunk_id,))
|
||||||
|
|
||||||
|
connection.commit()
|
||||||
|
except Exception as e:
|
||||||
|
connection.rollback()
|
||||||
|
logger.error(f"Error deleting chunk {chunk_id}: {e}")
|
||||||
|
raise
|
||||||
|
finally:
|
||||||
|
cur.close()
|
||||||
|
connection.close()
|
||||||
|
|
||||||
|
await asyncio.to_thread(_delete_chunk)
|
||||||
|
|
||||||
|
|
||||||
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
|
@ -520,3 +549,12 @@ class SQLiteVecVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoc
|
||||||
if not index:
|
if not index:
|
||||||
raise ValueError(f"Vector DB {vector_db_id} not found")
|
raise ValueError(f"Vector DB {vector_db_id} not found")
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a sqlite_vec index."""
|
||||||
|
index = await self._get_and_cache_vector_db_index(store_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Vector DB {store_id} not found")
|
||||||
|
|
||||||
|
# Use the index's delete_chunk method
|
||||||
|
await index.index.delete_chunk(chunk_id)
|
||||||
|
|
|
@ -410,6 +410,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
api_dependencies=[Api.inference],
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.vector_io,
|
Api.vector_io,
|
||||||
|
|
|
@ -112,6 +112,9 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Keyword search is not supported in Chroma")
|
raise NotImplementedError("Keyword search is not supported in Chroma")
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
raise NotImplementedError("delete_chunk is not supported in Chroma")
|
||||||
|
|
||||||
async def query_hybrid(
|
async def query_hybrid(
|
||||||
self,
|
self,
|
||||||
embedding: NDArray,
|
embedding: NDArray,
|
||||||
|
@ -208,3 +211,6 @@ class ChromaVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
index = VectorDBWithIndex(vector_db, ChromaIndex(self.client, collection), self.inference_api)
|
||||||
self.cache[vector_db_id] = index
|
self.cache[vector_db_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Chroma")
|
||||||
|
|
|
@ -247,6 +247,16 @@ class MilvusIndex(EmbeddingIndex):
|
||||||
) -> QueryChunksResponse:
|
) -> QueryChunksResponse:
|
||||||
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
raise NotImplementedError("Hybrid search is not supported in Milvus")
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
"""Remove a chunk from the Milvus collection."""
|
||||||
|
try:
|
||||||
|
await asyncio.to_thread(
|
||||||
|
self.client.delete, collection_name=self.collection_name, filter=f'chunk_id == "{chunk_id}"'
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting chunk {chunk_id} from Milvus collection {self.collection_name}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
|
||||||
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -369,3 +379,12 @@ class MilvusVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolP
|
||||||
)
|
)
|
||||||
|
|
||||||
return await index.query_chunks(query, params)
|
return await index.query_chunks(query, params)
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a milvus vector store."""
|
||||||
|
index = await self._get_and_cache_vector_db_index(store_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Vector DB {store_id} not found")
|
||||||
|
|
||||||
|
# Use the index's delete_chunk method
|
||||||
|
await index.index.delete_chunk(chunk_id)
|
||||||
|
|
|
@ -12,6 +12,6 @@ from .config import PGVectorVectorIOConfig
|
||||||
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
async def get_adapter_impl(config: PGVectorVectorIOConfig, deps: dict[Api, ProviderSpec]):
|
||||||
from .pgvector import PGVectorVectorIOAdapter
|
from .pgvector import PGVectorVectorIOAdapter
|
||||||
|
|
||||||
impl = PGVectorVectorIOAdapter(config, deps[Api.inference])
|
impl = PGVectorVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None))
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -99,7 +99,7 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
values.append(
|
values.append(
|
||||||
(
|
(
|
||||||
f"{chunk.metadata['document_id']}:chunk-{i}",
|
f"{chunk.chunk_id}",
|
||||||
Json(chunk.model_dump()),
|
Json(chunk.model_dump()),
|
||||||
embeddings[i].tolist(),
|
embeddings[i].tolist(),
|
||||||
)
|
)
|
||||||
|
@ -159,6 +159,11 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
cur.execute(f"DROP TABLE IF EXISTS {self.table_name}")
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
"""Remove a chunk from the PostgreSQL table."""
|
||||||
|
with self.conn.cursor(cursor_factory=psycopg2.extras.DictCursor) as cur:
|
||||||
|
cur.execute(f"DELETE FROM {self.table_name} WHERE id = %s", (chunk_id,))
|
||||||
|
|
||||||
|
|
||||||
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -265,3 +270,12 @@ class PGVectorVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorDBsProtoco
|
||||||
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
index = PGVectorIndex(vector_db, vector_db.embedding_dimension, self.conn)
|
||||||
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
self.cache[vector_db_id] = VectorDBWithIndex(vector_db, index, self.inference_api)
|
||||||
return self.cache[vector_db_id]
|
return self.cache[vector_db_id]
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a PostgreSQL vector store."""
|
||||||
|
index = await self._get_and_cache_vector_db_index(store_id)
|
||||||
|
if not index:
|
||||||
|
raise ValueError(f"Vector DB {store_id} not found")
|
||||||
|
|
||||||
|
# Use the index's delete_chunk method
|
||||||
|
await index.index.delete_chunk(chunk_id)
|
||||||
|
|
|
@ -82,6 +82,9 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
await self.client.upsert(collection_name=self.collection_name, points=points)
|
await self.client.upsert(collection_name=self.collection_name, points=points)
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
raise NotImplementedError("delete_chunk is not supported in qdrant")
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
results = (
|
results = (
|
||||||
await self.client.query_points(
|
await self.client.query_points(
|
||||||
|
@ -307,3 +310,6 @@ class QdrantVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate):
|
||||||
file_id: str,
|
file_id: str,
|
||||||
) -> VectorStoreFileObject:
|
) -> VectorStoreFileObject:
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Qdrant")
|
||||||
|
|
|
@ -66,6 +66,9 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
# TODO: make this async friendly
|
# TODO: make this async friendly
|
||||||
collection.data.insert_many(data_objects)
|
collection.data.insert_many(data_objects)
|
||||||
|
|
||||||
|
async def delete_chunk(self, chunk_id: str) -> None:
|
||||||
|
raise NotImplementedError("delete_chunk is not supported in Chroma")
|
||||||
|
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
collection = self.client.collections.get(self.collection_name)
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
|
||||||
|
@ -264,3 +267,6 @@ class WeaviateVectorIOAdapter(
|
||||||
|
|
||||||
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None:
|
||||||
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate")
|
||||||
|
|
|
@ -152,6 +152,11 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
"""Load existing OpenAI vector stores into the in-memory cache."""
|
"""Load existing OpenAI vector stores into the in-memory cache."""
|
||||||
self.openai_vector_stores = await self._load_openai_vector_stores()
|
self.openai_vector_stores = await self._load_openai_vector_stores()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def _delete_openai_chunk_from_vector_store(self, store_id: str, chunk_id: str) -> None:
|
||||||
|
"""Delete a chunk from a vector store."""
|
||||||
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
async def register_vector_db(self, vector_db: VectorDB) -> None:
|
||||||
"""Register a vector database (provider-specific implementation)."""
|
"""Register a vector database (provider-specific implementation)."""
|
||||||
|
@ -763,17 +768,17 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
if vector_store_id not in self.openai_vector_stores:
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
raise ValueError(f"Vector store {vector_store_id} not found")
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
dict_chunks = await self._load_openai_vector_store_file_contents(vector_store_id, file_id)
|
||||||
|
chunks = [Chunk.model_validate(c) for c in dict_chunks]
|
||||||
|
for c in chunks:
|
||||||
|
if c.chunk_id:
|
||||||
|
await self._delete_openai_chunk_from_vector_store(vector_store_id, str(c.chunk_id))
|
||||||
|
|
||||||
store_info = self.openai_vector_stores[vector_store_id].copy()
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
|
||||||
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
file = await self.openai_retrieve_vector_store_file(vector_store_id, file_id)
|
||||||
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
|
await self._delete_openai_vector_store_file_from_storage(vector_store_id, file_id)
|
||||||
|
|
||||||
# TODO: We need to actually delete the embeddings from the underlying vector store...
|
|
||||||
# Also uncomment the corresponding integration test marked as xfail
|
|
||||||
#
|
|
||||||
# test_openai_vector_store_delete_file_removes_from_vector_store in
|
|
||||||
# tests/integration/vector_io/test_openai_vector_stores.py
|
|
||||||
|
|
||||||
# Update in-memory cache
|
# Update in-memory cache
|
||||||
store_info["file_ids"].remove(file_id)
|
store_info["file_ids"].remove(file_id)
|
||||||
store_info["file_counts"][file.status] -= 1
|
store_info["file_counts"][file.status] -= 1
|
||||||
|
|
|
@ -231,6 +231,10 @@ class EmbeddingIndex(ABC):
|
||||||
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def delete_chunk(self, chunk_id: str):
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
|
@ -735,8 +735,6 @@ def test_openai_vector_store_delete_file(compat_client_with_empty_stores, client
|
||||||
assert updated_vector_store.file_counts.in_progress == 0
|
assert updated_vector_store.file_counts.in_progress == 0
|
||||||
|
|
||||||
|
|
||||||
# TODO: Remove this xfail once we have a way to remove embeddings from vector store
|
|
||||||
@pytest.mark.xfail(reason="Vector Store Files delete doesn't remove embeddings from vector store", strict=True)
|
|
||||||
def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models):
|
def test_openai_vector_store_delete_file_removes_from_vector_store(compat_client_with_empty_stores, client_with_models):
|
||||||
"""Test OpenAI vector store delete file removes from vector store."""
|
"""Test OpenAI vector store delete file removes from vector store."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue