diff --git a/tests/unit/core/test_stack_validation.py b/tests/unit/core/test_stack_validation.py index b50aff559..fa5348d1c 100644 --- a/tests/unit/core/test_stack_validation.py +++ b/tests/unit/core/test_stack_validation.py @@ -10,7 +10,7 @@ from unittest.mock import AsyncMock import pytest -from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.models import ListModelsResponse, Model, ModelType from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig from llama_stack.core.stack import validate_vector_stores_config from llama_stack.providers.datatypes import Api @@ -32,10 +32,10 @@ class TestVectorStoresValidation: ), ) mock_models = AsyncMock() - mock_models.list_models.return_value = [] + mock_models.list_models.return_value = ListModelsResponse(data=[]) with pytest.raises(ValueError, match="not found"): - await validate_vector_stores_config(run_config, {Api.models: mock_models}) + await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models}) async def test_validate_success(self): """Test validation passes with valid model.""" @@ -52,14 +52,16 @@ class TestVectorStoresValidation: ), ) mock_models = AsyncMock() - mock_models.list_models.return_value = [ - Model( - identifier="p/valid", # Must match provider_id/model_id format - model_type=ModelType.embedding, - metadata={"embedding_dimension": 768}, - provider_id="p", - provider_resource_id="valid", - ) - ] + mock_models.list_models.return_value = ListModelsResponse( + data=[ + Model( + identifier="p/valid", # Must match provider_id/model_id format + model_type=ModelType.embedding, + metadata={"embedding_dimension": 768}, + provider_id="p", + provider_resource_id="valid", + ) + ] + ) - await validate_vector_stores_config(run_config, {Api.models: mock_models}) + await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})