mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
Pin weaviate-client version
This commit is contained in:
parent
980c7c244d
commit
920c2a3e12
5 changed files with 18 additions and 25 deletions
|
@ -500,7 +500,7 @@ See [PGVector's documentation](https://github.com/pgvector/pgvector) for more de
|
|||
api=Api.vector_io,
|
||||
adapter_type="weaviate",
|
||||
provider_type="remote::weaviate",
|
||||
pip_packages=["weaviate-client"],
|
||||
pip_packages=["weaviate-client>=4.16.5"],
|
||||
module="llama_stack.providers.remote.vector_io.weaviate",
|
||||
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
|
||||
|
|
|
@ -96,9 +96,9 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
k: Limit of number of results to return
|
||||
score_threshold: Minimum similarity score threshold
|
||||
Returns:
|
||||
QueryChunksResponse with chunks and scores
|
||||
QueryChunksResponse with chunks and scores.
|
||||
"""
|
||||
log.info(
|
||||
log.debug(
|
||||
f"WEAVIATE VECTOR SEARCH CALLED: embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
|
@ -135,7 +135,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.info(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
log.debug(f"WEAVIATE VECTOR SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def delete(self, chunk_ids: list[str] | None = None) -> None:
|
||||
|
@ -166,7 +166,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
Returns:
|
||||
QueryChunksResponse with chunks and scores
|
||||
"""
|
||||
log.info(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
|
||||
log.debug(f"WEAVIATE KEYWORD SEARCH CALLED: query='{query_string}', k={k}, threshold={score_threshold}")
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
collection = self.client.collections.get(sanitized_collection_name)
|
||||
|
||||
|
@ -199,7 +199,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.info(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
|
||||
log.debug(f"WEAVIATE KEYWORD SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}.")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
async def query_hybrid(
|
||||
|
@ -223,7 +223,7 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
Returns:
|
||||
QueryChunksResponse with combined results
|
||||
"""
|
||||
log.info(
|
||||
log.debug(
|
||||
f"WEAVIATE HYBRID SEARCH CALLED: query='{query_string}', embedding_shape={embedding.shape}, k={k}, threshold={score_threshold}, reranker={reranker_type}"
|
||||
)
|
||||
sanitized_collection_name = sanitize_collection_name(self.collection_name, weaviate_format=True)
|
||||
|
@ -265,11 +265,10 @@ class WeaviateIndex(EmbeddingIndex):
|
|||
if score < score_threshold:
|
||||
continue
|
||||
|
||||
log.info(f"Document {chunk.metadata.get('document_id')} has score {score}")
|
||||
chunks.append(chunk)
|
||||
scores.append(score)
|
||||
|
||||
log.info(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
log.debug(f"WEAVIATE HYBRID SEARCH RESULTS: Found {len(chunks)} chunks with scores {scores}")
|
||||
return QueryChunksResponse(chunks=chunks, scores=scores)
|
||||
|
||||
|
||||
|
@ -297,7 +296,7 @@ class WeaviateVectorIOAdapter(
|
|||
|
||||
def _get_client(self) -> weaviate.WeaviateClient:
|
||||
if "localhost" in self.config.weaviate_cluster_url:
|
||||
log.info("using Weaviate locally in container")
|
||||
log.info("Using Weaviate locally in container")
|
||||
host, port = self.config.weaviate_cluster_url.split(":")
|
||||
key = "local_test"
|
||||
client = weaviate.connect_to_local(
|
||||
|
|
|
@ -49,7 +49,6 @@ dependencies = [
|
|||
"opentelemetry-exporter-otlp-proto-http>=1.30.0", # server
|
||||
"aiosqlite>=0.21.0", # server - for metadata store
|
||||
"asyncpg", # for metadata store
|
||||
"weaviate-client>=4.16.5",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
@ -23,13 +23,13 @@ pymilvus_mock.AnnSearchRequest = MagicMock
|
|||
with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
|
||||
from llama_stack.providers.remote.vector_io.milvus.milvus import MilvusIndex
|
||||
|
||||
# This test is a unit test for the MilvusIndex class. This should only contain
|
||||
# This test is a unit test for the MilvusVectorIOAdapter class. This should only contain
|
||||
# tests which are specific to this class. More general (API-level) tests should be placed in
|
||||
# tests/integration/vector_io/
|
||||
#
|
||||
# How to run this test:
|
||||
#
|
||||
# pytest tests/unit/providers/vector_io/remote/test_milvus.py \
|
||||
# pytest tests/unit/providers/vector_io/test_milvus.py \
|
||||
# -v -s --tb=short --disable-warnings --asyncio-mode=auto
|
||||
|
||||
MILVUS_PROVIDER = "milvus"
|
||||
|
@ -324,6 +324,3 @@ async def test_query_hybrid_search_default_rrf(
|
|||
call_args = mock_milvus_client.hybrid_search.call_args
|
||||
ranker = call_args[1]["ranker"]
|
||||
assert ranker is not None
|
||||
|
||||
|
||||
# TODO: Write tests for the MilvusVectorIOAdapter class.
|
||||
|
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -1777,7 +1777,6 @@ dependencies = [
|
|||
{ name = "termcolor" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "uvicorn" },
|
||||
{ name = "weaviate-client" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
|
@ -1905,7 +1904,6 @@ requires-dist = [
|
|||
{ name = "termcolor" },
|
||||
{ name = "tiktoken" },
|
||||
{ name = "uvicorn", specifier = ">=0.34.0" },
|
||||
{ name = "weaviate-client", specifier = ">=4.16.5" },
|
||||
]
|
||||
provides-extras = ["ui"]
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue