From 9d23c063d52db2c9dcb99c95f99fc950670c981d Mon Sep 17 00:00:00 2001 From: Fred Reiss Date: Thu, 19 Dec 2024 11:28:15 -0800 Subject: [PATCH] Fix regressions in inline vLLM provider --- .../providers/inline/inference/vllm/vllm.py | 35 ++++++++++++++----- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index 0e7ba872c..72aa2200a 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -50,7 +50,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.formatter = ChatFormat(Tokenizer.get_instance()) async def initialize(self): - log.info("Initializing vLLM inference adapter") + log.info("Initializing vLLM inference provider.") # Disable usage stats reporting. This would be a surprising thing for most # people to find out was on by default. @@ -79,14 +79,33 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async def shutdown(self): """Shutdown the vLLM inference adapter.""" - log.info("Shutting down vLLM inference adapter") + log.info("Shutting down vLLM inference provider.") if self.engine: self.engine.shutdown_background_loop() - async def register_model(self, model: Model) -> None: - raise ValueError( - "You cannot dynamically add a model to a running vllm instance" - ) + # Note that the return type of the superclass method is WRONG + async def register_model(self, model: Model) -> Model: + """ + Callback that is called when the server associates an inference endpoint + with an inference provider. + + :param model: Object that encapsulates parameters necessary for identifying + a specific LLM. + + :returns: The input ``Model`` object. It may or may not be permissible + to change fields before returning this object. + """ + log.info(f"Registering model {model.identifier} with vLLM inference provider.") + # The current version of this provided is hard-coded to serve only + # the model specified in the YAML config file. + configured_model = resolve_model(self.config.model) + registered_model = resolve_model(model.model_id) + + if configured_model.core_model_id != registered_model.core_model_id: + raise ValueError(f"Requested model '{model.identifier}' is different from " + f"model '{self.config.model}' that this provider " + f"is configured to serve") + return model def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams: if sampling_params is None: @@ -160,7 +179,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): log.info("Sampling params: %s", sampling_params) request_id = _random_uuid() - prompt = chat_completion_request_to_prompt(request, self.formatter) + prompt = chat_completion_request_to_prompt(request, self.config.model, self.formatter) vllm_sampling_params = self._sampling_params(request.sampling_params) results_generator = self.engine.generate( prompt, vllm_sampling_params, request_id @@ -216,7 +235,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): stream, self.formatter ): yield chunk - + async def embeddings( self, model_id: str, contents: list[InterleavedTextMedia] ) -> EmbeddingsResponse: