mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 18:00:36 +00:00
Merge 2367a4ff80 into sapling-pr-archive-ehhuang
This commit is contained in:
commit
9885c522c3
12 changed files with 848 additions and 75 deletions
|
|
@ -25,8 +25,8 @@ classifiers = [
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"aiohttp",
|
"aiohttp",
|
||||||
"fastapi>=0.115.0,<1.0", # server
|
"fastapi>=0.115.0,<1.0", # server
|
||||||
"fire", # for MCP in LLS client
|
"fire", # for MCP in LLS client
|
||||||
"httpx",
|
"httpx",
|
||||||
"jinja2>=3.1.6",
|
"jinja2>=3.1.6",
|
||||||
"jsonschema",
|
"jsonschema",
|
||||||
|
|
@ -34,7 +34,7 @@ dependencies = [
|
||||||
"openai>=2.5.0",
|
"openai>=2.5.0",
|
||||||
"prompt-toolkit",
|
"prompt-toolkit",
|
||||||
"python-dotenv",
|
"python-dotenv",
|
||||||
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
|
"pyjwt[crypto]>=2.10.0", # Pull crypto to support RS256 for jwt. Requires 2.10.0+ for ssl_context support.
|
||||||
"pydantic>=2.11.9",
|
"pydantic>=2.11.9",
|
||||||
"rich",
|
"rich",
|
||||||
"starlette",
|
"starlette",
|
||||||
|
|
@ -42,13 +42,13 @@ dependencies = [
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"pillow",
|
"pillow",
|
||||||
"h11>=0.16.0",
|
"h11>=0.16.0",
|
||||||
"python-multipart>=0.0.20", # For fastapi Form
|
"python-multipart>=0.0.20", # For fastapi Form
|
||||||
"uvicorn>=0.34.0", # server
|
"uvicorn>=0.34.0", # server
|
||||||
"opentelemetry-sdk>=1.30.0", # server
|
"opentelemetry-sdk>=1.30.0", # server
|
||||||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||||
"aiosqlite>=0.21.0", # server - for metadata store
|
"aiosqlite>=0.21.0", # server - for metadata store
|
||||||
"asyncpg", # for metadata store
|
"asyncpg", # for metadata store
|
||||||
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
|
"sqlalchemy[asyncio]>=2.0.41", # server - for conversations
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
|
@ -192,6 +192,7 @@ explicit = true
|
||||||
[tool.uv.sources]
|
[tool.uv.sources]
|
||||||
torch = [{ index = "pytorch-cpu" }]
|
torch = [{ index = "pytorch-cpu" }]
|
||||||
torchvision = [{ index = "pytorch-cpu" }]
|
torchvision = [{ index = "pytorch-cpu" }]
|
||||||
|
llama-stack-client = { path = "../llama-stack-client-python", editable = true }
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
line-length = 120
|
line-length = 120
|
||||||
|
|
|
||||||
|
|
@ -260,7 +260,7 @@ class VectorStoreSearchResponsePage(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
object: str = "vector_store.search_results.page"
|
object: str = "vector_store.search_results.page"
|
||||||
search_query: str
|
search_query: str | list[str]
|
||||||
data: list[VectorStoreSearchResponse]
|
data: list[VectorStoreSearchResponse]
|
||||||
has_more: bool = False
|
has_more: bool = False
|
||||||
next_page: str | None = None
|
next_page: str | None = None
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,13 @@ providers:
|
||||||
persistence:
|
persistence:
|
||||||
namespace: vector_io::weaviate
|
namespace: vector_io::weaviate
|
||||||
backend: kv_default
|
backend: kv_default
|
||||||
|
- provider_id: openai-vector-store
|
||||||
|
provider_type: remote::openai
|
||||||
|
config:
|
||||||
|
api_key: ${env.OPENAI_API_KEY:=}
|
||||||
|
persistence:
|
||||||
|
namespace: vector_io::openai_vector_store
|
||||||
|
backend: kv_default
|
||||||
files:
|
files:
|
||||||
- provider_id: meta-reference-files
|
- provider_id: meta-reference-files
|
||||||
provider_type: inline::localfs
|
provider_type: inline::localfs
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
from .config import OpenAIVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(config: OpenAIVectorIOConfig, deps: dict[Api, Any]):
|
||||||
|
from .openai import OpenAIVectorIOAdapter
|
||||||
|
|
||||||
|
assert isinstance(config, OpenAIVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = OpenAIVectorIOAdapter(
|
||||||
|
config,
|
||||||
|
deps[Api.inference],
|
||||||
|
deps.get(Api.files),
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
33
src/llama_stack/providers/inline/vector_io/openai/config.py
Normal file
33
src/llama_stack/providers/inline/vector_io/openai/config.py
Normal file
|
|
@ -0,0 +1,33 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.core.storage.datatypes import KVStoreReference
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIVectorIOConfig(BaseModel):
|
||||||
|
api_key: str | None = Field(
|
||||||
|
None,
|
||||||
|
description="OpenAI API key. If not provided, will use OPENAI_API_KEY environment variable.",
|
||||||
|
)
|
||||||
|
persistence: KVStoreReference = Field(
|
||||||
|
description="KVStore reference for persisting vector store metadata.",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"api_key": "${OPENAI_API_KEY}",
|
||||||
|
"persistence": KVStoreReference(
|
||||||
|
backend="kv_default",
|
||||||
|
namespace="vector_io::openai",
|
||||||
|
).model_dump(exclude_none=True),
|
||||||
|
}
|
||||||
512
src/llama_stack/providers/inline/vector_io/openai/openai.py
Normal file
512
src/llama_stack/providers/inline/vector_io/openai/openai.py
Normal file
|
|
@ -0,0 +1,512 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.apis.files import Files, OpenAIFileObject
|
||||||
|
from llama_stack.apis.inference import Inference, InterleavedContent
|
||||||
|
from llama_stack.apis.vector_io import (
|
||||||
|
Chunk,
|
||||||
|
QueryChunksResponse,
|
||||||
|
SearchRankingOptions,
|
||||||
|
VectorIO,
|
||||||
|
VectorStoreFileDeleteResponse,
|
||||||
|
VectorStoreFileLastError,
|
||||||
|
VectorStoreFileObject,
|
||||||
|
VectorStoreSearchResponsePage,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.vector_stores import VectorStore
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, VectorStoresProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
from llama_stack.providers.utils.memory.openai_vector_store_mixin import OpenAIVectorStoreMixin
|
||||||
|
from llama_stack.providers.utils.memory.vector_store import ChunkForDeletion
|
||||||
|
|
||||||
|
from .config import OpenAIVectorIOConfig
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="vector_io")
|
||||||
|
|
||||||
|
# Prefix for storing the mapping from Llama Stack vector store IDs to OpenAI vector store IDs
|
||||||
|
VECTOR_STORE_ID_MAPPING_PREFIX = "openai_vector_store_id_mapping::"
|
||||||
|
# Prefix for storing the mapping from Llama Stack file IDs to OpenAI file IDs
|
||||||
|
FILE_ID_MAPPING_PREFIX = "openai_file_id_mapping::"
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIVectorIOAdapter(OpenAIVectorStoreMixin, VectorIO, VectorStoresProtocolPrivate):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: OpenAIVectorIOConfig,
|
||||||
|
inference_api: Inference,
|
||||||
|
files_api: Files | None,
|
||||||
|
) -> None:
|
||||||
|
super().__init__(files_api=files_api, kvstore=None)
|
||||||
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.openai_client = None
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
self.kvstore = await kvstore_impl(self.config.persistence)
|
||||||
|
|
||||||
|
# Initialize OpenAI client
|
||||||
|
try:
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
except ImportError as e:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OpenAI Python client library is not installed. Please install it with: pip install openai"
|
||||||
|
) from e
|
||||||
|
|
||||||
|
api_key = self.config.api_key or None
|
||||||
|
if api_key == "${OPENAI_API_KEY}":
|
||||||
|
api_key = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
self.openai_client = AsyncOpenAI(api_key=api_key)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError("Failed to initialize OpenAI client") from e
|
||||||
|
|
||||||
|
# Load existing OpenAI vector stores into the in-memory cache
|
||||||
|
await self.initialize_openai_vector_stores()
|
||||||
|
|
||||||
|
async def _store_vector_store_id_mapping(self, llama_stack_id: str, openai_id: str) -> None:
|
||||||
|
"""Store mapping from Llama Stack vector store ID to OpenAI vector store ID."""
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{llama_stack_id}"
|
||||||
|
await self.kvstore.set(key, openai_id)
|
||||||
|
|
||||||
|
async def _get_openai_vector_store_id(self, llama_stack_id: str) -> str:
|
||||||
|
"""Get OpenAI vector store ID from Llama Stack vector store ID.
|
||||||
|
|
||||||
|
Raises ValueError if mapping is not found.
|
||||||
|
"""
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{llama_stack_id}"
|
||||||
|
try:
|
||||||
|
openai_id = await self.kvstore.get(key)
|
||||||
|
if openai_id:
|
||||||
|
return openai_id
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
# If not found in mapping, raise an error instead of assuming
|
||||||
|
raise ValueError(f"No OpenAI vector store mapping found for Llama Stack ID: {llama_stack_id}")
|
||||||
|
|
||||||
|
async def _delete_vector_store_id_mapping(self, llama_stack_id: str) -> None:
|
||||||
|
"""Delete mapping for a vector store ID."""
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{llama_stack_id}"
|
||||||
|
try:
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def _store_file_id_mapping(self, llama_stack_file_id: str, openai_file_id: str) -> None:
|
||||||
|
"""Store mapping from Llama Stack file ID to OpenAI file ID."""
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{FILE_ID_MAPPING_PREFIX}{llama_stack_file_id}"
|
||||||
|
await self.kvstore.set(key, openai_file_id)
|
||||||
|
|
||||||
|
async def _get_openai_file_id(self, llama_stack_file_id: str) -> str | None:
|
||||||
|
"""Get OpenAI file ID from Llama Stack file ID. Returns None if not found."""
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{FILE_ID_MAPPING_PREFIX}{llama_stack_file_id}"
|
||||||
|
try:
|
||||||
|
openai_id = await self.kvstore.get(key)
|
||||||
|
if openai_id:
|
||||||
|
return openai_id
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get_llama_stack_file_id(self, openai_file_id: str) -> str | None:
|
||||||
|
"""Get Llama Stack file ID from OpenAI file ID. Returns None if not found."""
|
||||||
|
if self.kvstore:
|
||||||
|
# For reverse lookup, we need to search through all mappings
|
||||||
|
prefix = FILE_ID_MAPPING_PREFIX
|
||||||
|
start_key = prefix
|
||||||
|
end_key = f"{prefix}\xff"
|
||||||
|
try:
|
||||||
|
items = await self.kvstore.items_in_range(start_key, end_key)
|
||||||
|
for key, value in items:
|
||||||
|
if value == openai_file_id:
|
||||||
|
# Extract the Llama Stack file ID from the key
|
||||||
|
return key[len(prefix) :]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _delete_file_id_mapping(self, llama_stack_file_id: str) -> None:
|
||||||
|
"""Delete mapping for a file ID."""
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{FILE_ID_MAPPING_PREFIX}{llama_stack_file_id}"
|
||||||
|
try:
|
||||||
|
await self.kvstore.delete(key)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
# Clean up mixin resources (file batch tasks)
|
||||||
|
await super().shutdown()
|
||||||
|
|
||||||
|
async def health(self) -> HealthResponse:
|
||||||
|
"""
|
||||||
|
Performs a health check by verifying connectivity to OpenAI API.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
if self.openai_client is None:
|
||||||
|
return HealthResponse(
|
||||||
|
status=HealthStatus.ERROR,
|
||||||
|
message="OpenAI client not initialized",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to list models as a simple health check
|
||||||
|
await self.openai_client.models.list()
|
||||||
|
return HealthResponse(status=HealthStatus.OK)
|
||||||
|
except Exception as e:
|
||||||
|
return HealthResponse(
|
||||||
|
status=HealthStatus.ERROR,
|
||||||
|
message=f"Health check failed: {str(e)}",
|
||||||
|
)
|
||||||
|
|
||||||
|
async def register_vector_store(self, vector_store: VectorStore) -> None:
|
||||||
|
"""Register a vector store by creating it in OpenAI's API."""
|
||||||
|
if self.openai_client is None:
|
||||||
|
raise RuntimeError("OpenAI client not initialized")
|
||||||
|
|
||||||
|
# Create vector store in OpenAI
|
||||||
|
created_store = await self.openai_client.vector_stores.create(
|
||||||
|
name=vector_store.vector_store_name or vector_store.identifier,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Store mapping from Llama Stack ID to OpenAI ID
|
||||||
|
await self._store_vector_store_id_mapping(vector_store.identifier, created_store.id)
|
||||||
|
|
||||||
|
logger.info(f"Created OpenAI vector store: {created_store.id} for identifier: {vector_store.identifier}")
|
||||||
|
|
||||||
|
async def unregister_vector_store(self, vector_store_id: str) -> None:
|
||||||
|
"""Delete a vector store from OpenAI's API."""
|
||||||
|
if self.openai_client is None:
|
||||||
|
raise RuntimeError("OpenAI client not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Look up the OpenAI ID from our mapping
|
||||||
|
if self.kvstore:
|
||||||
|
key = f"{VECTOR_STORE_ID_MAPPING_PREFIX}{vector_store_id}"
|
||||||
|
try:
|
||||||
|
openai_vector_store_id = await self.kvstore.get(key)
|
||||||
|
if openai_vector_store_id:
|
||||||
|
await self.openai_client.vector_stores.delete(openai_vector_store_id)
|
||||||
|
logger.info(
|
||||||
|
f"Deleted OpenAI vector store: {openai_vector_store_id} for identifier: {vector_store_id}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(f"No OpenAI vector store mapping found for {vector_store_id}, skipping deletion")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to delete vector store {vector_store_id} from OpenAI: {e}", exc_info=True)
|
||||||
|
# Clean up the mapping
|
||||||
|
await self._delete_vector_store_id_mapping(vector_store_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Error in unregister_vector_store for {vector_store_id}: {e}", exc_info=True)
|
||||||
|
|
||||||
|
async def insert_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
chunks: list[Chunk],
|
||||||
|
ttl_seconds: int | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
OpenAI Vector Stores API doesn't support direct chunk insertion.
|
||||||
|
Use file attachment instead via openai_attach_file_to_vector_store.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Direct chunk insertion is not supported by OpenAI Vector Stores API. "
|
||||||
|
"Please use file attachment instead via the openai_attach_file_to_vector_store endpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def query_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: InterleavedContent,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
) -> QueryChunksResponse:
|
||||||
|
"""
|
||||||
|
OpenAI Vector Stores API doesn't support direct chunk queries.
|
||||||
|
Use the OpenAI vector store search API instead.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Direct chunk querying is not supported by OpenAI Vector Stores API. "
|
||||||
|
"Please use the openai_search_vector_store endpoint instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def delete_chunks(
|
||||||
|
self,
|
||||||
|
store_id: str,
|
||||||
|
chunks_for_deletion: list[ChunkForDeletion],
|
||||||
|
) -> None:
|
||||||
|
"""
|
||||||
|
OpenAI Vector Stores API doesn't support direct chunk deletion.
|
||||||
|
Delete files from the vector store instead.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Direct chunk deletion is not supported by OpenAI Vector Stores API. "
|
||||||
|
"Please delete files from the vector store instead via openai_delete_vector_store_file."
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _prepare_and_attach_file_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
attributes: dict[str, Any],
|
||||||
|
chunking_strategy: Any,
|
||||||
|
created_at: int,
|
||||||
|
) -> tuple[Any, list[Chunk], Any]:
|
||||||
|
"""
|
||||||
|
Override to download file from Llama Stack, upload to OpenAI,
|
||||||
|
and attach to OpenAI vector store instead of storing chunks locally.
|
||||||
|
|
||||||
|
Returns: (VectorStoreFileObject, empty chunks list, file response)
|
||||||
|
"""
|
||||||
|
|
||||||
|
# Translate Llama Stack ID to OpenAI ID
|
||||||
|
try:
|
||||||
|
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
|
||||||
|
except ValueError as e:
|
||||||
|
logger.error(f"Cannot attach file to vector store {vector_store_id}: {e}")
|
||||||
|
return (
|
||||||
|
VectorStoreFileObject(
|
||||||
|
id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
chunking_strategy=chunking_strategy,
|
||||||
|
created_at=created_at,
|
||||||
|
status="failed",
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
last_error=VectorStoreFileLastError(
|
||||||
|
code="server_error",
|
||||||
|
message=str(e),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
[],
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store_file_object = VectorStoreFileObject(
|
||||||
|
id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
chunking_strategy=chunking_strategy,
|
||||||
|
created_at=created_at,
|
||||||
|
status="in_progress",
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Prepare file: download from Llama Stack if needed, upload to OpenAI
|
||||||
|
try:
|
||||||
|
file_obj: OpenAIFileObject = await self.files_api.openai_retrieve_file(file_id)
|
||||||
|
file_content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
||||||
|
file_data = file_content_response.body
|
||||||
|
|
||||||
|
import io
|
||||||
|
|
||||||
|
file_buffer = io.BytesIO(file_data)
|
||||||
|
file_buffer.name = file_obj.filename
|
||||||
|
|
||||||
|
openai_file = await self.openai_client.files.create(
|
||||||
|
file=file_buffer,
|
||||||
|
purpose="assistants",
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Uploaded file {file_id} to OpenAI as {openai_file.id}")
|
||||||
|
openai_file_id = openai_file.id
|
||||||
|
# Store mapping for later lookup
|
||||||
|
await self._store_file_id_mapping(file_id, openai_file_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not retrieve file {file_id} from Llama Stack: {e}. Using file_id directly.")
|
||||||
|
openai_file_id = file_id
|
||||||
|
|
||||||
|
# Attach file to OpenAI vector store
|
||||||
|
try:
|
||||||
|
attached_file = await self.openai_client.vector_stores.files.create(
|
||||||
|
vector_store_id=openai_vector_store_id,
|
||||||
|
file_id=openai_file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"Attached file {openai_file_id} to OpenAI vector store {openai_vector_store_id}, "
|
||||||
|
f"status: {attached_file.status}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use the status from OpenAI's response, don't assume it's completed
|
||||||
|
vector_store_file_object.status = attached_file.status
|
||||||
|
file_response = file_obj if "file_obj" in locals() else None
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to attach file {openai_file_id} to vector store: {e}")
|
||||||
|
vector_store_file_object.status = "failed"
|
||||||
|
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||||
|
code="server_error",
|
||||||
|
message=str(e),
|
||||||
|
)
|
||||||
|
file_response = file_obj if "file_obj" in locals() else None
|
||||||
|
|
||||||
|
# Return VectorStoreFileObject and empty chunks (OpenAI handles storage)
|
||||||
|
return vector_store_file_object, [], file_response
|
||||||
|
|
||||||
|
async def openai_search_vector_store(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
query: str | list[str],
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
max_num_results: int | None = 10,
|
||||||
|
ranking_options: SearchRankingOptions | None = None,
|
||||||
|
rewrite_query: bool | None = False,
|
||||||
|
search_mode: str | None = "vector",
|
||||||
|
) -> VectorStoreSearchResponsePage:
|
||||||
|
"""Search a vector store using OpenAI's native search API."""
|
||||||
|
assert self.openai_client is not None
|
||||||
|
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
|
||||||
|
# raise ValueError(f"openai_vector_store_id: {openai_vector_store_id}")
|
||||||
|
logger.info(f"openai_vector_store_id: {openai_vector_store_id}")
|
||||||
|
response = await self.openai_client.vector_stores.search(
|
||||||
|
vector_store_id=openai_vector_store_id,
|
||||||
|
query=query,
|
||||||
|
filters=filters,
|
||||||
|
max_num_results=max_num_results,
|
||||||
|
ranking_options=ranking_options,
|
||||||
|
rewrite_query=rewrite_query,
|
||||||
|
)
|
||||||
|
payload = response.model_dump()
|
||||||
|
logger.info(f"payload: {payload}")
|
||||||
|
# Remap OpenAI file IDs back to Llama Stack file IDs in results
|
||||||
|
if payload.get("data"):
|
||||||
|
for result in payload["data"]:
|
||||||
|
if result.get("file_id"):
|
||||||
|
llama_stack_file_id = await self._get_llama_stack_file_id(result["file_id"])
|
||||||
|
if llama_stack_file_id:
|
||||||
|
result["file_id"] = llama_stack_file_id
|
||||||
|
|
||||||
|
return VectorStoreSearchResponsePage(**payload)
|
||||||
|
|
||||||
|
async def openai_delete_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileDeleteResponse:
|
||||||
|
"""Delete a file from a vector store."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
if self.openai_client is None:
|
||||||
|
raise RuntimeError("OpenAI client not initialized")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Get the OpenAI file ID for this Llama Stack file ID
|
||||||
|
openai_file_id = await self._get_openai_file_id(file_id)
|
||||||
|
if not openai_file_id:
|
||||||
|
# If no mapping, use the file_id as-is (may be native OpenAI file ID)
|
||||||
|
openai_file_id = file_id
|
||||||
|
|
||||||
|
# Get OpenAI vector store ID
|
||||||
|
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
|
||||||
|
|
||||||
|
# Delete file from OpenAI vector store
|
||||||
|
await self.openai_client.vector_stores.files.delete(
|
||||||
|
vector_store_id=openai_vector_store_id,
|
||||||
|
file_id=openai_file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(f"Deleted file {openai_file_id} from OpenAI vector store {openai_vector_store_id}")
|
||||||
|
|
||||||
|
# Delete the file from OpenAI if it was created by us
|
||||||
|
if await self._get_openai_file_id(file_id):
|
||||||
|
try:
|
||||||
|
await self.openai_client.files.delete(openai_file_id)
|
||||||
|
logger.info(f"Deleted OpenAI file {openai_file_id}")
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not delete OpenAI file {openai_file_id}: {e}")
|
||||||
|
|
||||||
|
# Clean up mappings
|
||||||
|
await self._delete_file_id_mapping(file_id)
|
||||||
|
|
||||||
|
# Update vector store metadata
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
if file_id in store_info["file_ids"]:
|
||||||
|
store_info["file_ids"].remove(file_id)
|
||||||
|
store_info["file_counts"]["total"] -= 1
|
||||||
|
store_info["file_counts"]["completed"] -= 1
|
||||||
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
|
return VectorStoreFileDeleteResponse(
|
||||||
|
id=file_id,
|
||||||
|
deleted=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Error deleting file {file_id} from vector store {vector_store_id}: {e}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
async def openai_retrieve_vector_store_file(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
) -> VectorStoreFileObject:
|
||||||
|
"""Retrieve a vector store file and check status from OpenAI if still in_progress."""
|
||||||
|
if vector_store_id not in self.openai_vector_stores:
|
||||||
|
raise ValueError(f"Vector store {vector_store_id} not found")
|
||||||
|
|
||||||
|
if self.openai_client is None:
|
||||||
|
raise RuntimeError("OpenAI client not initialized")
|
||||||
|
|
||||||
|
# Get the cached file info
|
||||||
|
file_info = await self._load_openai_vector_store_file(vector_store_id, file_id)
|
||||||
|
file_object = VectorStoreFileObject(**file_info)
|
||||||
|
|
||||||
|
# If status is still in_progress, check the actual status from OpenAI
|
||||||
|
if file_object.status == "in_progress":
|
||||||
|
try:
|
||||||
|
# Get OpenAI file ID for this Llama Stack file ID
|
||||||
|
openai_file_id = await self._get_openai_file_id(file_id)
|
||||||
|
if not openai_file_id:
|
||||||
|
openai_file_id = file_id
|
||||||
|
|
||||||
|
# Get OpenAI vector store ID
|
||||||
|
openai_vector_store_id = await self._get_openai_vector_store_id(vector_store_id)
|
||||||
|
|
||||||
|
# Retrieve the file status from OpenAI
|
||||||
|
openai_file = await self.openai_client.vector_stores.files.retrieve(
|
||||||
|
vector_store_id=openai_vector_store_id,
|
||||||
|
file_id=openai_file_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update the status from OpenAI
|
||||||
|
file_object.status = openai_file.status
|
||||||
|
|
||||||
|
# If status changed, update it in storage
|
||||||
|
if openai_file.status != "in_progress":
|
||||||
|
file_info["status"] = openai_file.status
|
||||||
|
# Update file counts in vector store metadata
|
||||||
|
store_info = self.openai_vector_stores[vector_store_id].copy()
|
||||||
|
if file_object.status == "completed":
|
||||||
|
store_info["file_counts"]["in_progress"] = max(
|
||||||
|
0, store_info["file_counts"].get("in_progress", 0) - 1
|
||||||
|
)
|
||||||
|
store_info["file_counts"]["completed"] = (
|
||||||
|
store_info["file_counts"].get("completed", 0) + 1
|
||||||
|
)
|
||||||
|
elif file_object.status == "failed":
|
||||||
|
store_info["file_counts"]["in_progress"] = max(
|
||||||
|
0, store_info["file_counts"].get("in_progress", 0) - 1
|
||||||
|
)
|
||||||
|
store_info["file_counts"]["failed"] = store_info["file_counts"].get("failed", 0) + 1
|
||||||
|
|
||||||
|
self.openai_vector_stores[vector_store_id] = store_info
|
||||||
|
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info)
|
||||||
|
await self._save_openai_vector_store(vector_store_id, store_info)
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.debug(f"Could not retrieve file status from OpenAI: {e}. Using cached status.")
|
||||||
|
|
||||||
|
return file_object
|
||||||
|
|
@ -78,6 +78,75 @@ pip install faiss-cpu
|
||||||
## Documentation
|
## Documentation
|
||||||
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
|
See [Faiss' documentation](https://faiss.ai/) or the [Faiss Wiki](https://github.com/facebookresearch/faiss/wiki) for
|
||||||
more details about Faiss in general.
|
more details about Faiss in general.
|
||||||
|
""",
|
||||||
|
),
|
||||||
|
RemoteProviderSpec(
|
||||||
|
api=Api.vector_io,
|
||||||
|
adapter_type="openai",
|
||||||
|
provider_type="remote::openai",
|
||||||
|
pip_packages=["openai"] + DEFAULT_VECTOR_IO_DEPS,
|
||||||
|
module="llama_stack.providers.remote.vector_io.openai",
|
||||||
|
config_class="llama_stack.providers.inline.vector_io.openai.OpenAIVectorIOConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
|
optional_api_dependencies=[Api.files, Api.models],
|
||||||
|
description="""
|
||||||
|
[OpenAI Vector Stores](https://platform.openai.com/docs/guides/vector-stores) is a remote vector database provider for Llama Stack that uses OpenAI's Vector Stores API.
|
||||||
|
It allows you to store and query vectors using OpenAI's embeddings and vector store infrastructure.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Direct integration with OpenAI's Vector Stores API
|
||||||
|
- File attachment support for batch vector store operations
|
||||||
|
- OpenAI-compatible API endpoints
|
||||||
|
- Full-text search and vector search capabilities
|
||||||
|
- Metadata filtering
|
||||||
|
|
||||||
|
## Search Modes
|
||||||
|
|
||||||
|
**Supported:**
|
||||||
|
- **Vector Search** (`mode="vector"`): Performs vector similarity search using OpenAI embeddings
|
||||||
|
|
||||||
|
**Not Supported:**
|
||||||
|
- **Keyword Search** (`mode="keyword"`): Not supported by OpenAI Vector Stores API
|
||||||
|
- **Hybrid Search** (`mode="hybrid"`): Not supported by OpenAI Vector Stores API
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To use this provider, you need to provide:
|
||||||
|
|
||||||
|
1. **API Key**: Either pass `api_key` in the config or set the `OPENAI_API_KEY` environment variable
|
||||||
|
2. **Persistence**: A KVStore backend for storing metadata
|
||||||
|
|
||||||
|
### Example Configuration
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
vector_io:
|
||||||
|
- provider_id: openai
|
||||||
|
provider_type: remote::openai
|
||||||
|
config:
|
||||||
|
api_key: ${OPENAI_API_KEY}
|
||||||
|
persistence:
|
||||||
|
backend: kv_default
|
||||||
|
namespace: vector_io::openai
|
||||||
|
```
|
||||||
|
|
||||||
|
## Installation
|
||||||
|
|
||||||
|
Install the OpenAI Python client:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install openai
|
||||||
|
```
|
||||||
|
|
||||||
|
## Limitations
|
||||||
|
|
||||||
|
- OpenAI Vector Stores API currently supports file-based attachments primarily
|
||||||
|
- Direct chunk insertion uses OpenAI's embeddings API
|
||||||
|
- For queries, only vector search mode is natively supported
|
||||||
|
|
||||||
|
## Documentation
|
||||||
|
|
||||||
|
See [OpenAI Vector Stores API documentation](https://platform.openai.com/docs/guides/vector-stores) for more details.
|
||||||
""",
|
""",
|
||||||
),
|
),
|
||||||
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
# NOTE: sqlite-vec cannot be bundled into the container image because it does not have a
|
||||||
|
|
|
||||||
|
|
@ -0,0 +1,26 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.vector_io.openai.config import OpenAIVectorIOConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: OpenAIVectorIOConfig, deps: dict[Api, Any]):
|
||||||
|
"""Remote adapter for OpenAI Vector Store provider - delegates to inline implementation."""
|
||||||
|
from llama_stack.providers.inline.vector_io.openai.openai import OpenAIVectorIOAdapter
|
||||||
|
|
||||||
|
assert isinstance(config, OpenAIVectorIOConfig), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
|
impl = OpenAIVectorIOAdapter(
|
||||||
|
config,
|
||||||
|
deps[Api.inference],
|
||||||
|
deps.get(Api.files),
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
@ -729,6 +729,70 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
]
|
]
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
async def _prepare_and_attach_file_chunks(
|
||||||
|
self,
|
||||||
|
vector_store_id: str,
|
||||||
|
file_id: str,
|
||||||
|
attributes: dict[str, Any],
|
||||||
|
chunking_strategy: VectorStoreChunkingStrategy,
|
||||||
|
created_at: int,
|
||||||
|
) -> tuple[VectorStoreFileObject, list[Chunk], OpenAIFileObject | None]:
|
||||||
|
"""
|
||||||
|
Implementation-specific method for preparing and attaching file chunks to vector store.
|
||||||
|
Subclasses can override this to customize how files are prepared and attached.
|
||||||
|
|
||||||
|
Returns: (VectorStoreFileObject, chunks, file_response) tuple
|
||||||
|
"""
|
||||||
|
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
|
||||||
|
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
|
||||||
|
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
|
||||||
|
else:
|
||||||
|
# Default values from OpenAI API spec
|
||||||
|
max_chunk_size_tokens = 800
|
||||||
|
chunk_overlap_tokens = 400
|
||||||
|
|
||||||
|
file_response = await self.files_api.openai_retrieve_file(file_id)
|
||||||
|
mime_type, _ = mimetypes.guess_type(file_response.filename)
|
||||||
|
content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
||||||
|
|
||||||
|
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
||||||
|
|
||||||
|
chunk_attributes = attributes.copy()
|
||||||
|
chunk_attributes["filename"] = file_response.filename
|
||||||
|
|
||||||
|
chunks = make_overlapped_chunks(
|
||||||
|
file_id,
|
||||||
|
content,
|
||||||
|
max_chunk_size_tokens,
|
||||||
|
chunk_overlap_tokens,
|
||||||
|
chunk_attributes,
|
||||||
|
)
|
||||||
|
|
||||||
|
vector_store_file_object = VectorStoreFileObject(
|
||||||
|
id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
|
chunking_strategy=chunking_strategy,
|
||||||
|
created_at=created_at,
|
||||||
|
status="in_progress",
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not chunks:
|
||||||
|
vector_store_file_object.status = "failed"
|
||||||
|
vector_store_file_object.last_error = VectorStoreFileLastError(
|
||||||
|
code="server_error",
|
||||||
|
message="No chunks were generated from the file",
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# Default implementation: insert chunks directly
|
||||||
|
await self.insert_chunks(
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
chunks=chunks,
|
||||||
|
)
|
||||||
|
vector_store_file_object.status = "completed"
|
||||||
|
|
||||||
|
return vector_store_file_object, chunks, file_response
|
||||||
|
|
||||||
async def openai_attach_file_to_vector_store(
|
async def openai_attach_file_to_vector_store(
|
||||||
self,
|
self,
|
||||||
vector_store_id: str,
|
vector_store_id: str,
|
||||||
|
|
@ -750,69 +814,43 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
attributes = attributes or {}
|
attributes = attributes or {}
|
||||||
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
chunking_strategy = chunking_strategy or VectorStoreChunkingStrategyAuto()
|
||||||
created_at = int(time.time())
|
created_at = int(time.time())
|
||||||
chunks: list[Chunk] = []
|
|
||||||
file_response: OpenAIFileObject | None = None
|
|
||||||
|
|
||||||
vector_store_file_object = VectorStoreFileObject(
|
|
||||||
id=file_id,
|
|
||||||
attributes=attributes,
|
|
||||||
chunking_strategy=chunking_strategy,
|
|
||||||
created_at=created_at,
|
|
||||||
status="in_progress",
|
|
||||||
vector_store_id=vector_store_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not hasattr(self, "files_api") or not self.files_api:
|
if not hasattr(self, "files_api") or not self.files_api:
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object = VectorStoreFileObject(
|
||||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
id=file_id,
|
||||||
code="server_error",
|
attributes=attributes,
|
||||||
message="Files API is not available",
|
chunking_strategy=chunking_strategy,
|
||||||
|
created_at=created_at,
|
||||||
|
status="failed",
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
last_error=VectorStoreFileLastError(
|
||||||
|
code="server_error",
|
||||||
|
message="Files API is not available",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
return vector_store_file_object
|
return vector_store_file_object
|
||||||
|
|
||||||
if isinstance(chunking_strategy, VectorStoreChunkingStrategyStatic):
|
|
||||||
max_chunk_size_tokens = chunking_strategy.static.max_chunk_size_tokens
|
|
||||||
chunk_overlap_tokens = chunking_strategy.static.chunk_overlap_tokens
|
|
||||||
else:
|
|
||||||
# Default values from OpenAI API spec
|
|
||||||
max_chunk_size_tokens = 800
|
|
||||||
chunk_overlap_tokens = 400
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
file_response = await self.files_api.openai_retrieve_file(file_id)
|
vector_store_file_object, chunks, file_response = await self._prepare_and_attach_file_chunks(
|
||||||
mime_type, _ = mimetypes.guess_type(file_response.filename)
|
vector_store_id=vector_store_id,
|
||||||
content_response = await self.files_api.openai_retrieve_file_content(file_id)
|
file_id=file_id,
|
||||||
|
attributes=attributes,
|
||||||
content = content_from_data_and_mime_type(content_response.body, mime_type)
|
chunking_strategy=chunking_strategy,
|
||||||
|
created_at=created_at,
|
||||||
chunk_attributes = attributes.copy()
|
|
||||||
chunk_attributes["filename"] = file_response.filename
|
|
||||||
|
|
||||||
chunks = make_overlapped_chunks(
|
|
||||||
file_id,
|
|
||||||
content,
|
|
||||||
max_chunk_size_tokens,
|
|
||||||
chunk_overlap_tokens,
|
|
||||||
chunk_attributes,
|
|
||||||
)
|
)
|
||||||
if not chunks:
|
|
||||||
vector_store_file_object.status = "failed"
|
|
||||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
|
||||||
code="server_error",
|
|
||||||
message="No chunks were generated from the file",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
await self.insert_chunks(
|
|
||||||
vector_store_id=vector_store_id,
|
|
||||||
chunks=chunks,
|
|
||||||
)
|
|
||||||
vector_store_file_object.status = "completed"
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error attaching file to vector store")
|
logger.exception("Error attaching file to vector store")
|
||||||
vector_store_file_object.status = "failed"
|
vector_store_file_object = VectorStoreFileObject(
|
||||||
vector_store_file_object.last_error = VectorStoreFileLastError(
|
id=file_id,
|
||||||
code="server_error",
|
attributes=attributes,
|
||||||
message=str(e),
|
chunking_strategy=chunking_strategy,
|
||||||
|
created_at=created_at,
|
||||||
|
status="failed",
|
||||||
|
vector_store_id=vector_store_id,
|
||||||
|
last_error=VectorStoreFileLastError(
|
||||||
|
code="server_error",
|
||||||
|
message=str(e),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create OpenAI vector store file metadata
|
# Create OpenAI vector store file metadata
|
||||||
|
|
@ -820,8 +858,8 @@ class OpenAIVectorStoreMixin(ABC):
|
||||||
file_info["filename"] = file_response.filename if file_response else ""
|
file_info["filename"] = file_response.filename if file_response else ""
|
||||||
|
|
||||||
# Save vector store file to persistent storage (provider-specific)
|
# Save vector store file to persistent storage (provider-specific)
|
||||||
dict_chunks = [c.model_dump() for c in chunks]
|
# Only save chunks if they were generated (some providers like OpenAI handle storage remotely)
|
||||||
# This should be updated to include chunk_id
|
dict_chunks = [c.model_dump() for c in chunks] if chunks else []
|
||||||
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
|
await self._save_openai_vector_store_file(vector_store_id, file_id, file_info, dict_chunks)
|
||||||
|
|
||||||
# Update file_ids and file_counts in vector store metadata
|
# Update file_ids and file_counts in vector store metadata
|
||||||
|
|
|
||||||
|
|
@ -371,6 +371,7 @@ def vector_provider_wrapper(func):
|
||||||
# For CI tests (replay/record), only use providers that are available in ci-tests environment
|
# For CI tests (replay/record), only use providers that are available in ci-tests environment
|
||||||
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") in ("replay", "record"):
|
if os.environ.get("LLAMA_STACK_TEST_INFERENCE_MODE") in ("replay", "record"):
|
||||||
all_providers = ["faiss", "sqlite-vec"]
|
all_providers = ["faiss", "sqlite-vec"]
|
||||||
|
all_providers = ["openai-vector-store"]
|
||||||
else:
|
else:
|
||||||
# For live tests, try all providers (they'll skip if not available)
|
# For live tests, try all providers (they'll skip if not available)
|
||||||
all_providers = [
|
all_providers = [
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,17 @@ def skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def skip_if_provider_is_openai_vector_store(client_with_models):
|
||||||
|
"""Skip tests that require direct chunk insertion/querying (not supported by OpenAI)."""
|
||||||
|
vector_io_providers = [p for p in client_with_models.providers.list() if p.api == "vector_io"]
|
||||||
|
for p in vector_io_providers:
|
||||||
|
if p.provider_type == "remote::openai":
|
||||||
|
pytest.skip(
|
||||||
|
"OpenAI Vector Stores provider does not support direct chunk insertion/querying. "
|
||||||
|
"Use file attachment instead."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def sample_chunks():
|
def sample_chunks():
|
||||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||||
|
|
@ -144,8 +155,8 @@ def compat_client_with_empty_stores(compat_client):
|
||||||
yield compat_client
|
yield compat_client
|
||||||
|
|
||||||
# Clean up after the test
|
# Clean up after the test
|
||||||
clear_vector_stores()
|
# clear_vector_stores()
|
||||||
clear_files()
|
# clear_files()
|
||||||
|
|
||||||
|
|
||||||
@vector_provider_wrapper
|
@vector_provider_wrapper
|
||||||
|
|
@ -365,6 +376,7 @@ def test_openai_vector_store_with_chunks(
|
||||||
):
|
):
|
||||||
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
"""Test vector store functionality with actual chunks using both OpenAI and native APIs."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
skip_if_provider_is_openai_vector_store(client_with_models)
|
||||||
|
|
||||||
compat_client = compat_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
@ -430,6 +442,7 @@ def test_openai_vector_store_search_relevance(
|
||||||
):
|
):
|
||||||
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
"""Test that OpenAI vector store search returns relevant results for different queries."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
skip_if_provider_is_openai_vector_store(client_with_models)
|
||||||
|
|
||||||
compat_client = compat_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
@ -482,6 +495,7 @@ def test_openai_vector_store_search_with_ranking_options(
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store search with ranking options."""
|
"""Test OpenAI vector store search with ranking options."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
skip_if_provider_is_openai_vector_store(client_with_models)
|
||||||
|
|
||||||
compat_client = compat_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
@ -542,6 +556,7 @@ def test_openai_vector_store_search_with_high_score_filter(
|
||||||
):
|
):
|
||||||
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
"""Test that searching with text very similar to a document and high score threshold returns only that document."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
skip_if_provider_is_openai_vector_store(client_with_models)
|
||||||
|
|
||||||
compat_client = compat_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
@ -608,6 +623,7 @@ def test_openai_vector_store_search_with_max_num_results(
|
||||||
):
|
):
|
||||||
"""Test OpenAI vector store search with max_num_results."""
|
"""Test OpenAI vector store search with max_num_results."""
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
|
skip_if_provider_is_openai_vector_store(client_with_models)
|
||||||
|
|
||||||
compat_client = compat_client_with_empty_stores
|
compat_client = compat_client_with_empty_stores
|
||||||
llama_client = client_with_models
|
llama_client = client_with_models
|
||||||
|
|
@ -678,6 +694,13 @@ def test_openai_vector_store_attach_file(
|
||||||
assert file_attach_response.object == "vector_store.file"
|
assert file_attach_response.object == "vector_store.file"
|
||||||
assert file_attach_response.id == file.id
|
assert file_attach_response.id == file.id
|
||||||
assert file_attach_response.vector_store_id == vector_store.id
|
assert file_attach_response.vector_store_id == vector_store.id
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
while file_attach_response.status != "completed" and time.time() - start_time < 10:
|
||||||
|
file_attach_response = compat_client.vector_stores.files.retrieve(
|
||||||
|
vector_store_id=vector_store.id,
|
||||||
|
file_id=file.id,
|
||||||
|
)
|
||||||
assert file_attach_response.status == "completed"
|
assert file_attach_response.status == "completed"
|
||||||
assert file_attach_response.chunking_strategy.type == "auto"
|
assert file_attach_response.chunking_strategy.type == "auto"
|
||||||
assert file_attach_response.created_at > 0
|
assert file_attach_response.created_at > 0
|
||||||
|
|
@ -1178,6 +1201,7 @@ def test_openai_vector_store_search_modes(
|
||||||
):
|
):
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
skip_if_provider_doesnt_support_openai_vector_stores(client_with_models)
|
||||||
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
|
skip_if_provider_doesnt_support_openai_vector_stores_search(client_with_models, search_mode)
|
||||||
|
skip_if_provider_is_openai_vector_store(client_with_models)
|
||||||
|
|
||||||
vector_store = llama_stack_client.vector_stores.create(
|
vector_store = llama_stack_client.vector_stores.create(
|
||||||
name=f"search_mode_test_{search_mode}",
|
name=f"search_mode_test_{search_mode}",
|
||||||
|
|
|
||||||
51
uv.lock
generated
51
uv.lock
generated
|
|
@ -2096,8 +2096,8 @@ requires-dist = [
|
||||||
{ name = "httpx" },
|
{ name = "httpx" },
|
||||||
{ name = "jinja2", specifier = ">=3.1.6" },
|
{ name = "jinja2", specifier = ">=3.1.6" },
|
||||||
{ name = "jsonschema" },
|
{ name = "jsonschema" },
|
||||||
{ name = "llama-stack-client", specifier = ">=0.3.0" },
|
{ name = "llama-stack-client", editable = "../llama-stack-client-python" },
|
||||||
{ name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.3.0" },
|
{ name = "llama-stack-client", marker = "extra == 'ui'", editable = "../llama-stack-client-python" },
|
||||||
{ name = "openai", specifier = ">=2.5.0" },
|
{ name = "openai", specifier = ">=2.5.0" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" },
|
||||||
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
{ name = "opentelemetry-sdk", specifier = ">=1.30.0" },
|
||||||
|
|
@ -2232,8 +2232,8 @@ unit = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "llama-stack-client"
|
name = "llama-stack-client"
|
||||||
version = "0.3.0"
|
version = "0.4.0a1"
|
||||||
source = { registry = "https://pypi.org/simple" }
|
source = { editable = "../llama-stack-client-python" }
|
||||||
dependencies = [
|
dependencies = [
|
||||||
{ name = "anyio" },
|
{ name = "anyio" },
|
||||||
{ name = "click" },
|
{ name = "click" },
|
||||||
|
|
@ -2251,10 +2251,47 @@ dependencies = [
|
||||||
{ name = "tqdm" },
|
{ name = "tqdm" },
|
||||||
{ name = "typing-extensions" },
|
{ name = "typing-extensions" },
|
||||||
]
|
]
|
||||||
sdist = { url = "https://files.pythonhosted.org/packages/1d/d9/3c720f420fc80ce51de1a0ad90c53edc613617b68980137dcf716a86198a/llama_stack_client-0.3.0.tar.gz", hash = "sha256:1e974a74d0da285e18ba7df30b9a324e250782b130253bcef3e695830c5bb03d", size = 340443, upload-time = "2025-10-21T23:58:25.855Z" }
|
|
||||||
wheels = [
|
[package.metadata]
|
||||||
{ url = "https://files.pythonhosted.org/packages/96/27/1c65035ce58100be22409c98e4d65b1cdaeff7811ea968f9f844641330d7/llama_stack_client-0.3.0-py3-none-any.whl", hash = "sha256:9f85d84d508ef7da44b96ca8555d7783da717cfc9135bab6a5530fe8c852690d", size = 425234, upload-time = "2025-10-21T23:58:24.246Z" },
|
requires-dist = [
|
||||||
|
{ name = "aiohttp", marker = "extra == 'aiohttp'" },
|
||||||
|
{ name = "anyio", specifier = ">=3.5.0,<5" },
|
||||||
|
{ name = "click" },
|
||||||
|
{ name = "distro", specifier = ">=1.7.0,<2" },
|
||||||
|
{ name = "fire" },
|
||||||
|
{ name = "httpx", specifier = ">=0.23.0,<1" },
|
||||||
|
{ name = "httpx-aiohttp", marker = "extra == 'aiohttp'", specifier = ">=0.1.9" },
|
||||||
|
{ name = "pandas" },
|
||||||
|
{ name = "prompt-toolkit" },
|
||||||
|
{ name = "pyaml" },
|
||||||
|
{ name = "pydantic", specifier = ">=1.9.0,<3" },
|
||||||
|
{ name = "requests" },
|
||||||
|
{ name = "rich" },
|
||||||
|
{ name = "sniffio" },
|
||||||
|
{ name = "termcolor" },
|
||||||
|
{ name = "tqdm" },
|
||||||
|
{ name = "typing-extensions", specifier = ">=4.7,<5" },
|
||||||
]
|
]
|
||||||
|
provides-extras = ["aiohttp"]
|
||||||
|
|
||||||
|
[package.metadata.requires-dev]
|
||||||
|
dev = [
|
||||||
|
{ name = "black" },
|
||||||
|
{ name = "dirty-equals", specifier = ">=0.6.0" },
|
||||||
|
{ name = "importlib-metadata", specifier = ">=6.7.0" },
|
||||||
|
{ name = "mypy" },
|
||||||
|
{ name = "pre-commit" },
|
||||||
|
{ name = "pyright", specifier = "==1.1.399" },
|
||||||
|
{ name = "pytest", specifier = ">=7.1.1" },
|
||||||
|
{ name = "pytest-asyncio" },
|
||||||
|
{ name = "pytest-xdist", specifier = ">=3.6.1" },
|
||||||
|
{ name = "respx" },
|
||||||
|
{ name = "rich", specifier = ">=13.7.1" },
|
||||||
|
{ name = "ruff" },
|
||||||
|
{ name = "time-machine" },
|
||||||
|
]
|
||||||
|
pydantic-v1 = [{ name = "pydantic", specifier = ">=1.9.0,<2" }]
|
||||||
|
pydantic-v2 = [{ name = "pydantic", specifier = ">=2,<3" }]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "lm-format-enforcer"
|
name = "lm-format-enforcer"
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue