From 8b2376bfb319c64edb0b96b804ae71f780f9ed5b Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Fri, 10 Jan 2025 16:35:16 -0800 Subject: [PATCH] Add inline vLLM inference provider to regression tests and fix regressions (#662) # What does this PR do? This PR adds the inline vLLM inference provider to the regression tests for inference providers. The PR also fixes some regressions in that inference provider in order to make the tests pass. ## Test Plan Command to run the new tests (from root of project): ``` pytest \ -vvv \ llama_stack/providers/tests/inference/test_text_inference.py \ --providers inference=vllm \ --inference-model meta-llama/Llama-3.2-3B-Instruct \ ``` Output of the above command after these changes: ``` /mnt/datadisk1/freiss/llama/env/lib/python3.12/site-packages/pytest_asyncio/plugin.py:207: PytestDeprecationWarning: The configuration option "asyncio_default_fixture_loop_scope" is unset. The event loop scope for asynchronous fixtures will default to the fixture caching scope. Future versions of pytest-asyncio will default the loop scope for asynchronous fixtures to function scope. Set the default fixture loop scope explicitly in order to avoid unexpected behavior in the future. Valid fixture loop scopes are: "function", "class", "module", "package", "session" warnings.warn(PytestDeprecationWarning(_DEFAULT_FIXTURE_LOOP_SCOPE_UNSET)) =================================================================== test session starts =================================================================== platform linux -- Python 3.12.7, pytest-8.3.4, pluggy-1.5.0 -- /mnt/datadisk1/freiss/llama/env/bin/python3.12 cachedir: .pytest_cache rootdir: /mnt/datadisk1/freiss/llama/llama-stack configfile: pyproject.toml plugins: asyncio-0.25.0, anyio-4.6.2.post1 asyncio: mode=Mode.STRICT, asyncio_default_fixture_loop_scope=None collected 9 items llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_model_list[-vllm] PASSED [ 11%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion[-vllm] SKIPPED (Other inference providers don't support completion() yet) [ 22%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_logprobs[-vllm] SKIPPED (Other inference providers don't support completion() yet) [ 33%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_completion_structured_output[-vllm] SKIPPED (This test is not quite robust) [ 44%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_non_streaming[-vllm] PASSED [ 55%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_structured_output[-vllm] SKIPPED (Other inference providers don't support structured output yet) [ 66%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_streaming[-vllm] PASSED [ 77%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling[-vllm] PASSED [ 88%] llama_stack/providers/tests/inference/test_text_inference.py::TestInference::test_chat_completion_with_tool_calling_streaming[-vllm] PASSED [100%] ======================================================== 5 passed, 4 skipped, 2 warnings in 25.56s ======================================================== Task was destroyed but it is pending! task: cb=[_log_task_completion(error_callback=>)() at /mnt/datadisk1/freiss/llama/env/lib/python3.12/site-packages/vllm/engine/async_llm_engine.py:45, shield.._inner_done_callback() at /mnt/datadisk1/freiss/llama/env/lib/python3.12/asyncio/tasks.py:905]> [rank0]:[W1219 11:38:34.689424319 ProcessGroupNCCL.cpp:1250] Warning: WARNING: process group has NOT been destroyed before we destruct ProcessGroupNCCL. On normal program exit, the application should call destroy_process_group to ensure that any pending NCCL operations have finished in this process. In rare cases this process can exit before this point and block the progress of another member of the process group. This constraint has always been present, but this warning has only been added since PyTorch 2.4 (function operator()) ``` The warning about "asyncio_default_fixture_loop_scope" appears to be due to my environment having a newer version of pytest-asyncio. The warning about a pending task appears to be due to a bug in `vllm.AsyncLLMEngine.shutdown_background_loop()`. It looks like that method returns without stopping a pending task. I will look into that issue separately. ## Sources ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [X] Ran pre-commit to handle lint / formatting issues. - [X] Read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md), Pull Request section? - [ ] Updated relevant documentation. - [X] Wrote necessary unit or integration tests. --- .../providers/inline/inference/vllm/vllm.py | 39 +++++++++++++++---- .../providers/tests/inference/fixtures.py | 28 ++++++++++++- .../tests/inference/test_text_inference.py | 20 +++++----- 3 files changed, 69 insertions(+), 18 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 03bcad3e9..0f1045845 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -63,7 +63,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): - log.info("Initializing vLLM inference adapter") + log.info("Initializing vLLM inference provider.") # Disable usage stats reporting. This would be a surprising thing for most # people to find out was on by default. @@ -91,15 +91,36 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.engine = AsyncLLMEngine.from_engine_args(engine_args) async def shutdown(self): - """Shutdown the vLLM inference adapter.""" - log.info("Shutting down vLLM inference adapter") + """Shut down the vLLM inference adapter.""" + log.info("Shutting down vLLM inference provider.") if self.engine: self.engine.shutdown_background_loop() - async def register_model(self, model: Model) -> None: - raise ValueError( - "You cannot dynamically add a model to a running vllm instance" - ) + # Note that the return type of the superclass method is WRONG + async def register_model(self, model: Model) -> Model: + """ + Callback that is called when the server associates an inference endpoint + with an inference provider. + + :param model: Object that encapsulates parameters necessary for identifying + a specific LLM. + + :returns: The input ``Model`` object. It may or may not be permissible + to change fields before returning this object. + """ + log.info(f"Registering model {model.identifier} with vLLM inference provider.") + # The current version of this provided is hard-coded to serve only + # the model specified in the YAML config file. + configured_model = resolve_model(self.config.model) + registered_model = resolve_model(model.model_id) + + if configured_model.core_model_id != registered_model.core_model_id: + raise ValueError( + f"Requested model '{model.identifier}' is different from " + f"model '{self.config.model}' that this provider " + f"is configured to serve" + ) + return model def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: if sampling_params is None: @@ -163,7 +184,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = await chat_completion_request_to_prompt(request, self.formatter) + prompt = await chat_completion_request_to_prompt( + request, self.config.model, self.formatter + ) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index d956caa93..b6653b65d 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -15,6 +15,7 @@ from llama_stack.distribution.datatypes import Api, Provider from llama_stack.providers.inline.inference.meta_reference import ( MetaReferenceInferenceConfig, ) +from llama_stack.providers.inline.inference.vllm import VLLMConfig from llama_stack.providers.remote.inference.bedrock import BedrockConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig @@ -105,6 +106,26 @@ def inference_ollama(inference_model) -> ProviderFixture: ) +@pytest_asyncio.fixture(scope="session") +def inference_vllm(inference_model) -> ProviderFixture: + inference_model = ( + [inference_model] if isinstance(inference_model, str) else inference_model + ) + return ProviderFixture( + providers=[ + Provider( + provider_id=f"vllm-{i}", + provider_type="inline::vllm", + config=VLLMConfig( + model=m, + enforce_eager=True, # Make test run faster + ).model_dump(), + ) + for i, m in enumerate(inference_model) + ] + ) + + @pytest.fixture(scope="session") def inference_vllm_remote() -> ProviderFixture: return ProviderFixture( @@ -253,6 +274,7 @@ INFERENCE_FIXTURES = [ "ollama", "fireworks", "together", + "vllm", "groq", "vllm_remote", "remote", @@ -286,4 +308,8 @@ async def inference_stack(request, inference_model): ], ) - return test_stack.impls[Api.inference], test_stack.impls[Api.models] + # Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended + yield test_stack.impls[Api.inference], test_stack.impls[Api.models] + + # Cleanup code that runs after test case completion + await test_stack.impls[Api.inference].shutdown() diff --git a/llama_stack/providers/tests/inference/test_text_inference.py b/llama_stack/providers/tests/inference/test_text_inference.py index 7776c7959..e2c939914 100644 --- a/llama_stack/providers/tests/inference/test_text_inference.py +++ b/llama_stack/providers/tests/inference/test_text_inference.py @@ -86,7 +86,9 @@ def sample_tool_definition(): class TestInference: - @pytest.mark.asyncio + # Session scope for asyncio because the tests in this class all + # share the same provider instance. + @pytest.mark.asyncio(loop_scope="session") async def test_model_list(self, inference_model, inference_stack): _, models_impl = inference_stack response = await models_impl.list_models() @@ -102,7 +104,7 @@ class TestInference: assert model_def is not None - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_completion(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -147,7 +149,7 @@ class TestInference: last = chunks[-1] assert last.stop_reason == StopReason.out_of_tokens - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_completion_logprobs(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -202,7 +204,7 @@ class TestInference: else: # no token, no logprobs assert not chunk.logprobs, "Logprobs should be empty" - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") @pytest.mark.skip("This test is not quite robust") async def test_completion_structured_output(self, inference_model, inference_stack): inference_impl, _ = inference_stack @@ -247,7 +249,7 @@ class TestInference: assert answer.year_born == "1963" assert answer.year_retired == "2003" - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_non_streaming( self, inference_model, inference_stack, common_params, sample_messages ): @@ -264,7 +266,7 @@ class TestInference: assert isinstance(response.completion_message.content, str) assert len(response.completion_message.content) > 0 - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_structured_output( self, inference_model, inference_stack, common_params ): @@ -335,7 +337,7 @@ class TestInference: with pytest.raises(ValidationError): AnswerFormat.model_validate_json(response.completion_message.content) - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_streaming( self, inference_model, inference_stack, common_params, sample_messages ): @@ -362,7 +364,7 @@ class TestInference: end = grouped[ChatCompletionResponseEventType.complete][0] assert end.event.stop_reason == StopReason.end_of_turn - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_with_tool_calling( self, inference_model, @@ -409,7 +411,7 @@ class TestInference: assert "location" in call.arguments assert "San Francisco" in call.arguments["location"] - @pytest.mark.asyncio + @pytest.mark.asyncio(loop_scope="session") async def test_chat_completion_with_tool_calling_streaming( self, inference_model,