mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-12 12:06:04 +00:00
fixes
This commit is contained in:
parent
016388b788
commit
27b3d9d223
2 changed files with 9 additions and 12 deletions
|
|
@ -6,15 +6,12 @@
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from pydantic import TypeAdapter
|
|
||||||
|
|
||||||
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
from llama_stack.apis.common.errors import ModelNotFoundError, ModelTypeError
|
||||||
from llama_stack.apis.models import ModelType
|
from llama_stack.apis.models import ModelType
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
|
|
||||||
# Removed VectorStores import to avoid exposing public API
|
# Removed VectorStores import to avoid exposing public API
|
||||||
from llama_stack.apis.vector_io.vector_io import (
|
from llama_stack.apis.vector_io.vector_io import (
|
||||||
OpenAICreateVectorStoreRequestWithExtraBody,
|
|
||||||
SearchRankingOptions,
|
SearchRankingOptions,
|
||||||
VectorStoreChunkingStrategy,
|
VectorStoreChunkingStrategy,
|
||||||
VectorStoreDeleteResponse,
|
VectorStoreDeleteResponse,
|
||||||
|
|
|
||||||
|
|
@ -23,14 +23,14 @@ class TestRagQuery:
|
||||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||||
)
|
)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await rag_tool.query(content=MagicMock(), vector_store_ids=[])
|
await rag_tool.query(content=MagicMock(), vector_db_ids=[])
|
||||||
|
|
||||||
async def test_query_chunk_metadata_handling(self):
|
async def test_query_chunk_metadata_handling(self):
|
||||||
rag_tool = MemoryToolRuntimeImpl(
|
rag_tool = MemoryToolRuntimeImpl(
|
||||||
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
config=MagicMock(), vector_io_api=MagicMock(), inference_api=MagicMock(), files_api=MagicMock()
|
||||||
)
|
)
|
||||||
content = "test query content"
|
content = "test query content"
|
||||||
vector_store_ids = ["db1"]
|
vector_db_ids = ["db1"]
|
||||||
|
|
||||||
chunk_metadata = ChunkMetadata(
|
chunk_metadata = ChunkMetadata(
|
||||||
document_id="doc1",
|
document_id="doc1",
|
||||||
|
|
@ -55,7 +55,7 @@ class TestRagQuery:
|
||||||
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
query_response = QueryChunksResponse(chunks=[chunk], scores=[1.0])
|
||||||
|
|
||||||
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
rag_tool.vector_io_api.query_chunks = AsyncMock(return_value=query_response)
|
||||||
result = await rag_tool.query(content=content, vector_store_ids=vector_store_ids)
|
result = await rag_tool.query(content=content, vector_db_ids=vector_db_ids)
|
||||||
|
|
||||||
assert result is not None
|
assert result is not None
|
||||||
expected_metadata_string = (
|
expected_metadata_string = (
|
||||||
|
|
@ -90,7 +90,7 @@ class TestRagQuery:
|
||||||
files_api=MagicMock(),
|
files_api=MagicMock(),
|
||||||
)
|
)
|
||||||
|
|
||||||
vector_store_ids = ["db1", "db2"]
|
vector_db_ids = ["db1", "db2"]
|
||||||
|
|
||||||
# Fake chunks from each DB
|
# Fake chunks from each DB
|
||||||
chunk_metadata1 = ChunkMetadata(
|
chunk_metadata1 = ChunkMetadata(
|
||||||
|
|
@ -101,7 +101,7 @@ class TestRagQuery:
|
||||||
)
|
)
|
||||||
chunk1 = Chunk(
|
chunk1 = Chunk(
|
||||||
content="chunk from db1",
|
content="chunk from db1",
|
||||||
metadata={"vector_store_id": "db1", "document_id": "doc1"},
|
metadata={"vector_db_id": "db1", "document_id": "doc1"},
|
||||||
stored_chunk_id="c1",
|
stored_chunk_id="c1",
|
||||||
chunk_metadata=chunk_metadata1,
|
chunk_metadata=chunk_metadata1,
|
||||||
)
|
)
|
||||||
|
|
@ -114,7 +114,7 @@ class TestRagQuery:
|
||||||
)
|
)
|
||||||
chunk2 = Chunk(
|
chunk2 = Chunk(
|
||||||
content="chunk from db2",
|
content="chunk from db2",
|
||||||
metadata={"vector_store_id": "db2", "document_id": "doc2"},
|
metadata={"vector_db_id": "db2", "document_id": "doc2"},
|
||||||
stored_chunk_id="c2",
|
stored_chunk_id="c2",
|
||||||
chunk_metadata=chunk_metadata2,
|
chunk_metadata=chunk_metadata2,
|
||||||
)
|
)
|
||||||
|
|
@ -126,13 +126,13 @@ class TestRagQuery:
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
|
|
||||||
result = await rag_tool.query(content="test", vector_store_ids=vector_store_ids)
|
result = await rag_tool.query(content="test", vector_db_ids=vector_db_ids)
|
||||||
returned_chunks = result.metadata["chunks"]
|
returned_chunks = result.metadata["chunks"]
|
||||||
returned_scores = result.metadata["scores"]
|
returned_scores = result.metadata["scores"]
|
||||||
returned_doc_ids = result.metadata["document_ids"]
|
returned_doc_ids = result.metadata["document_ids"]
|
||||||
returned_vector_store_ids = result.metadata["vector_store_ids"]
|
returned_vector_db_ids = result.metadata["vector_db_ids"]
|
||||||
|
|
||||||
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
assert returned_chunks == ["chunk from db1", "chunk from db2"]
|
||||||
assert returned_scores == (0.9, 0.8)
|
assert returned_scores == (0.9, 0.8)
|
||||||
assert returned_doc_ids == ["doc1", "doc2"]
|
assert returned_doc_ids == ["doc1", "doc2"]
|
||||||
assert returned_vector_store_ids == ["db1", "db2"]
|
assert returned_vector_db_ids == ["db1", "db2"]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue