Change line width and remove __del__

This commit is contained in:
Fred Reiss 2025-01-29 18:42:18 -08:00 committed by Ashwin Bharambe
parent 74c8504f50
commit 24cc7a777c

View file

@ -11,8 +11,10 @@ import re
import uuid import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
# These vLLM modules contain names that overlap with Llama Stack names, import llama_models.sku_list
# so we import fully-qualified names
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol import vllm.entrypoints.openai.protocol
import vllm.sampling_params import vllm.sampling_params
@ -86,8 +88,8 @@ from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_
# Constants go here # Constants go here
# Map from Hugging Face model architecture name to appropriate tool parser. # Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers # See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# for the full list of available parsers. # available parsers.
# TODO: Expand this list # TODO: Expand this list
CONFIG_TYPE_TO_TOOL_PARSER = { CONFIG_TYPE_TO_TOOL_PARSER = {
"GraniteConfig": "granite", "GraniteConfig": "granite",
@ -135,33 +137,28 @@ def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore response_format: Optional[ResponseFormat], # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams: ) -> vllm.sampling_params.GuidedDecodingParams:
""" """
Translate constrained decoding parameters from Llama Stack's Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding :param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
info. Can be ``None``, indicating no constraints. indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference :returns: The equivalent dataclass object for the low-level inference layer of vLLM.
layer of vLLM.
""" """
if response_format is None: if response_format is None:
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() # As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
# returns an invalid value that crashes the executor on some code # value that crashes the executor on some code paths. Use ``None`` instead.
# paths. Use ``None`` instead.
return None return None
# Llama Stack currently implements fewer types of constrained # Llama Stack currently implements fewer types of constrained decoding than vLLM does.
# decoding than vLLM does. Translate the types that exist and # Translate the types that exist and detect if Llama Stack adds new ones.
# detect if Llama Stack adds new ones.
if isinstance(response_format, JsonSchemaResponseFormat): if isinstance(response_format, JsonSchemaResponseFormat):
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema) return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
elif isinstance(response_format, GrammarResponseFormat): elif isinstance(response_format, GrammarResponseFormat):
# BNF grammar. # BNF grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM # Llama Stack uses the parse tree of the grammar, while vLLM uses the string
# uses the string representation of the grammar. # representation of the grammar.
raise TypeError( raise TypeError(
"Constrained decoding with BNF grammars is not " "Constrained decoding with BNF grammars is not currently implemented, because the "
"currently implemented, because the reference " "reference implementation does not implement it."
"implementation does not implement it."
) )
else: else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'") raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
@ -172,11 +169,10 @@ def _convert_sampling_params(
response_format: Optional[ResponseFormat], # type: ignore response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig], log_prob_config: Optional[LogProbConfig],
) -> vllm.SamplingParams: ) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from """Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
Llama Stack's format to vLLM's format.""" format."""
# In the absence of provided config values, use Llama Stack defaults # In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
# a encoded in the Llama Stack dataclasses. These defaults are # Stack dataclasses. These defaults are different from vLLM's defaults.
# different from vLLM's defaults.
if sampling_params is None: if sampling_params is None:
sampling_params = SamplingParams() sampling_params = SamplingParams()
if log_prob_config is None: if log_prob_config is None:
@ -220,11 +216,9 @@ def _convert_sampling_params(
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
""" """
vLLM-based inference model adapter for Llama Stack with support for multiple vLLM-based inference model adapter for Llama Stack with support for multiple models.
models.
Requires the configuration parameters documented in the Requires the configuration parameters documented in the :class:`VllmConfig2` class.
:class:`VllmConfig2` class.
""" """
config: VLLMConfig config: VLLMConfig
@ -249,37 +243,28 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.engine = None self.engine = None
self.chat = None self.chat = None
def __del__(self):
self._shutdown()
########################################################################### ###########################################################################
# METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS # METHODS INHERITED FROM IMPLICIT BASE CLASS.
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
async def initialize(self) -> None: async def initialize(self) -> None:
""" """
Callback that is invoked through many levels of indirection during Callback that is invoked through many levels of indirection during provider class
provider class instantiation, sometime after when __init__() is called instantiation, sometime after when __init__() is called and before any model registration
and before any model registration methods or methods connected to a methods or methods connected to a REST API are called.
REST API are called.
It's not clear what assumptions the class can make about the platform's It's not clear what assumptions the class can make about the platform's initialization
initialization state here that can't be made during __init__(), and state here that can't be made during __init__(), and vLLM can't be started until we know
vLLM can't be started until we know what model it's supposed to be what model it's supposed to be serving, so nothing happens here currently.
serving, so nothing happens here currently.
""" """
pass pass
async def shutdown(self) -> None: async def shutdown(self) -> None:
""" """
Callback that apparently is invoked when shutting down the Llama Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
Stack server. Not sure how to shut down a Llama Stack server in such to shut down a Llama Stack server in such a way as to trigger this callback.
a way as to trigger this callback.
""" """
_info("Shutting down inline vLLM inference provider.") _info(f"Shutting down inline vLLM inference provider {self}.")
self._shutdown()
def _shutdown(self) -> None:
"""Internal non-async version of self.shutdown(). Idempotent."""
if self.engine is not None: if self.engine is not None:
self.engine.shutdown_background_loop() self.engine.shutdown_background_loop()
self.engine = None self.engine = None
@ -293,14 +278,14 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# Note that the return type of the superclass method is WRONG # Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
""" """
Callback that is called when the server associates an inference endpoint Callback that is called when the server associates an inference endpoint with an
with an inference provider. inference provider.
:param model: Object that encapsulates parameters necessary for identifying :param model: Object that encapsulates parameters necessary for identifying a specific
a specific LLM. LLM.
:returns: The input ``Model`` object. It may or may not be permissible :returns: The input ``Model`` object. It may or may not be permissible to change fields
to change fields before returning this object. before returning this object.
""" """
_debug(f"In register_model({model})") _debug(f"In register_model({model})")
@ -318,14 +303,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if self.resolved_model_id is not None: if self.resolved_model_id is not None:
if resolved_model_id != self.resolved_model_id: if resolved_model_id != self.resolved_model_id:
raise ValueError( raise ValueError(
f"Attempted to serve two LLMs (ids " f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
f"'{self.resolved_model_id}') and " f"'{resolved_model_id}') from one copy of provider '{self}'. Use multiple "
f"'{resolved_model_id}') from one copy of "
f"provider '{self}'. Use multiple "
f"copies of the provider instead." f"copies of the provider instead."
) )
else: else:
# Model already loaded # Model already loaded
self.model_ids.add(model.model_id)
return model return model
_info(f"Preloading model: {resolved_model_id}") _info(f"Preloading model: {resolved_model_id}")
@ -386,24 +370,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
""" """
Callback that is called when the server removes an inference endpoint Callback that is called when the server removes an inference endpoint from an inference
from an inference provider. provider.
:param model_id: The same external ID that the higher layers of the :param model_id: The same external ID that the higher layers of the stack previously passed
stack previously passed to :func:`register_model()` to :func:`register_model()`
""" """
if model_id not in self.model_ids: if model_id not in self.model_ids:
raise ValueError( raise ValueError(
f"Attempted to unregister model ID '{model_id}', " f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
f"but that ID is not registered to this provider."
) )
self.model_ids.remove(model_id) self.model_ids.remove(model_id)
if len(self.model_ids) == 0: if len(self.model_ids) == 0:
# Last model was just unregistered. Shut down the connection # Last model was just unregistered. Shut down the connection to vLLM and free up
# to vLLM and free up resources. # resources.
# Note that this operation may cause in-flight chat completion # Note that this operation may cause in-flight chat completion requests on the
# requests on the now-unregistered model to return errors. # now-unregistered model to return errors.
self.resolved_model_id = None self.resolved_model_id = None
self.chat = None self.chat = None
self.engine.shutdown_background_loop() self.engine.shutdown_background_loop()
@ -449,21 +432,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def _streaming_completion( async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]: ) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming """Internal implementation of :func:`completion()` API for the streaming case. Assumes
case. Assumes that arguments have been validated upstream. that arguments have been validated upstream.
:param content: Must be a string :param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format`` :param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format and ``sampling_params`` arguments, converted to VLLM format
""" """
# We run agains the vLLM generate() call directly instead of using the # We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
# OpenAI-compatible layer, because doing so simplifies the code here. # layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate() # The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str() request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator. # The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput # The generator returns objects of type vllm.RequestOutput.
results_generator = self.engine.generate(content, sampling_params, request_id) results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below. # Need to know the model's EOS token ID for the conversion code below.
@ -477,10 +460,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# This case also should never happen # This case also should never happen
raise ValueError("Inference produced empty result") raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the # If we get here, then request_output contains the final output of the generate() call.
# generate() call. # The result may include multiple alternate outputs, but Llama Stack APIs only allow
# The result may include multiple alternate outputs, but Llama Stack APIs # us to return one.
# only allow us to return one.
output: vllm.CompletionOutput = request_output.outputs[0] output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text completion_string = output.text
@ -490,8 +472,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
for logprob_dict in output.logprobs for logprob_dict in output.logprobs
] ]
# The final output chunk should be labeled with the reason that the # The final output chunk should be labeled with the reason that the overall generate()
# overall generate() call completed. # call completed.
stop_reason_str = output.stop_reason stop_reason_str = output.stop_reason
if stop_reason_str is None: if stop_reason_str is None:
stop_reason = None # Still going stop_reason = None # Still going
@ -502,15 +484,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
else: else:
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'") raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
# vLLM's protocol outputs the stop token, then sets end of message # vLLM's protocol outputs the stop token, then sets end of message on the next step for
# on the next step for some reason. # some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id: if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs) yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but # Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
# vLLM doesn't always provide one if it runs out of tokens. # provide one if it runs out of tokens.
if stop_reason is None: if stop_reason is None:
yield CompletionResponseStreamChunk( yield CompletionResponseStreamChunk(
delta=completion_string, delta=completion_string,
@ -542,10 +524,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
) )
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest # Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# dataclass. # Note that this dataclass has the same name as a similar dataclass in Llama Stack.
# Note that this dataclass has the same name as a similar dataclass in
# Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict( request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
ChatCompletionRequest( ChatCompletionRequest(
model=self.resolved_model_id, model=self.resolved_model_id,
@ -573,7 +553,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if not isinstance(vllm_result, AsyncGenerator): if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed. # vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator # Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result) return self._convert_streaming_results(vllm_result)
else: else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse): if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
@ -587,17 +567,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse: ) -> ChatCompletionResponse:
""" """
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
API into an equivalent Llama Stack object. equivalent Llama Stack object.
The result from vLLM's non-streaming API is a dataclass with The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
the same name as the Llama Stack ChatCompletionResponse dataclass, Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
but with more and different field names. We ignore the fields that the fields that aren't currently present in the Llama Stack dataclass.
aren't currently present in the Llama Stack dataclass.
""" """
# There may be multiple responses, but we can only pass through the # There may be multiple responses, but we can only pass through the first one.
# first one.
if len(vllm_result.choices) == 0: if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses") raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message vllm_message = vllm_result.choices[0].message
@ -634,13 +612,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
:param vllm_result: Stream of strings that need to be parsed :param vllm_result: Stream of strings that need to be parsed
""" """
# Tool calls come in pieces, but Llama Stack expects them in bigger # Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# chunks. We build up those chunks and output them at the end. # those chunks and output them at the end.
# This data structure holds the current set of partial tool calls. # This data structure holds the current set of partial tool calls.
index_to_tool_call: Dict[int, Dict] = dict() index_to_tool_call: Dict[int, Dict] = dict()
# The Llama Stack event stream must always start with a start event. # The Llama Stack event stream must always start with a start event. Use an empty one to
# Use an empty one to simplify logic below # simplify logic below
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start, event_type=ChatCompletionResponseEventType.start,
@ -651,8 +629,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
converted_stop_reason = None converted_stop_reason = None
async for chunk_str in vllm_result: async for chunk_str in vllm_result:
# Due to OpenAI compatibility, each event in the stream # Due to OpenAI compatibility, each event in the stream will start with "data: " and
# will start with "data: " and end with "\n\n". # end with "\n\n".
_prefix = "data: " _prefix = "data: "
_suffix = "\n\n" _suffix = "\n\n"
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix): if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
@ -677,8 +655,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") _debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs # The result may contain multiple completions, but Llama Stack APIs only support
# only support returning one. # returning one.
first_choice = parsed_chunk["choices"][0] first_choice = parsed_chunk["choices"][0]
converted_stop_reason = get_stop_reason(first_choice["finish_reason"]) converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"] delta_record = first_choice["delta"]
@ -693,8 +671,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
) )
elif "tool_calls" in delta_record: elif "tool_calls" in delta_record:
# Tool call(s). Llama Stack APIs do not have a clear way to return # Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
# partial tool calls, so buffer until we get a "tool calls" stop reason # calls, so buffer until we get a "tool calls" stop reason
for tc in delta_record["tool_calls"]: for tc in delta_record["tool_calls"]:
index = tc["index"] index = tc["index"]
if index not in index_to_tool_call: if index not in index_to_tool_call:
@ -716,8 +694,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if first_choice["finish_reason"] == "tool_calls": if first_choice["finish_reason"] == "tool_calls":
# Special OpenAI code for "tool calls complete". # Special OpenAI code for "tool calls complete".
# Output the buffered tool calls. Llama Stack requires a separate # Output the buffered tool calls. Llama Stack requires a separate event per tool
# event per tool call. # call.
for tool_call_record in index_to_tool_call.values(): for tool_call_record in index_to_tool_call.values():
# Arguments come in as a string. Parse the completed string. # Arguments come in as a string. Parse the completed string.
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"]) tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
@ -731,6 +709,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
) )
) )
# If we get here, we've lost the connection with the vLLM event stream # If we get here, we've lost the connection with the vLLM event stream before it ended
# before it ended normally. # normally.
raise ValueError("vLLM event stream ended without [DONE] message.") raise ValueError("vLLM event stream ended without [DONE] message.")