mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-02 08:44:44 +00:00
Fix regressions in inline vLLM provider
This commit is contained in:
parent
b74a05114f
commit
9d23c063d5
1 changed files with 27 additions and 8 deletions
|
@ -50,7 +50,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self):
|
||||
log.info("Initializing vLLM inference adapter")
|
||||
log.info("Initializing vLLM inference provider.")
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
# people to find out was on by default.
|
||||
|
@ -79,14 +79,33 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def shutdown(self):
|
||||
"""Shutdown the vLLM inference adapter."""
|
||||
log.info("Shutting down vLLM inference adapter")
|
||||
log.info("Shutting down vLLM inference provider.")
|
||||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
async def register_model(self, model: Model) -> None:
|
||||
raise ValueError(
|
||||
"You cannot dynamically add a model to a running vllm instance"
|
||||
)
|
||||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint
|
||||
with an inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying
|
||||
a specific LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible
|
||||
to change fields before returning this object.
|
||||
"""
|
||||
log.info(f"Registering model {model.identifier} with vLLM inference provider.")
|
||||
# The current version of this provided is hard-coded to serve only
|
||||
# the model specified in the YAML config file.
|
||||
configured_model = resolve_model(self.config.model)
|
||||
registered_model = resolve_model(model.model_id)
|
||||
|
||||
if configured_model.core_model_id != registered_model.core_model_id:
|
||||
raise ValueError(f"Requested model '{model.identifier}' is different from "
|
||||
f"model '{self.config.model}' that this provider "
|
||||
f"is configured to serve")
|
||||
return model
|
||||
|
||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||
if sampling_params is None:
|
||||
|
@ -160,7 +179,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
log.info("Sampling params: %s", sampling_params)
|
||||
request_id = _random_uuid()
|
||||
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
prompt = chat_completion_request_to_prompt(request, self.config.model, self.formatter)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(
|
||||
prompt, vllm_sampling_params, request_id
|
||||
|
@ -216,7 +235,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
stream, self.formatter
|
||||
):
|
||||
yield chunk
|
||||
|
||||
|
||||
async def embeddings(
|
||||
self, model_id: str, contents: list[InterleavedTextMedia]
|
||||
) -> EmbeddingsResponse:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue