diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 037e99819..1243881b9 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -31,7 +31,7 @@ from llama_stack.apis.inference import ( ToolChoice, UserMessage, ) -from llama_stack.apis.models import Model +from llama_stack.apis.models import ListModelsResponse, Model from .utils import group_chunks @@ -92,12 +92,13 @@ class TestInference: async def test_model_list(self, inference_model, inference_stack): _, models_impl = inference_stack response = await models_impl.list_models() - assert isinstance(response, list) - assert len(response) >= 1 - assert all(isinstance(model, Model) for model in response) + assert isinstance(response, ListModelsResponse) + assert isinstance(response.data, list) + assert len(response.data) >= 1 + assert all(isinstance(model, Model) for model in response.data) model_def = None - for model in response: + for model in response.data: if model.identifier == inference_model: model_def = model break