From 01d90eb8ab3048e184a9c2ab46016db8cb238b52 Mon Sep 17 00:00:00 2001 From: Francisco Javier Arceo Date: Wed, 23 Jul 2025 21:20:16 -0400 Subject: [PATCH] chore: Enabling teste for Weaviate and some minor changes Signed-off-by: Francisco Javier Arceo --- .../workflows/integration-vector-io-tests.yml | 29 ++- .../providers/vector_io/remote_weaviate.md | 14 ++ llama_stack/apis/common/errors.py | 8 + llama_stack/apis/vector_io/vector_io.py | 2 +- llama_stack/distribution/routers/inference.py | 13 +- .../distribution/routing_tables/common.py | 3 +- .../distribution/routing_tables/models.py | 3 +- .../distribution/routing_tables/vector_dbs.py | 3 +- .../remote/vector_io/milvus/milvus.py | 10 +- .../remote/vector_io/weaviate/__init__.py | 2 +- .../remote/vector_io/weaviate/config.py | 24 ++- .../remote/vector_io/weaviate/weaviate.py | 178 +++++++++++------- .../providers/utils/memory/vector_store.py | 2 +- .../{chunk_utils.py => vector_utils.py} | 18 ++ scripts/generate_prompt_format.py | 3 +- .../vector_io/test_openai_vector_stores.py | 5 +- ...st_chunk_utils.py => test_vector_utils.py} | 2 +- 17 files changed, 216 insertions(+), 103 deletions(-) rename llama_stack/providers/utils/vector_io/{chunk_utils.py => vector_utils.py} (58%) rename tests/unit/providers/vector_io/{test_chunk_utils.py => test_vector_utils.py} (97%) diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index 9a02bbcf8..fc2a6a514 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -24,7 +24,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector"] + vector-io-provider: ["inline::faiss", "inline::sqlite-vec", "inline::milvus", "remote::chromadb", "remote::pgvector", "remote::weaviate"] python-version: ["3.12", "3.13"] fail-fast: false # we want to run all tests regardless of failure @@ -48,6 +48,14 @@ jobs: -e ANONYMIZED_TELEMETRY=FALSE \ chromadb/chroma:latest + - name: Setup Weaviate + if: matrix.vector-io-provider == 'remote::weaviate' + run: | + docker run --rm -d --pull always \ + --name weaviate \ + -p 8080:8080 -p 50051:50051 \ + cr.weaviate.io/semitechnologies/weaviate:1.32.0 + - name: Start PGVector DB if: matrix.vector-io-provider == 'remote::pgvector' run: | @@ -93,6 +101,21 @@ jobs: docker logs chromadb exit 1 + - name: Wait for Weaviate to be ready + if: matrix.vector-io-provider == 'remote::weaviate' + run: | + echo "Waiting for Weaviate to be ready..." + for i in {1..30}; do + if curl -s http://localhost:8080 | grep -q "https://weaviate.io/developers/weaviate/current/"; then + echo "Weaviate is ready!" + exit 0 + fi + sleep 2 + done + echo "Weaviate failed to start" + docker logs weaviate + exit 1 + - name: Build Llama Stack run: | uv run llama stack build --template ci-tests --image-type venv @@ -113,6 +136,10 @@ jobs: PGVECTOR_DB: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_USER: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} PGVECTOR_PASSWORD: ${{ matrix.vector-io-provider == 'remote::pgvector' && 'llamastack' || '' }} + ENABLE_WEAVIATE: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'true' || '' }} + WEAVIATE_API_KEY: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'llamastack' || '' }} + WEAVIATE_CLUSTER_URL: ${{ matrix.vector-io-provider == 'remote::weaviate' && 'http://localhost:8080' || '' }} + run: | uv run pytest -sv --stack-config="inference=inline::sentence-transformers,vector_io=${{ matrix.vector-io-provider }}" \ tests/integration/vector_io \ diff --git a/docs/source/providers/vector_io/remote_weaviate.md b/docs/source/providers/vector_io/remote_weaviate.md index d930515d5..793032ba9 100644 --- a/docs/source/providers/vector_io/remote_weaviate.md +++ b/docs/source/providers/vector_io/remote_weaviate.md @@ -33,9 +33,23 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. +## Configuration + +| Field | Type | Required | Default | Description | +|-------|------|----------|---------|-------------| +| `host` | `str \| None` | No | localhost | | +| `port` | `int \| None` | No | 8080 | | +| `weaviate_api_key` | `str \| None` | No | | The API key for the Weaviate instance | +| `weaviate_cluster_url` | `str \| None` | No | | The URL of the Weaviate cluster | +| `kvstore` | `utils.kvstore.config.RedisKVStoreConfig \| utils.kvstore.config.SqliteKVStoreConfig \| utils.kvstore.config.PostgresKVStoreConfig \| utils.kvstore.config.MongoDBKVStoreConfig, annotation=NoneType, required=False, default='sqlite', discriminator='type'` | No | | Config for KV store backend (SQLite only for now) | + ## Sample Configuration ```yaml +host: ${env.WEAVIATE_HOST:=localhost} +port: ${env.WEAVIATE_PORT:=8080} +weaviate_api_key: null +weaviate_cluster_url: null kvstore: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/dummy}/weaviate_registry.db diff --git a/llama_stack/apis/common/errors.py b/llama_stack/apis/common/errors.py index 80f297bce..fb52dc772 100644 --- a/llama_stack/apis/common/errors.py +++ b/llama_stack/apis/common/errors.py @@ -11,3 +11,11 @@ class UnsupportedModelError(ValueError): def __init__(self, model_name: str, supported_models_list: list[str]): message = f"'{model_name}' model is not supported. Supported models are: {', '.join(supported_models_list)}" super().__init__(message) + + +class ModelNotFoundError(ValueError): + """raised when Llama Stack cannot find a referenced model""" + + def __init__(self, model_name: str) -> None: + message = f"Model '{model_name}' not found. Use client.models.list() to list available models." + super().__init__(message) diff --git a/llama_stack/apis/vector_io/vector_io.py b/llama_stack/apis/vector_io/vector_io.py index 853c4656c..d09a9f2f3 100644 --- a/llama_stack/apis/vector_io/vector_io.py +++ b/llama_stack/apis/vector_io/vector_io.py @@ -16,7 +16,7 @@ from pydantic import BaseModel, Field from llama_stack.apis.inference import InterleavedContent from llama_stack.apis.vector_dbs import VectorDB from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol -from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id +from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.schema_utils import json_schema_type, webmethod from llama_stack.strong_typing.schema import register_schema diff --git a/llama_stack/distribution/routers/inference.py b/llama_stack/distribution/routers/inference.py index c864b0eb0..6152acd57 100644 --- a/llama_stack/distribution/routers/inference.py +++ b/llama_stack/distribution/routers/inference.py @@ -17,6 +17,7 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.inference import ( BatchChatCompletionResponse, BatchCompletionResponse, @@ -188,7 +189,7 @@ class InferenceRouter(Inference): sampling_params = SamplingParams() model = await self.routing_table.get_model(model_id) if model is None: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") if tool_config: @@ -317,7 +318,7 @@ class InferenceRouter(Inference): ) model = await self.routing_table.get_model(model_id) if model is None: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if model.model_type == ModelType.embedding: raise ValueError(f"Model '{model_id}' is an embedding model and does not support chat completions") provider = await self.routing_table.get_provider_impl(model_id) @@ -390,7 +391,7 @@ class InferenceRouter(Inference): logger.debug(f"InferenceRouter.embeddings: {model_id}") model = await self.routing_table.get_model(model_id) if model is None: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if model.model_type == ModelType.llm: raise ValueError(f"Model '{model_id}' is an LLM model and does not support embeddings") provider = await self.routing_table.get_provider_impl(model_id) @@ -430,7 +431,7 @@ class InferenceRouter(Inference): ) model_obj = await self.routing_table.get_model(model) if model_obj is None: - raise ValueError(f"Model '{model}' not found") + raise ModelNotFoundError(model) if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support completions") @@ -491,7 +492,7 @@ class InferenceRouter(Inference): ) model_obj = await self.routing_table.get_model(model) if model_obj is None: - raise ValueError(f"Model '{model}' not found") + raise ModelNotFoundError(model) if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") @@ -562,7 +563,7 @@ class InferenceRouter(Inference): ) model_obj = await self.routing_table.get_model(model) if model_obj is None: - raise ValueError(f"Model '{model}' not found") + raise ModelNotFoundError(model) if model_obj.model_type != ModelType.embedding: raise ValueError(f"Model '{model}' is not an embedding model") diff --git a/llama_stack/distribution/routing_tables/common.py b/llama_stack/distribution/routing_tables/common.py index caf0780fd..a759ea8dd 100644 --- a/llama_stack/distribution/routing_tables/common.py +++ b/llama_stack/distribution/routing_tables/common.py @@ -6,6 +6,7 @@ from typing import Any +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.models import Model from llama_stack.apis.resource import ResourceType from llama_stack.apis.scoring_functions import ScoringFn @@ -257,7 +258,7 @@ async def lookup_model(routing_table: CommonRoutingTableImpl, model_id: str) -> models = await routing_table.get_all_with_type("model") matching_models = [m for m in models if m.provider_resource_id == model_id] if len(matching_models) == 0: - raise ValueError(f"Model '{model_id}' not found") + raise ModelNotFoundError(model_id) if len(matching_models) > 1: raise ValueError(f"Multiple providers found for '{model_id}': {[m.provider_id for m in matching_models]}") diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 3928307c6..ae1fe2882 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -7,6 +7,7 @@ import time from typing import Any +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel from llama_stack.distribution.datatypes import ( ModelWithOwner, @@ -111,7 +112,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): async def unregister_model(self, model_id: str) -> None: existing_model = await self.get_model(model_id) if existing_model is None: - raise ValueError(f"Model {model_id} not found") + raise ModelNotFoundError(model_id) await self.unregister_object(existing_model) async def update_registered_models( diff --git a/llama_stack/distribution/routing_tables/vector_dbs.py b/llama_stack/distribution/routing_tables/vector_dbs.py index 58ecf24da..eb4cd8cd9 100644 --- a/llama_stack/distribution/routing_tables/vector_dbs.py +++ b/llama_stack/distribution/routing_tables/vector_dbs.py @@ -8,6 +8,7 @@ from typing import Any from pydantic import TypeAdapter +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.apis.models import ModelType from llama_stack.apis.resource import ResourceType from llama_stack.apis.vector_dbs import ListVectorDBsResponse, VectorDB, VectorDBs @@ -63,7 +64,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): raise ValueError("No provider available. Please configure a vector_io provider.") model = await lookup_model(self, embedding_model) if model is None: - raise ValueError(f"Model {embedding_model} not found") + raise ModelNotFoundError(embedding_model) if model.model_type != ModelType.embedding: raise ValueError(f"Model {embedding_model} is not an embedding model") if "embedding_dimension" not in model.metadata: diff --git a/llama_stack/providers/remote/vector_io/milvus/milvus.py b/llama_stack/providers/remote/vector_io/milvus/milvus.py index f1652a80e..634db0140 100644 --- a/llama_stack/providers/remote/vector_io/milvus/milvus.py +++ b/llama_stack/providers/remote/vector_io/milvus/milvus.py @@ -7,7 +7,6 @@ import asyncio import logging import os -import re from typing import Any from numpy.typing import NDArray @@ -30,6 +29,7 @@ from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig @@ -43,14 +43,6 @@ OPENAI_VECTOR_STORES_FILES_PREFIX = f"openai_vector_stores_files:milvus:{VERSION OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_contents:milvus:{VERSION}::" -def sanitize_collection_name(name: str) -> str: - """ - Sanitize collection name to ensure it only contains numbers, letters, and underscores. - Any other characters are replaced with underscores. - """ - return re.sub(r"[^a-zA-Z0-9_]", "_", name) - - class MilvusIndex(EmbeddingIndex): def __init__( self, client: MilvusClient, collection_name: str, consistency_level="Strong", kvstore: KVStore | None = None diff --git a/llama_stack/providers/remote/vector_io/weaviate/__init__.py b/llama_stack/providers/remote/vector_io/weaviate/__init__.py index 22e116c22..9272b21e2 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/__init__.py +++ b/llama_stack/providers/remote/vector_io/weaviate/__init__.py @@ -12,6 +12,6 @@ from .config import WeaviateVectorIOConfig async def get_adapter_impl(config: WeaviateVectorIOConfig, deps: dict[Api, ProviderSpec]): from .weaviate import WeaviateVectorIOAdapter - impl = WeaviateVectorIOAdapter(config, deps[Api.inference]) + impl = WeaviateVectorIOAdapter(config, deps[Api.inference], deps.get(Api.files, None)) await impl.initialize() return impl diff --git a/llama_stack/providers/remote/vector_io/weaviate/config.py b/llama_stack/providers/remote/vector_io/weaviate/config.py index 4283b8d3b..b111e9032 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/config.py +++ b/llama_stack/providers/remote/vector_io/weaviate/config.py @@ -12,18 +12,30 @@ from llama_stack.providers.utils.kvstore.config import ( KVStoreConfig, SqliteKVStoreConfig, ) +from llama_stack.schema_utils import json_schema_type -class WeaviateRequestProviderData(BaseModel): - weaviate_api_key: str - weaviate_cluster_url: str +@json_schema_type +class WeaviateVectorIOConfig(BaseModel): + host: str | None = Field(default="localhost") + port: int | None = Field(default=8080) + weaviate_api_key: str | None = Field(description="The API key for the Weaviate instance", default=None) + weaviate_cluster_url: str | None = Field(description="The URL of the Weaviate cluster", default=None) kvstore: KVStoreConfig | None = Field(description="Config for KV store backend (SQLite only for now)", default=None) - -class WeaviateVectorIOConfig(BaseModel): @classmethod - def sample_run_config(cls, __distro_dir__: str, **kwargs: Any) -> dict[str, Any]: + def sample_run_config( + cls, + __distro_dir__: str, + host: str = "${env.WEAVIATE_HOST:=localhost}", + port: int = "${env.WEAVIATE_PORT:=8080}", + **kwargs: Any, + ) -> dict[str, Any]: return { + "host": "${env.WEAVIATE_HOST:=localhost}", + "port": "${env.WEAVIATE_PORT:=8080}", + "weaviate_api_key": None, + "weaviate_cluster_url": None, "kvstore": SqliteKVStoreConfig.sample_run_config( __distro_dir__=__distro_dir__, db_name="weaviate_registry.db", diff --git a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py index 543835e20..b61b63d8c 100644 --- a/llama_stack/providers/remote/vector_io/weaviate/weaviate.py +++ b/llama_stack/providers/remote/vector_io/weaviate/weaviate.py @@ -21,12 +21,16 @@ from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.api import KVStore +from llama_stack.providers.utils.memory.openai_vector_store_mixin import ( + OpenAIVectorStoreMixin, +) from llama_stack.providers.utils.memory.vector_store import ( EmbeddingIndex, VectorDBWithIndex, ) +from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collection_name -from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig +from .config import WeaviateVectorIOConfig log = logging.getLogger(__name__) @@ -39,11 +43,19 @@ OPENAI_VECTOR_STORES_FILES_CONTENTS_PREFIX = f"openai_vector_stores_files_conten class WeaviateIndex(EmbeddingIndex): - def __init__(self, client: weaviate.Client, collection_name: str, kvstore: KVStore | None = None): + def __init__( + self, + client: weaviate.Client, + collection_name: str, + kvstore: KVStore | None = None, + ): self.client = client - self.collection_name = collection_name + self.collection_name = sanitize_collection_name(collection_name, weaviate_format=True) self.kvstore = kvstore + async def initialize(self): + pass + async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray): assert len(chunks) == len(embeddings), ( f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}" @@ -67,10 +79,13 @@ class WeaviateIndex(EmbeddingIndex): collection.data.insert_many(data_objects) async def delete_chunk(self, chunk_id: str) -> None: - raise NotImplementedError("delete_chunk is not supported in Chroma") + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) + collection.data.delete_many(where=Filter.by_property("id").contains_any([chunk_id])) async def query_vector(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse: - collection = self.client.collections.get(self.collection_name) + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + collection = self.client.collections.get(sanitized_collection_name) results = collection.query.near_vector( near_vector=embedding.tolist(), @@ -94,8 +109,17 @@ class WeaviateIndex(EmbeddingIndex): return QueryChunksResponse(chunks=chunks, scores=scores) - async def delete(self, chunk_ids: list[str]) -> None: - collection = self.client.collections.get(self.collection_name) + async def delete(self, chunk_ids: list[str] | None = None) -> None: + """ + Delete chunks by IDs if provided, otherwise drop the entire collection. + """ + sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True) + if chunk_ids is None: + # Drop entire collection if it exists + if self.client.collections.exists(sanitized_collection_name): + self.client.collections.delete(sanitized_collection_name) + return + collection = self.client.collections.get(sanitized_collection_name) collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids)) async def query_keyword( @@ -119,6 +143,7 @@ class WeaviateIndex(EmbeddingIndex): class WeaviateVectorIOAdapter( + OpenAIVectorStoreMixin, VectorIO, NeedsRequestProviderData, VectorDBsProtocolPrivate, @@ -140,42 +165,53 @@ class WeaviateVectorIOAdapter( self.metadata_collection_name = "openai_vector_stores_metadata" def _get_client(self) -> weaviate.Client: - provider_data = self.get_request_provider_data() - assert provider_data is not None, "Request provider data must be set" - assert isinstance(provider_data, WeaviateRequestProviderData) - - key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}" - if key in self.client_cache: - return self.client_cache[key] - - client = weaviate.connect_to_weaviate_cloud( - cluster_url=provider_data.weaviate_cluster_url, - auth_credentials=Auth.api_key(provider_data.weaviate_api_key), - ) + if self.config.weaviate_cluster_url is None: + key = "local_test" + client = weaviate.connect_to_local( + host=self.config.host, + port=self.config.port, + ) + else: + key = f"{self.config.weaviate_cluster_url}::{self.config.weaviate_api_key}" + if key in self.client_cache: + return self.client_cache[key] + client = weaviate.connect_to_weaviate_cloud( + cluster_url=self.config.weaviate_cluster_url, + auth_credentials=Auth.api_key(self.config.weaviate_api_key), + ) self.client_cache[key] = client return client async def initialize(self) -> None: """Set up KV store and load existing vector DBs and OpenAI vector stores.""" - # Initialize KV store for metadata - self.kvstore = await kvstore_impl(self.config.kvstore) + # Initialize KV store for metadata if configured + if self.config.kvstore is not None: + self.kvstore = await kvstore_impl(self.config.kvstore) + else: + self.kvstore = None + log.info("No kvstore configured, registry will not persist across restarts") # Load existing vector DB definitions - start_key = VECTOR_DBS_PREFIX - end_key = f"{VECTOR_DBS_PREFIX}\xff" - stored = await self.kvstore.values_in_range(start_key, end_key) - for raw in stored: - vector_db = VectorDB.model_validate_json(raw) - client = self._get_client() - idx = WeaviateIndex(client=client, collection_name=vector_db.identifier, kvstore=self.kvstore) - self.cache[vector_db.identifier] = VectorDBWithIndex( - vector_db=vector_db, - index=idx, - inference_api=self.inference_api, - ) + if self.kvstore is not None: + start_key = VECTOR_DBS_PREFIX + end_key = f"{VECTOR_DBS_PREFIX}\xff" + stored = await self.kvstore.values_in_range(start_key, end_key) + for raw in stored: + vector_db = VectorDB.model_validate_json(raw) + client = self._get_client() + idx = WeaviateIndex( + client=client, + collection_name=vector_db.identifier, + kvstore=self.kvstore, + ) + self.cache[vector_db.identifier] = VectorDBWithIndex( + vector_db=vector_db, + index=idx, + inference_api=self.inference_api, + ) - # Load OpenAI vector stores metadata into cache - await self.initialize_openai_vector_stores() + # Load OpenAI vector stores metadata into cache + await self.initialize_openai_vector_stores() async def shutdown(self) -> None: for client in self.client_cache.values(): @@ -186,11 +222,11 @@ class WeaviateVectorIOAdapter( vector_db: VectorDB, ) -> None: client = self._get_client() - + sanitized_collection_name = sanitize_collection_name(vector_db.identifier, weaviate_format=True) # Create collection if it doesn't exist - if not client.collections.exists(vector_db.identifier): + if not client.collections.exists(sanitized_collection_name): client.collections.create( - name=vector_db.identifier, + name=sanitized_collection_name, vectorizer_config=wvc.config.Configure.Vectorizer.none(), properties=[ wvc.config.Property( @@ -200,30 +236,41 @@ class WeaviateVectorIOAdapter( ], ) - self.cache[vector_db.identifier] = VectorDBWithIndex( + self.cache[sanitized_collection_name] = VectorDBWithIndex( vector_db, - WeaviateIndex(client=client, collection_name=vector_db.identifier), + WeaviateIndex(client=client, collection_name=sanitized_collection_name), self.inference_api, ) - async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: - if vector_db_id in self.cache: - return self.cache[vector_db_id] + async def unregister_vector_db(self, vector_db_id: str) -> None: + client = self._get_client() + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + if sanitized_collection_name not in self.cache or client.collections.exists(sanitized_collection_name) is False: + log.warning(f"Vector DB {sanitized_collection_name} not found") + return + client.collections.delete(sanitized_collection_name) + await self.cache[sanitized_collection_name].index.delete() + del self.cache[sanitized_collection_name] - vector_db = await self.vector_db_store.get_vector_db(vector_db_id) + async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None: + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + if sanitized_collection_name in self.cache: + return self.cache[sanitized_collection_name] + + vector_db = await self.vector_db_store.get_vector_db(sanitized_collection_name) if not vector_db: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise ValueError(f"Vector DB {sanitized_collection_name} not found") client = self._get_client() if not client.collections.exists(vector_db.identifier): - raise ValueError(f"Collection with name `{vector_db.identifier}` not found") + raise ValueError(f"Collection with name `{sanitized_collection_name}` not found") index = VectorDBWithIndex( vector_db=vector_db, - index=WeaviateIndex(client=client, collection_name=vector_db.identifier), + index=WeaviateIndex(client=client, collection_name=sanitized_collection_name), inference_api=self.inference_api, ) - self.cache[vector_db_id] = index + self.cache[sanitized_collection_name] = index return index async def insert_chunks( @@ -232,9 +279,10 @@ class WeaviateVectorIOAdapter( chunks: list[Chunk], ttl_seconds: int | None = None, ) -> None: - index = await self._get_and_cache_vector_db_index(vector_db_id) + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise ValueError(f"Vector DB {sanitized_collection_name} not found") await index.insert_chunks(chunks) @@ -244,29 +292,17 @@ class WeaviateVectorIOAdapter( query: InterleavedContent, params: dict[str, Any] | None = None, ) -> QueryChunksResponse: - index = await self._get_and_cache_vector_db_index(vector_db_id) + sanitized_collection_name = sanitize_collection_name(vector_db_id, weaviate_format=True) + index = await self._get_and_cache_vector_db_index(sanitized_collection_name) if not index: - raise ValueError(f"Vector DB {vector_db_id} not found") + raise ValueError(f"Vector DB {sanitized_collection_name} not found") return await index.query_chunks(query, params) - # OpenAI Vector Stores File operations are not supported in Weaviate - async def _save_openai_vector_store_file( - self, store_id: str, file_id: str, file_info: dict[str, Any], file_contents: list[dict[str, Any]] - ) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _load_openai_vector_store_file(self, store_id: str, file_id: str) -> dict[str, Any]: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _load_openai_vector_store_file_contents(self, store_id: str, file_id: str) -> list[dict[str, Any]]: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _update_openai_vector_store_file(self, store_id: str, file_id: str, file_info: dict[str, Any]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - - async def _delete_openai_vector_store_file_from_storage(self, store_id: str, file_id: str) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") - async def delete_chunks(self, store_id: str, chunk_ids: list[str]) -> None: - raise NotImplementedError("OpenAI Vector Stores API is not supported in Weaviate") + sanitized_collection_name = sanitize_collection_name(store_id, weaviate_format=True) + index = await self._get_and_cache_vector_db_index(sanitized_collection_name) + if not index: + raise ValueError(f"Vector DB {sanitized_collection_name} not found") + + await index.delete(chunk_ids) diff --git a/llama_stack/providers/utils/memory/vector_store.py b/llama_stack/providers/utils/memory/vector_store.py index 4a8749cba..484475e9d 100644 --- a/llama_stack/providers/utils/memory/vector_store.py +++ b/llama_stack/providers/utils/memory/vector_store.py @@ -30,7 +30,7 @@ from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, ) -from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id +from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id log = logging.getLogger(__name__) diff --git a/llama_stack/providers/utils/vector_io/chunk_utils.py b/llama_stack/providers/utils/vector_io/vector_utils.py similarity index 58% rename from llama_stack/providers/utils/vector_io/chunk_utils.py rename to llama_stack/providers/utils/vector_io/vector_utils.py index 01afa6ec8..f2888043e 100644 --- a/llama_stack/providers/utils/vector_io/chunk_utils.py +++ b/llama_stack/providers/utils/vector_io/vector_utils.py @@ -5,6 +5,7 @@ # the root directory of this source tree. import hashlib +import re import uuid @@ -19,3 +20,20 @@ def generate_chunk_id(document_id: str, chunk_text: str, chunk_window: str | Non if chunk_window: hash_input += f":{chunk_window}".encode() return str(uuid.UUID(hashlib.md5(hash_input, usedforsecurity=False).hexdigest())) + + +def proper_case(s: str) -> str: + """Convert a string to proper case (first letter uppercase, rest lowercase).""" + return s[0].upper() + s[1:].lower() if s else s + + +def sanitize_collection_name(name: str, weaviate_format=False) -> str: + """ + Sanitize collection name to ensure it only contains numbers, letters, and underscores. + Any other characters are replaced with underscores. + """ + if not weaviate_format: + s = re.sub(r"[^a-zA-Z0-9_]", "_", name) + else: + s = proper_case(re.sub(r"[^a-zA-Z0-9]", "", name)) + return s diff --git a/scripts/generate_prompt_format.py b/scripts/generate_prompt_format.py index 5598e35f6..855033f95 100755 --- a/scripts/generate_prompt_format.py +++ b/scripts/generate_prompt_format.py @@ -15,6 +15,7 @@ from pathlib import Path import fire +from llama_stack.apis.common.errors import ModelNotFoundError from llama_stack.models.llama.llama3.generation import Llama3 from llama_stack.models.llama.llama4.generation import Llama4 from llama_stack.models.llama.sku_list import resolve_model @@ -34,7 +35,7 @@ def run_main( llama_model = resolve_model(model_id) if not llama_model: - raise ValueError(f"Model {model_id} not found") + raise ModelNotFoundError(model_id) cls = Llama4 if llama4 else Llama3 generator = cls.build( diff --git a/tests/integration/vector_io/test_openai_vector_stores.py b/tests/integration/vector_io/test_openai_vector_stores.py index a34c5b410..c314de0ae 100644 --- a/tests/integration/vector_io/test_openai_vector_stores.py +++ b/tests/integration/vector_io/test_openai_vector_stores.py @@ -29,6 +29,7 @@ def skip_if_provider_doesnt_support_openai_vector_stores(client_with_models): "inline::chromadb", "remote::pgvector", "remote::chromadb", + "remote::weaviate", ]: return @@ -109,11 +110,11 @@ def test_openai_create_vector_store(compat_client_with_empty_stores, client_with # Create a vector store vector_store = client.vector_stores.create( - name="test_vector_store", metadata={"purpose": "testing", "environment": "integration"} + name="Vs_test_vector_store", metadata={"purpose": "testing", "environment": "integration"} ) assert vector_store is not None - assert vector_store.name == "test_vector_store" + assert vector_store.name == "Vs_test_vector_store" assert vector_store.object == "vector_store" assert vector_store.status in ["completed", "in_progress"] assert vector_store.metadata["purpose"] == "testing" diff --git a/tests/unit/providers/vector_io/test_chunk_utils.py b/tests/unit/providers/vector_io/test_vector_utils.py similarity index 97% rename from tests/unit/providers/vector_io/test_chunk_utils.py rename to tests/unit/providers/vector_io/test_vector_utils.py index 535b76d73..a5d803a82 100644 --- a/tests/unit/providers/vector_io/test_chunk_utils.py +++ b/tests/unit/providers/vector_io/test_vector_utils.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from llama_stack.apis.vector_io import Chunk, ChunkMetadata -from llama_stack.providers.utils.vector_io.chunk_utils import generate_chunk_id +from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id # This test is a unit test for the chunk_utils.py helpers. This should only contain # tests which are specific to this file. More general (API-level) tests should be placed in