Change line width and remove __del__

This commit is contained in:
Fred Reiss 2025-01-29 18:42:18 -08:00 committed by Ashwin Bharambe
parent 74c8504f50
commit 24cc7a777c

View file

@ -11,8 +11,10 @@ import re
import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
# These vLLM modules contain names that overlap with Llama Stack names,
# so we import fully-qualified names
import llama_models.sku_list
# These vLLM modules contain names that overlap with Llama Stack names, so we import
# fully-qualified names
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
@ -86,8 +88,8 @@ from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_
# Constants go here
# Map from Hugging Face model architecture name to appropriate tool parser.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers
# for the full list of available parsers.
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
# available parsers.
# TODO: Expand this list
CONFIG_TYPE_TO_TOOL_PARSER = {
"GraniteConfig": "granite",
@ -135,33 +137,28 @@ def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Translate constrained decoding parameters from Llama Stack's
format to vLLM's format.
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding
info. Can be ``None``, indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference
layer of vLLM.
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
indicating no constraints.
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
"""
if response_format is None:
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams()
# returns an invalid value that crashes the executor on some code
# paths. Use ``None`` instead.
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
# value that crashes the executor on some code paths. Use ``None`` instead.
return None
# Llama Stack currently implements fewer types of constrained
# decoding than vLLM does. Translate the types that exist and
# detect if Llama Stack adds new ones.
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
# Translate the types that exist and detect if Llama Stack adds new ones.
if isinstance(response_format, JsonSchemaResponseFormat):
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
elif isinstance(response_format, GrammarResponseFormat):
# BNF grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM
# uses the string representation of the grammar.
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
# representation of the grammar.
raise TypeError(
"Constrained decoding with BNF grammars is not "
"currently implemented, because the reference "
"implementation does not implement it."
"Constrained decoding with BNF grammars is not currently implemented, because the "
"reference implementation does not implement it."
)
else:
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
@ -172,11 +169,10 @@ def _convert_sampling_params(
response_format: Optional[ResponseFormat], # type: ignore
log_prob_config: Optional[LogProbConfig],
) -> vllm.SamplingParams:
"""Convert sampling and constrained decoding configuration from
Llama Stack's format to vLLM's format."""
# In the absence of provided config values, use Llama Stack defaults
# a encoded in the Llama Stack dataclasses. These defaults are
# different from vLLM's defaults.
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
format."""
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
# Stack dataclasses. These defaults are different from vLLM's defaults.
if sampling_params is None:
sampling_params = SamplingParams()
if log_prob_config is None:
@ -220,11 +216,9 @@ def _convert_sampling_params(
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"""
vLLM-based inference model adapter for Llama Stack with support for multiple
models.
vLLM-based inference model adapter for Llama Stack with support for multiple models.
Requires the configuration parameters documented in the
:class:`VllmConfig2` class.
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
"""
config: VLLMConfig
@ -249,37 +243,28 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self.engine = None
self.chat = None
def __del__(self):
self._shutdown()
###########################################################################
# METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
async def initialize(self) -> None:
"""
Callback that is invoked through many levels of indirection during
provider class instantiation, sometime after when __init__() is called
and before any model registration methods or methods connected to a
REST API are called.
Callback that is invoked through many levels of indirection during provider class
instantiation, sometime after when __init__() is called and before any model registration
methods or methods connected to a REST API are called.
It's not clear what assumptions the class can make about the platform's
initialization state here that can't be made during __init__(), and
vLLM can't be started until we know what model it's supposed to be
serving, so nothing happens here currently.
It's not clear what assumptions the class can make about the platform's initialization
state here that can't be made during __init__(), and vLLM can't be started until we know
what model it's supposed to be serving, so nothing happens here currently.
"""
pass
async def shutdown(self) -> None:
"""
Callback that apparently is invoked when shutting down the Llama
Stack server. Not sure how to shut down a Llama Stack server in such
a way as to trigger this callback.
Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
to shut down a Llama Stack server in such a way as to trigger this callback.
"""
_info("Shutting down inline vLLM inference provider.")
self._shutdown()
def _shutdown(self) -> None:
"""Internal non-async version of self.shutdown(). Idempotent."""
_info(f"Shutting down inline vLLM inference provider {self}.")
if self.engine is not None:
self.engine.shutdown_background_loop()
self.engine = None
@ -293,14 +278,14 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# Note that the return type of the superclass method is WRONG
async def register_model(self, model: Model) -> Model:
"""
Callback that is called when the server associates an inference endpoint
with an inference provider.
Callback that is called when the server associates an inference endpoint with an
inference provider.
:param model: Object that encapsulates parameters necessary for identifying
a specific LLM.
:param model: Object that encapsulates parameters necessary for identifying a specific
LLM.
:returns: The input ``Model`` object. It may or may not be permissible
to change fields before returning this object.
:returns: The input ``Model`` object. It may or may not be permissible to change fields
before returning this object.
"""
_debug(f"In register_model({model})")
@ -318,14 +303,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if self.resolved_model_id is not None:
if resolved_model_id != self.resolved_model_id:
raise ValueError(
f"Attempted to serve two LLMs (ids "
f"'{self.resolved_model_id}') and "
f"'{resolved_model_id}') from one copy of "
f"provider '{self}'. Use multiple "
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
f"'{resolved_model_id}') from one copy of provider '{self}'. Use multiple "
f"copies of the provider instead."
)
else:
# Model already loaded
self.model_ids.add(model.model_id)
return model
_info(f"Preloading model: {resolved_model_id}")
@ -386,24 +370,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def unregister_model(self, model_id: str) -> None:
"""
Callback that is called when the server removes an inference endpoint
from an inference provider.
Callback that is called when the server removes an inference endpoint from an inference
provider.
:param model_id: The same external ID that the higher layers of the
stack previously passed to :func:`register_model()`
:param model_id: The same external ID that the higher layers of the stack previously passed
to :func:`register_model()`
"""
if model_id not in self.model_ids:
raise ValueError(
f"Attempted to unregister model ID '{model_id}', "
f"but that ID is not registered to this provider."
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
)
self.model_ids.remove(model_id)
if len(self.model_ids) == 0:
# Last model was just unregistered. Shut down the connection
# to vLLM and free up resources.
# Note that this operation may cause in-flight chat completion
# requests on the now-unregistered model to return errors.
# Last model was just unregistered. Shut down the connection to vLLM and free up
# resources.
# Note that this operation may cause in-flight chat completion requests on the
# now-unregistered model to return errors.
self.resolved_model_id = None
self.chat = None
self.engine.shutdown_background_loop()
@ -449,21 +432,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming
case. Assumes that arguments have been validated upstream.
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
that arguments have been validated upstream.
:param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format
"""
# We run agains the vLLM generate() call directly instead of using the
# OpenAI-compatible layer, because doing so simplifies the code here.
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
# layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput
# The generator returns objects of type vllm.RequestOutput.
results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below.
@ -477,10 +460,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# This case also should never happen
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the
# generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs
# only allow us to return one.
# If we get here, then request_output contains the final output of the generate() call.
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
# us to return one.
output: vllm.CompletionOutput = request_output.outputs[0]
completion_string = output.text
@ -490,8 +472,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
for logprob_dict in output.logprobs
]
# The final output chunk should be labeled with the reason that the
# overall generate() call completed.
# The final output chunk should be labeled with the reason that the overall generate()
# call completed.
stop_reason_str = output.stop_reason
if stop_reason_str is None:
stop_reason = None # Still going
@ -502,15 +484,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
else:
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
# vLLM's protocol outputs the stop token, then sets end of message
# on the next step for some reason.
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
# some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
# Llama Stack requires that the last chunk have a stop reason, but
# vLLM doesn't always provide one if it runs out of tokens.
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
# provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(
delta=completion_string,
@ -542,10 +524,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
# dataclass.
# Note that this dataclass has the same name as a similar dataclass in
# Llama Stack.
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
ChatCompletionRequest(
model=self.resolved_model_id,
@ -573,7 +553,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if not isinstance(vllm_result, AsyncGenerator):
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
# vLLM client returns a stream of strings, which need to be parsed.
# Stream comes in the form of an async generator
# Stream comes in the form of an async generator.
return self._convert_streaming_results(vllm_result)
else:
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
@ -587,17 +567,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
) -> ChatCompletionResponse:
"""
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible
API into an equivalent Llama Stack object.
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
equivalent Llama Stack object.
The result from vLLM's non-streaming API is a dataclass with
the same name as the Llama Stack ChatCompletionResponse dataclass,
but with more and different field names. We ignore the fields that
aren't currently present in the Llama Stack dataclass.
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
the fields that aren't currently present in the Llama Stack dataclass.
"""
# There may be multiple responses, but we can only pass through the
# first one.
# There may be multiple responses, but we can only pass through the first one.
if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message
@ -634,13 +612,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
:param vllm_result: Stream of strings that need to be parsed
"""
# Tool calls come in pieces, but Llama Stack expects them in bigger
# chunks. We build up those chunks and output them at the end.
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
# those chunks and output them at the end.
# This data structure holds the current set of partial tool calls.
index_to_tool_call: Dict[int, Dict] = dict()
# The Llama Stack event stream must always start with a start event.
# Use an empty one to simplify logic below
# The Llama Stack event stream must always start with a start event. Use an empty one to
# simplify logic below
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
@ -651,8 +629,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
converted_stop_reason = None
async for chunk_str in vllm_result:
# Due to OpenAI compatibility, each event in the stream
# will start with "data: " and end with "\n\n".
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
# end with "\n\n".
_prefix = "data: "
_suffix = "\n\n"
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
@ -677,8 +655,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
# The result may contain multiple completions, but Llama Stack APIs
# only support returning one.
# The result may contain multiple completions, but Llama Stack APIs only support
# returning one.
first_choice = parsed_chunk["choices"][0]
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"]
@ -693,8 +671,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
)
elif "tool_calls" in delta_record:
# Tool call(s). Llama Stack APIs do not have a clear way to return
# partial tool calls, so buffer until we get a "tool calls" stop reason
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
# calls, so buffer until we get a "tool calls" stop reason
for tc in delta_record["tool_calls"]:
index = tc["index"]
if index not in index_to_tool_call:
@ -716,8 +694,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if first_choice["finish_reason"] == "tool_calls":
# Special OpenAI code for "tool calls complete".
# Output the buffered tool calls. Llama Stack requires a separate
# event per tool call.
# Output the buffered tool calls. Llama Stack requires a separate event per tool
# call.
for tool_call_record in index_to_tool_call.values():
# Arguments come in as a string. Parse the completed string.
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
@ -731,6 +709,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
)
)
# If we get here, we've lost the connection with the vLLM event stream
# before it ended normally.
# If we get here, we've lost the connection with the vLLM event stream before it ended
# normally.
raise ValueError("vLLM event stream ended without [DONE] message.")