diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d9c0cb188..e935cd8c2 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) +from llama_stack.providers.inline.inference.vllm import VLLMConfig from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig @@ -104,6 +105,26 @@ def inference_ollama(inference_model) -> ProviderFixture: ) +@pytest_asyncio.fixture(scope="session") +def inference_vllm(inference_model) -> ProviderFixture: + inference_model = ( + [inference_model] if isinstance(inference_model, str) else inference_model + ) + return ProviderFixture( + providers=[ + Provider( + provider_id=f"vllm-{i}", + provider_type="inline::vllm", + config=VLLMConfig( + model=m, + enforce_eager=True, # Make test run faster + ).model_dump(), + ) + for i, m in enumerate(inference_model) + ] + ) + + @pytest.fixture(scope="session") def inference_vllm_remote() -> ProviderFixture: return ProviderFixture( @@ -222,6 +243,7 @@ INFERENCE_FIXTURES = [ "ollama", "fireworks", "together", + "vllm", "vllm_remote", "remote", "bedrock", @@ -254,4 +276,8 @@ async def inference_stack(request, inference_model): ], ) - return test_stack.impls[Api.inference], test_stack.impls[Api.models] + # Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended + yield test_stack.impls[Api.inference], test_stack.impls[Api.models] + + # Cleanup code that runs after test case completion + await test_stack.impls[Api.inference].shutdown() diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 99a62ac08..55a3bd0a8 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -67,7 +67,9 @@ def sample_tool_definition(): class TestInference: - @pytest.mark.asyncio + # Session scope for asyncio because the tests in this class all + # share the same provider instance. + @pytest.mark.asyncio(loop_scope="session") async def test_model_list(self, inference_model, inference_stack): _, models_impl = inference_stack response = await models_impl.list_models() @@ -83,7 +85,7 @@ class TestInference: assert model_def is not None - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -128,7 +130,7 @@ class TestInference: last = chunks[-1] assert last.stop_reason == StopReason.out_of_tokens - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_completion_logprobs(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -183,7 +185,7 @@ class TestInference: else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.skip("This test is not quite robust") async def test_completion_structured_output(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -227,7 +229,7 @@ class TestInference: assert answer.year_born == "1963" assert answer.year_retired == "2003" - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_non_streaming( self, inference_model, inference_stack, common_params, sample_messages ): @@ -244,7 +246,7 @@ class TestInference: assert isinstance(response.completion_message.content, str) assert len(response.completion_message.content) > 0 - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_structured_output( self, inference_model, inference_stack, common_params ): @@ -314,7 +316,7 @@ class TestInference: with pytest.raises(ValidationError): AnswerFormat.model_validate_json(response.completion_message.content) - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_streaming( self, inference_model, inference_stack, common_params, sample_messages ): @@ -341,7 +343,7 @@ class TestInference: end = grouped[ChatCompletionResponseEventType.complete][0] assert end.event.stop_reason == StopReason.end_of_turn - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_with_tool_calling( self, inference_model, @@ -380,7 +382,7 @@ class TestInference: assert "location" in call.arguments assert "San Francisco" in call.arguments["location"] - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_with_tool_calling_streaming( self, inference_model,