mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-05 18:27:22 +00:00
Merge branch 'main' into add-mongodb-vector_io
This commit is contained in:
commit
d0064fc915
426 changed files with 99110 additions and 62778 deletions
|
|
@ -192,18 +192,18 @@ async def test_create_agent_session_persistence(agents_impl, sample_agent_config
|
|||
assert session_response.session_id is not None
|
||||
|
||||
# Verify the session was stored
|
||||
session = await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
||||
session = await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||
assert session.session_name == "test_session"
|
||||
assert session.session_id == session_response.session_id
|
||||
assert session.started_at is not None
|
||||
assert session.turns == []
|
||||
|
||||
# Delete the session
|
||||
await agents_impl.delete_agents_session(agent_id, session_response.session_id)
|
||||
await agents_impl.delete_agents_session(session_response.session_id, agent_id)
|
||||
|
||||
# Verify the session was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agents_session(agent_id, session_response.session_id)
|
||||
await agents_impl.get_agents_session(session_response.session_id, agent_id)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("enable_session_persistence", [True, False])
|
||||
|
|
@ -226,11 +226,11 @@ async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config,
|
|||
assert session2.session_id in session_ids
|
||||
|
||||
# Delete one session
|
||||
await agents_impl.delete_agents_session(agent_id, session1.session_id)
|
||||
await agents_impl.delete_agents_session(session1.session_id, agent_id)
|
||||
|
||||
# Verify the session was deleted
|
||||
with pytest.raises(ValueError):
|
||||
await agents_impl.get_agents_session(agent_id, session1.session_id)
|
||||
await agents_impl.get_agents_session(session1.session_id, agent_id)
|
||||
|
||||
# List sessions again
|
||||
sessions = await agents_impl.list_agent_sessions(agent_id)
|
||||
|
|
|
|||
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
251
tests/unit/providers/nvidia/test_rerank_inference.py
Normal file
|
|
@ -0,0 +1,251 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.models import ModelType
|
||||
from llama_stack.providers.remote.inference.nvidia.config import NVIDIAConfig
|
||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAInferenceAdapter
|
||||
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
|
||||
|
||||
|
||||
class MockResponse:
|
||||
def __init__(self, status=200, json_data=None, text_data="OK"):
|
||||
self.status = status
|
||||
self._json_data = json_data or {"rankings": []}
|
||||
self._text_data = text_data
|
||||
|
||||
async def json(self):
|
||||
return self._json_data
|
||||
|
||||
async def text(self):
|
||||
return self._text_data
|
||||
|
||||
|
||||
class MockSession:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
self.post_calls = []
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
def post(self, url, **kwargs):
|
||||
self.post_calls.append((url, kwargs))
|
||||
|
||||
class PostContext:
|
||||
def __init__(self, response):
|
||||
self.response = response
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.response
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
return False
|
||||
|
||||
return PostContext(self.response)
|
||||
|
||||
|
||||
def create_adapter(config=None, rerank_endpoints=None):
|
||||
if config is None:
|
||||
config = NVIDIAConfig(api_key="test-key")
|
||||
|
||||
adapter = NVIDIAInferenceAdapter(config=config)
|
||||
|
||||
class MockModel:
|
||||
provider_resource_id = "test-model"
|
||||
metadata = {}
|
||||
|
||||
adapter.model_store = AsyncMock()
|
||||
adapter.model_store.get_model = AsyncMock(return_value=MockModel())
|
||||
|
||||
if rerank_endpoints is not None:
|
||||
adapter.config.rerank_model_to_url = rerank_endpoints
|
||||
|
||||
return adapter
|
||||
|
||||
|
||||
async def test_rerank_basic_functionality():
|
||||
adapter = create_adapter()
|
||||
mock_response = MockResponse(json_data={"rankings": [{"index": 0, "logit": 0.5}]})
|
||||
mock_session = MockSession(mock_response)
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="test query", items=["item1", "item2"])
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.5
|
||||
|
||||
url, kwargs = mock_session.post_calls[0]
|
||||
payload = kwargs["json"]
|
||||
assert payload["model"] == "test-model"
|
||||
assert payload["query"] == {"text": "test query"}
|
||||
assert payload["passages"] == [{"text": "item1"}, {"text": "item2"}]
|
||||
|
||||
|
||||
async def test_missing_rankings_key():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(json_data={}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
assert len(result.data) == 0
|
||||
|
||||
|
||||
async def test_hosted_with_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"test-model": "https://model.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert url == "https://model.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_hosted_without_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), # This creates hosted config (integrate.api.nvidia.com).
|
||||
rerank_endpoints={}, # No endpoint mapping for test-model
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
|
||||
|
||||
async def test_hosted_model_not_in_endpoint_mapping():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(api_key="key"), rerank_endpoints={"other-model": "https://other.endpoint/rerank"}
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "https://integrate.api.nvidia.com" in url
|
||||
assert url != "https://other.endpoint/rerank"
|
||||
|
||||
|
||||
async def test_self_hosted_ignores_endpoint():
|
||||
adapter = create_adapter(
|
||||
config=NVIDIAConfig(url="http://localhost:8000", api_key=None),
|
||||
rerank_endpoints={"test-model": "https://model.endpoint/rerank"}, # This should be ignored for self-hosted.
|
||||
)
|
||||
mock_session = MockSession(MockResponse())
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
url, _ = mock_session.post_calls[0]
|
||||
assert "http://localhost:8000" in url
|
||||
assert "model.endpoint/rerank" not in url
|
||||
|
||||
|
||||
async def test_max_num_results():
|
||||
adapter = create_adapter()
|
||||
rankings = [{"index": 0, "logit": 0.8}, {"index": 1, "logit": 0.6}]
|
||||
mock_session = MockSession(MockResponse(json_data={"rankings": rankings}))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
result = await adapter.rerank(model="test-model", query="q", items=["a", "b"], max_num_results=1)
|
||||
|
||||
assert len(result.data) == 1
|
||||
assert result.data[0].index == 0
|
||||
assert result.data[0].relevance_score == 0.8
|
||||
|
||||
|
||||
async def test_http_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = MockSession(MockResponse(status=500, text_data="Server Error"))
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="status 500.*Server Error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_client_error():
|
||||
adapter = create_adapter()
|
||||
mock_session = AsyncMock()
|
||||
mock_session.__aenter__.side_effect = aiohttp.ClientError("Network error")
|
||||
|
||||
with patch("aiohttp.ClientSession", return_value=mock_session):
|
||||
with pytest.raises(ConnectionError, match="Failed to connect.*Network error"):
|
||||
await adapter.rerank(model="test-model", query="q", items=["a"])
|
||||
|
||||
|
||||
async def test_list_models_includes_configured_rerank_models():
|
||||
"""Test that list_models adds rerank models to the dynamic model list."""
|
||||
adapter = create_adapter()
|
||||
adapter.__provider_id__ = "nvidia"
|
||||
adapter.__provider_spec__ = MagicMock()
|
||||
|
||||
dynamic_ids = ["llm-1", "embedding-1"]
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
result = await adapter.list_models()
|
||||
|
||||
assert result is not None
|
||||
|
||||
# Check that the rerank models are added
|
||||
model_ids = [m.identifier for m in result]
|
||||
assert "nv-rerank-qa-mistral-4b:1" in model_ids
|
||||
assert "nvidia/nv-rerankqa-mistral-4b-v3" in model_ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in model_ids
|
||||
|
||||
rerank_models = [m for m in result if m.model_type == ModelType.rerank]
|
||||
|
||||
assert len(rerank_models) == 3
|
||||
|
||||
for m in rerank_models:
|
||||
assert m.provider_id == "nvidia"
|
||||
assert m.model_type == ModelType.rerank
|
||||
assert m.metadata == {}
|
||||
assert m.identifier in adapter._model_cache
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_has_no_duplicates():
|
||||
adapter = create_adapter()
|
||||
|
||||
dynamic_ids = [
|
||||
"llm-1",
|
||||
"nvidia/nv-rerankqa-mistral-4b-v3", # overlaps configured rerank ids
|
||||
"embedding-1",
|
||||
"llm-1",
|
||||
]
|
||||
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(return_value=dynamic_ids)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
assert len(ids) == len(set(ids))
|
||||
assert ids.count("nvidia/nv-rerankqa-mistral-4b-v3") == 1
|
||||
assert "nv-rerank-qa-mistral-4b:1" in ids
|
||||
assert "nvidia/llama-3.2-nv-rerankqa-1b-v2" in ids
|
||||
|
||||
|
||||
async def test_list_provider_model_ids_uses_configured_on_dynamic_failure():
|
||||
adapter = create_adapter()
|
||||
|
||||
# Simulate dynamic listing failure
|
||||
with patch.object(OpenAIMixin, "list_provider_model_ids", new=AsyncMock(side_effect=Exception)):
|
||||
ids = list(await adapter.list_provider_model_ids())
|
||||
|
||||
# Should still return configured rerank ids
|
||||
configured_ids = list(adapter.config.rerank_model_to_url.keys())
|
||||
assert set(ids) == set(configured_ids)
|
||||
|
|
@ -455,8 +455,8 @@ class TestOpenAIMixinAllowedModels:
|
|||
"""Test cases for allowed_models filtering functionality"""
|
||||
|
||||
async def test_list_models_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that list_models filters models based on allowed_models set"""
|
||||
mixin.allowed_models = {"some-mock-model-id", "another-mock-model-id"}
|
||||
"""Test that list_models filters models based on allowed_models"""
|
||||
mixin.config.allowed_models = ["some-mock-model-id", "another-mock-model-id"]
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
|
@ -470,8 +470,18 @@ class TestOpenAIMixinAllowedModels:
|
|||
assert "final-mock-model-id" not in model_ids
|
||||
|
||||
async def test_list_models_with_empty_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that empty allowed_models set allows all models"""
|
||||
assert len(mixin.allowed_models) == 0
|
||||
"""Test that empty allowed_models allows no models"""
|
||||
mixin.config.allowed_models = []
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
||||
assert result is not None
|
||||
assert len(result) == 0 # No models should be included
|
||||
|
||||
async def test_list_models_with_omitted_allowed_models(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test that omitted allowed_models allows all models"""
|
||||
assert mixin.config.allowed_models is None
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
result = await mixin.list_models()
|
||||
|
|
@ -488,7 +498,7 @@ class TestOpenAIMixinAllowedModels:
|
|||
self, mixin, mock_client_with_models, mock_client_context
|
||||
):
|
||||
"""Test that check_model_availability respects allowed_models"""
|
||||
mixin.allowed_models = {"final-mock-model-id"}
|
||||
mixin.config.allowed_models = ["final-mock-model-id"]
|
||||
|
||||
with mock_client_context(mixin, mock_client_with_models):
|
||||
assert await mixin.check_model_availability("final-mock-model-id")
|
||||
|
|
@ -536,7 +546,7 @@ class TestOpenAIMixinModelRegistration:
|
|||
|
||||
async def test_register_model_with_allowed_models_filter(self, mixin, mock_client_with_models, mock_client_context):
|
||||
"""Test model registration with allowed_models filtering"""
|
||||
mixin.allowed_models = {"some-mock-model-id"}
|
||||
mixin.config.allowed_models = ["some-mock-model-id"]
|
||||
|
||||
# Test with allowed model
|
||||
allowed_model = Model(
|
||||
|
|
@ -690,7 +700,7 @@ class TestOpenAIMixinCustomListProviderModelIds:
|
|||
mixin = CustomListProviderModelIdsImplementation(
|
||||
config=config, custom_model_ids=["model-1", "model-2", "model-3"]
|
||||
)
|
||||
mixin.allowed_models = ["model-1"]
|
||||
mixin.config.allowed_models = ["model-1"]
|
||||
|
||||
result = await mixin.list_models()
|
||||
|
||||
|
|
|
|||
|
|
@ -43,9 +43,15 @@ def embedding_dimension() -> int:
|
|||
@pytest.fixture(scope="session")
|
||||
def sample_chunks():
|
||||
"""Generates chunks that force multiple batches for a single document to expose ID conflicts."""
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
n, k = 10, 3
|
||||
sample = [
|
||||
Chunk(content=f"Sentence {i} from document {j}", metadata={"document_id": f"document-{j}"})
|
||||
Chunk(
|
||||
content=f"Sentence {i} from document {j}",
|
||||
chunk_id=generate_chunk_id(f"document-{j}", f"Sentence {i} from document {j}"),
|
||||
metadata={"document_id": f"document-{j}"},
|
||||
)
|
||||
for j in range(k)
|
||||
for i in range(n)
|
||||
]
|
||||
|
|
@ -53,6 +59,7 @@ def sample_chunks():
|
|||
[
|
||||
Chunk(
|
||||
content=f"Sentence {i} from document {j + k}",
|
||||
chunk_id=f"document-{j}-chunk-{i}",
|
||||
chunk_metadata=ChunkMetadata(
|
||||
document_id=f"document-{j + k}",
|
||||
chunk_id=f"document-{j}-chunk-{i}",
|
||||
|
|
@ -73,6 +80,7 @@ def sample_chunks_with_metadata():
|
|||
sample = [
|
||||
Chunk(
|
||||
content=f"Sentence {i} from document {j}",
|
||||
chunk_id=f"document-{j}-chunk-{i}",
|
||||
metadata={"document_id": f"document-{j}"},
|
||||
chunk_metadata=ChunkMetadata(
|
||||
document_id=f"document-{j}",
|
||||
|
|
|
|||
|
|
@ -49,9 +49,21 @@ def vector_store_id():
|
|||
|
||||
@pytest.fixture
|
||||
def sample_chunks():
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
return [
|
||||
Chunk(content="MOCK text content 1", mime_type="text/plain", metadata={"document_id": "mock-doc-1"}),
|
||||
Chunk(content="MOCK text content 1", mime_type="text/plain", metadata={"document_id": "mock-doc-2"}),
|
||||
Chunk(
|
||||
content="MOCK text content 1",
|
||||
chunk_id=generate_chunk_id("mock-doc-1", "MOCK text content 1"),
|
||||
mime_type="text/plain",
|
||||
metadata={"document_id": "mock-doc-1"},
|
||||
),
|
||||
Chunk(
|
||||
content="MOCK text content 1",
|
||||
chunk_id=generate_chunk_id("mock-doc-2", "MOCK text content 1"),
|
||||
mime_type="text/plain",
|
||||
metadata={"document_id": "mock-doc-2"},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -434,9 +434,15 @@ async def test_query_chunks_hybrid_tie_breaking(
|
|||
sqlite_vec_index, sample_embeddings, embedding_dimension, tmp_path_factory
|
||||
):
|
||||
"""Test tie-breaking and determinism when scores are equal."""
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
# Create two chunks with the same content and embedding
|
||||
chunk1 = Chunk(content="identical", metadata={"document_id": "docA"})
|
||||
chunk2 = Chunk(content="identical", metadata={"document_id": "docB"})
|
||||
chunk1 = Chunk(
|
||||
content="identical", chunk_id=generate_chunk_id("docA", "identical"), metadata={"document_id": "docA"}
|
||||
)
|
||||
chunk2 = Chunk(
|
||||
content="identical", chunk_id=generate_chunk_id("docB", "identical"), metadata={"document_id": "docB"}
|
||||
)
|
||||
chunks = [chunk1, chunk2]
|
||||
# Use the same embedding for both chunks to ensure equal scores
|
||||
same_embedding = sample_embeddings[0]
|
||||
|
|
|
|||
|
|
@ -135,10 +135,24 @@ async def test_insert_chunks_with_missing_document_id(vector_io_adapter):
|
|||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
||||
# Various document_id scenarios that shouldn't crash
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
chunks = [
|
||||
Chunk(content="has doc_id in metadata", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="no doc_id anywhere", metadata={"source": "test"}),
|
||||
Chunk(content="doc_id in chunk_metadata", chunk_metadata=ChunkMetadata(document_id="doc-3")),
|
||||
Chunk(
|
||||
content="has doc_id in metadata",
|
||||
chunk_id=generate_chunk_id("doc-1", "has doc_id in metadata"),
|
||||
metadata={"document_id": "doc-1"},
|
||||
),
|
||||
Chunk(
|
||||
content="no doc_id anywhere",
|
||||
chunk_id=generate_chunk_id("unknown", "no doc_id anywhere"),
|
||||
metadata={"source": "test"},
|
||||
),
|
||||
Chunk(
|
||||
content="doc_id in chunk_metadata",
|
||||
chunk_id=generate_chunk_id("doc-3", "doc_id in chunk_metadata"),
|
||||
chunk_metadata=ChunkMetadata(document_id="doc-3"),
|
||||
),
|
||||
]
|
||||
|
||||
# Should work without KeyError
|
||||
|
|
@ -151,7 +165,9 @@ async def test_document_id_with_invalid_type_raises_error():
|
|||
from llama_stack.apis.vector_io import Chunk
|
||||
|
||||
# Integer document_id should raise TypeError
|
||||
chunk = Chunk(content="test", metadata={"document_id": 12345})
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
chunk = Chunk(content="test", chunk_id=generate_chunk_id("test", "test"), metadata={"document_id": 12345})
|
||||
with pytest.raises(TypeError) as exc_info:
|
||||
_ = chunk.document_id
|
||||
assert "metadata['document_id'] must be a string" in str(exc_info.value)
|
||||
|
|
@ -159,7 +175,9 @@ async def test_document_id_with_invalid_type_raises_error():
|
|||
|
||||
|
||||
async def test_query_chunks_calls_underlying_index_and_returns(vector_io_adapter):
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1")], scores=[0.1])
|
||||
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
||||
|
||||
expected = QueryChunksResponse(chunks=[Chunk(content="c1", chunk_id=generate_chunk_id("test", "c1"))], scores=[0.1])
|
||||
fake_index = AsyncMock(query_chunks=AsyncMock(return_value=expected))
|
||||
vector_io_adapter.cache["db1"] = fake_index
|
||||
|
||||
|
|
|
|||
|
|
@ -18,13 +18,12 @@ from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
|
|||
|
||||
|
||||
def test_generate_chunk_id():
|
||||
chunks = [
|
||||
Chunk(content="test", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test ", metadata={"document_id": "doc-1"}),
|
||||
Chunk(content="test 3", metadata={"document_id": "doc-1"}),
|
||||
]
|
||||
"""Test that generate_chunk_id produces expected hashes."""
|
||||
chunk_id1 = generate_chunk_id("doc-1", "test")
|
||||
chunk_id2 = generate_chunk_id("doc-1", "test ")
|
||||
chunk_id3 = generate_chunk_id("doc-1", "test 3")
|
||||
|
||||
chunk_ids = sorted([chunk.chunk_id for chunk in chunks])
|
||||
chunk_ids = sorted([chunk_id1, chunk_id2, chunk_id3])
|
||||
assert chunk_ids == [
|
||||
"31d1f9a3-c8d2-66e7-3c37-af2acd329778",
|
||||
"d07dade7-29c0-cda7-df29-0249a1dcbc3e",
|
||||
|
|
@ -33,42 +32,49 @@ def test_generate_chunk_id():
|
|||
|
||||
|
||||
def test_generate_chunk_id_with_window():
|
||||
chunk = Chunk(content="test", metadata={"document_id": "doc-1"})
|
||||
"""Test that generate_chunk_id with chunk_window produces different IDs."""
|
||||
# Create a chunk object to match the original test behavior (passing object to generate_chunk_id)
|
||||
chunk = Chunk(content="test", chunk_id="placeholder", metadata={"document_id": "doc-1"})
|
||||
chunk_id1 = generate_chunk_id("doc-1", chunk, chunk_window="0-1")
|
||||
chunk_id2 = generate_chunk_id("doc-1", chunk, chunk_window="1-2")
|
||||
assert chunk_id1 == "8630321a-d9cb-2bb6-cd28-ebf68dafd866"
|
||||
assert chunk_id2 == "13a1c09a-cbda-b61a-2d1a-7baa90888685"
|
||||
# Verify that different windows produce different IDs
|
||||
assert chunk_id1 != chunk_id2
|
||||
assert len(chunk_id1) == 36 # Valid UUID format
|
||||
assert len(chunk_id2) == 36 # Valid UUID format
|
||||
|
||||
|
||||
def test_chunk_id():
|
||||
# Test with existing chunk ID
|
||||
chunk_with_id = Chunk(content="test", metadata={"document_id": "existing-id"})
|
||||
assert chunk_with_id.chunk_id == "11704f92-42b6-61df-bf85-6473e7708fbd"
|
||||
|
||||
# Test with document ID in metadata
|
||||
chunk_with_doc_id = Chunk(content="test", metadata={"document_id": "doc-1"})
|
||||
assert chunk_with_doc_id.chunk_id == generate_chunk_id("doc-1", "test")
|
||||
|
||||
# Test chunks with ChunkMetadata
|
||||
chunk_with_metadata = Chunk(
|
||||
def test_chunk_creation_with_explicit_id():
|
||||
"""Test that chunks can be created with explicit chunk_id."""
|
||||
chunk_id = generate_chunk_id("doc-1", "test")
|
||||
chunk = Chunk(
|
||||
content="test",
|
||||
metadata={"document_id": "existing-id", "chunk_id": "chunk-id-1"},
|
||||
chunk_id=chunk_id,
|
||||
metadata={"document_id": "doc-1"},
|
||||
)
|
||||
assert chunk.chunk_id == chunk_id
|
||||
assert chunk.chunk_id == "31d1f9a3-c8d2-66e7-3c37-af2acd329778"
|
||||
|
||||
|
||||
def test_chunk_with_metadata():
|
||||
"""Test chunks with ChunkMetadata."""
|
||||
chunk_id = "chunk-id-1"
|
||||
chunk = Chunk(
|
||||
content="test",
|
||||
chunk_id=chunk_id,
|
||||
metadata={"document_id": "existing-id"},
|
||||
chunk_metadata=ChunkMetadata(document_id="document_1"),
|
||||
)
|
||||
assert chunk_with_metadata.chunk_id == "chunk-id-1"
|
||||
|
||||
# Test with no ID or document ID
|
||||
chunk_without_id = Chunk(content="test")
|
||||
generated_id = chunk_without_id.chunk_id
|
||||
assert isinstance(generated_id, str) and len(generated_id) == 36 # Should be a valid UUID
|
||||
assert chunk.chunk_id == "chunk-id-1"
|
||||
assert chunk.document_id == "existing-id" # metadata takes precedence
|
||||
|
||||
|
||||
def test_stored_chunk_id_alias():
|
||||
# Test with existing chunk ID alias
|
||||
chunk_with_alias = Chunk(content="test", metadata={"document_id": "existing-id", "chunk_id": "chunk-id-1"})
|
||||
assert chunk_with_alias.chunk_id == "chunk-id-1"
|
||||
serialized_chunk = chunk_with_alias.model_dump()
|
||||
assert serialized_chunk["stored_chunk_id"] == "chunk-id-1"
|
||||
# showing chunk_id is not serialized (i.e., a computed field)
|
||||
assert "chunk_id" not in serialized_chunk
|
||||
assert chunk_with_alias.stored_chunk_id == "chunk-id-1"
|
||||
def test_chunk_serialization():
|
||||
"""Test that chunk_id is properly serialized."""
|
||||
chunk = Chunk(
|
||||
content="test",
|
||||
chunk_id="test-chunk-id",
|
||||
metadata={"document_id": "doc-1"},
|
||||
)
|
||||
serialized_chunk = chunk.model_dump()
|
||||
assert serialized_chunk["chunk_id"] == "test-chunk-id"
|
||||
assert "chunk_id" in serialized_chunk
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue