llama-stack/llama_stack/providers/remote/vector_io/weaviate/weaviate.py
Ihar Hrachyshka 9e6561a1ec
chore: enable pyupgrade fixes (#1806)
# What does this PR do?

The goal of this PR is code base modernization.

Schema reflection code needed a minor adjustment to handle UnionTypes
and collections.abc.AsyncIterator. (Both are preferred for latest Python
releases.)

Note to reviewers: almost all changes here are automatically generated
by pyupgrade. Some additional unused imports were cleaned up. The only
change worth of note can be found under `docs/openapi_generator` and
`llama_stack/strong_typing/schema.py` where reflection code was updated
to deal with "newer" types.

Signed-off-by: Ihar Hrachyshka <ihar.hrachyshka@gmail.com>
2025-05-01 14:23:50 -07:00

189 lines
6.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import Any
import weaviate
import weaviate.classes as wvc
from numpy.typing import NDArray
from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import Api, VectorDBsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import (
EmbeddingIndex,
VectorDBWithIndex,
)
from .config import WeaviateRequestProviderData, WeaviateVectorIOConfig
log = logging.getLogger(__name__)
class WeaviateIndex(EmbeddingIndex):
def __init__(self, client: weaviate.Client, collection_name: str):
self.client = client
self.collection_name = collection_name
async def add_chunks(self, chunks: list[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
data_objects = []
for i, chunk in enumerate(chunks):
data_objects.append(
wvc.data.DataObject(
properties={
"chunk_content": chunk.model_dump_json(),
},
vector=embeddings[i].tolist(),
)
)
# Inserting chunks into a prespecified Weaviate collection
collection = self.client.collections.get(self.collection_name)
# TODO: make this async friendly
collection.data.insert_many(data_objects)
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
collection = self.client.collections.get(self.collection_name)
results = collection.query.near_vector(
near_vector=embedding.tolist(),
limit=k,
return_metadata=wvc.query.MetadataQuery(distance=True),
)
chunks = []
scores = []
for doc in results.objects:
chunk_json = doc.properties["chunk_content"]
try:
chunk_dict = json.loads(chunk_json)
chunk = Chunk(**chunk_dict)
except Exception:
log.exception(f"Failed to parse document: {chunk_json}")
continue
chunks.append(chunk)
scores.append(1.0 / doc.metadata.distance)
return QueryChunksResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: list[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(where=Filter.by_property("id").contains_any(chunk_ids))
class WeaviateVectorIOAdapter(
VectorIO,
NeedsRequestProviderData,
VectorDBsProtocolPrivate,
):
def __init__(self, config: WeaviateVectorIOConfig, inference_api: Api.inference) -> None:
self.config = config
self.inference_api = inference_api
self.client_cache = {}
self.cache = {}
def _get_client(self) -> weaviate.Client:
provider_data = self.get_request_provider_data()
assert provider_data is not None, "Request provider data must be set"
assert isinstance(provider_data, WeaviateRequestProviderData)
key = f"{provider_data.weaviate_cluster_url}::{provider_data.weaviate_api_key}"
if key in self.client_cache:
return self.client_cache[key]
client = weaviate.connect_to_weaviate_cloud(
cluster_url=provider_data.weaviate_cluster_url,
auth_credentials=Auth.api_key(provider_data.weaviate_api_key),
)
self.client_cache[key] = client
return client
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
for client in self.client_cache.values():
client.close()
async def register_vector_db(
self,
vector_db: VectorDB,
) -> None:
client = self._get_client()
# Create collection if it doesn't exist
if not client.collections.exists(vector_db.identifier):
client.collections.create(
name=vector_db.identifier,
vectorizer_config=wvc.config.Configure.Vectorizer.none(),
properties=[
wvc.config.Property(
name="chunk_content",
data_type=wvc.config.DataType.TEXT,
),
],
)
self.cache[vector_db.identifier] = VectorDBWithIndex(
vector_db,
WeaviateIndex(client=client, collection_name=vector_db.identifier),
self.inference_api,
)
async def _get_and_cache_vector_db_index(self, vector_db_id: str) -> VectorDBWithIndex | None:
if vector_db_id in self.cache:
return self.cache[vector_db_id]
vector_db = await self.vector_db_store.get_vector_db(vector_db_id)
if not vector_db:
raise ValueError(f"Vector DB {vector_db_id} not found")
client = self._get_client()
if not client.collections.exists(vector_db.identifier):
raise ValueError(f"Collection with name `{vector_db.identifier}` not found")
index = VectorDBWithIndex(
vector_db=vector_db,
index=WeaviateIndex(client=client, collection_name=vector_db.identifier),
inference_api=self.inference_api,
)
self.cache[vector_db_id] = index
return index
async def insert_chunks(
self,
vector_db_id: str,
chunks: list[Chunk],
ttl_seconds: int | None = None,
) -> None:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
await index.insert_chunks(chunks)
async def query_chunks(
self,
vector_db_id: str,
query: InterleavedContent,
params: dict[str, Any] | None = None,
) -> QueryChunksResponse:
index = await self._get_and_cache_vector_db_index(vector_db_id)
if not index:
raise ValueError(f"Vector DB {vector_db_id} not found")
return await index.query_chunks(query, params)