mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Fix regressions in inline vLLM provider
This commit is contained in:
parent
b74a05114f
commit
9d23c063d5
1 changed files with 27 additions and 8 deletions
|
@ -50,7 +50,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||||
|
|
||||||
async def initialize(self):
|
async def initialize(self):
|
||||||
log.info("Initializing vLLM inference adapter")
|
log.info("Initializing vLLM inference provider.")
|
||||||
|
|
||||||
# Disable usage stats reporting. This would be a surprising thing for most
|
# Disable usage stats reporting. This would be a surprising thing for most
|
||||||
# people to find out was on by default.
|
# people to find out was on by default.
|
||||||
|
@ -79,14 +79,33 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def shutdown(self):
|
async def shutdown(self):
|
||||||
"""Shutdown the vLLM inference adapter."""
|
"""Shutdown the vLLM inference adapter."""
|
||||||
log.info("Shutting down vLLM inference adapter")
|
log.info("Shutting down vLLM inference provider.")
|
||||||
if self.engine:
|
if self.engine:
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> None:
|
# Note that the return type of the superclass method is WRONG
|
||||||
raise ValueError(
|
async def register_model(self, model: Model) -> Model:
|
||||||
"You cannot dynamically add a model to a running vllm instance"
|
"""
|
||||||
)
|
Callback that is called when the server associates an inference endpoint
|
||||||
|
with an inference provider.
|
||||||
|
|
||||||
|
:param model: Object that encapsulates parameters necessary for identifying
|
||||||
|
a specific LLM.
|
||||||
|
|
||||||
|
:returns: The input ``Model`` object. It may or may not be permissible
|
||||||
|
to change fields before returning this object.
|
||||||
|
"""
|
||||||
|
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||||
|
# The current version of this provided is hard-coded to serve only
|
||||||
|
# the model specified in the YAML config file.
|
||||||
|
configured_model = resolve_model(self.config.model)
|
||||||
|
registered_model = resolve_model(model.model_id)
|
||||||
|
|
||||||
|
if configured_model.core_model_id != registered_model.core_model_id:
|
||||||
|
raise ValueError(f"Requested model '{model.identifier}' is different from "
|
||||||
|
f"model '{self.config.model}' that this provider "
|
||||||
|
f"is configured to serve")
|
||||||
|
return model
|
||||||
|
|
||||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
|
@ -160,7 +179,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
log.info("Sampling params: %s", sampling_params)
|
log.info("Sampling params: %s", sampling_params)
|
||||||
request_id = _random_uuid()
|
request_id = _random_uuid()
|
||||||
|
|
||||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
prompt = chat_completion_request_to_prompt(request, self.config.model, self.formatter)
|
||||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||||
results_generator = self.engine.generate(
|
results_generator = self.engine.generate(
|
||||||
prompt, vllm_sampling_params, request_id
|
prompt, vllm_sampling_params, request_id
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue