fix tests

This commit is contained in:
Ashwin Bharambe 2025-10-20 14:16:53 -07:00
parent 9c9f5f059a
commit 85143e7316

View file

@ -10,7 +10,7 @@ from unittest.mock import AsyncMock
import pytest import pytest
from llama_stack.apis.models import Model, ModelType from llama_stack.apis.models import ListModelsResponse, Model, ModelType
from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig from llama_stack.core.datatypes import QualifiedModel, StackRunConfig, StorageConfig, VectorStoresConfig
from llama_stack.core.stack import validate_vector_stores_config from llama_stack.core.stack import validate_vector_stores_config
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -32,10 +32,10 @@ class TestVectorStoresValidation:
), ),
) )
mock_models = AsyncMock() mock_models = AsyncMock()
mock_models.list_models.return_value = [] mock_models.list_models.return_value = ListModelsResponse(data=[])
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
await validate_vector_stores_config(run_config, {Api.models: mock_models}) await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})
async def test_validate_success(self): async def test_validate_success(self):
"""Test validation passes with valid model.""" """Test validation passes with valid model."""
@ -52,14 +52,16 @@ class TestVectorStoresValidation:
), ),
) )
mock_models = AsyncMock() mock_models = AsyncMock()
mock_models.list_models.return_value = [ mock_models.list_models.return_value = ListModelsResponse(
Model( data=[
identifier="p/valid", # Must match provider_id/model_id format Model(
model_type=ModelType.embedding, identifier="p/valid", # Must match provider_id/model_id format
metadata={"embedding_dimension": 768}, model_type=ModelType.embedding,
provider_id="p", metadata={"embedding_dimension": 768},
provider_resource_id="valid", provider_id="p",
) provider_resource_id="valid",
] )
]
)
await validate_vector_stores_config(run_config, {Api.models: mock_models}) await validate_vector_stores_config(run_config.vector_stores, {Api.models: mock_models})