Fix regressions in inline vLLM provider

This commit is contained in:
Fred Reiss 2024-12-19 11:28:15 -08:00
parent b74a05114f
commit 9d23c063d5

View file

@ -50,7 +50,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self):
log.info("Initializing vLLM inference adapter")
log.info("Initializing vLLM inference provider.")
# Disable usage stats reporting. This would be a surprising thing for most
# people to find out was on by default.
@ -79,14 +79,33 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def shutdown(self):
"""Shutdown the vLLM inference adapter."""
log.info("Shutting down vLLM inference adapter")
log.info("Shutting down vLLM inference provider.")
if self.engine:
self.engine.shutdown_background_loop()
async def register_model(self, model: Model) -> None:
raise ValueError(
"You cannot dynamically add a model to a running vllm instance"
)
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint
with an inference provider.
:param model: Object that encapsulates parameters necessary for identifying
a specific LLM.
:returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object.
"""
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
# The current version of this provided is hard-coded to serve only
# the model specified in the YAML config file.
configured_model = resolve_model(self.config.model)
registered_model = resolve_model(model.model_id)
if configured_model.core_model_id != registered_model.core_model_id:
raise ValueError(f"Requested model '{model.identifier}' is different from "
f"model '{self.config.model}' that this provider "
f"is configured to serve")
return model
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
@ -160,7 +179,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
log.info("Sampling params: %s", sampling_params)
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
prompt = chat_completion_request_to_prompt(request, self.config.model, self.formatter)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
@ -216,7 +235,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
stream, self.formatter
):
yield chunk
async def embeddings(
self, model_id: str, contents: list[InterleavedTextMedia]
) -> EmbeddingsResponse: