diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 12c6c0370..1caae9687 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -88,23 +88,25 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): """ Callback that is called when the server associates an inference endpoint with an inference provider. - + :param model: Object that encapsulates parameters necessary for identifying a specific LLM. - + :returns: The input ``Model`` object. It may or may not be permissible to change fields before returning this object. """ - log.info(f"Registering model {model.identifier} with vLLM inference provider.") - # The current version of this provided is hard-coded to serve only + log.info(f"Registering model {model.identifier} with vLLM inference provider.") + # The current version of this provided is hard-coded to serve only # the model specified in the YAML config file. configured_model = resolve_model(self.config.model) registered_model = resolve_model(model.model_id) - + if configured_model.core_model_id != registered_model.core_model_id: - raise ValueError(f"Requested model '{model.identifier}' is different from " - f"model '{self.config.model}' that this provider " - f"is configured to serve") + raise ValueError( + f"Requested model '{model.identifier}' is different from " + f"model '{self.config.model}' that this provider " + f"is configured to serve" + ) return model def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: @@ -169,8 +171,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = await chat_completion_request_to_prompt(request, self.config.model, - self.formatter) + prompt = await chat_completion_request_to_prompt( + request, self.config.model, self.formatter + ) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id @@ -226,7 +229,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): stream, self.formatter ): yield chunk - + async def embeddings( self, model_id: str, contents: List[InterleavedContent] ) -> EmbeddingsResponse: diff --git a/llama_stack/providers/tests/inference/fixtures.py b/llama_stack/providers/tests/inference/fixtures.py index f3c7df404..524bc69db 100644 --- a/llama_stack/providers/tests/inference/fixtures.py +++ b/llama_stack/providers/tests/inference/fixtures.py @@ -292,6 +292,6 @@ async def inference_stack(request, inference_model): # Pytest yield fixture; see https://docs.pytest.org/en/stable/how-to/fixtures.html#yield-fixtures-recommended yield test_stack.impls[Api.inference], test_stack.impls[Api.models] - + # Cleanup code that runs after test case completion await test_stack.impls[Api.inference].shutdown()