mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Change line width and remove __del__
This commit is contained in:
parent
74c8504f50
commit
24cc7a777c
1 changed files with 90 additions and 112 deletions
|
@ -11,8 +11,10 @@ import re
|
|||
import uuid
|
||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names,
|
||||
# so we import fully-qualified names
|
||||
import llama_models.sku_list
|
||||
|
||||
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||
# fully-qualified names
|
||||
import vllm.entrypoints.openai.protocol
|
||||
import vllm.sampling_params
|
||||
|
||||
|
@ -86,8 +88,8 @@ from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_
|
|||
# Constants go here
|
||||
|
||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers
|
||||
# for the full list of available parsers.
|
||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||
# available parsers.
|
||||
# TODO: Expand this list
|
||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||
"GraniteConfig": "granite",
|
||||
|
@ -135,33 +137,28 @@ def _response_format_to_guided_decoding_params(
|
|||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||
"""
|
||||
Translate constrained decoding parameters from Llama Stack's
|
||||
format to vLLM's format.
|
||||
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||
|
||||
:param response_format: Llama Stack version of constrained decoding
|
||||
info. Can be ``None``, indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference
|
||||
layer of vLLM.
|
||||
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||
indicating no constraints.
|
||||
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||
"""
|
||||
if response_format is None:
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams()
|
||||
# returns an invalid value that crashes the executor on some code
|
||||
# paths. Use ``None`` instead.
|
||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||
return None
|
||||
|
||||
# Llama Stack currently implements fewer types of constrained
|
||||
# decoding than vLLM does. Translate the types that exist and
|
||||
# detect if Llama Stack adds new ones.
|
||||
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||
elif isinstance(response_format, GrammarResponseFormat):
|
||||
# BNF grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM
|
||||
# uses the string representation of the grammar.
|
||||
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||
# representation of the grammar.
|
||||
raise TypeError(
|
||||
"Constrained decoding with BNF grammars is not "
|
||||
"currently implemented, because the reference "
|
||||
"implementation does not implement it."
|
||||
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||
"reference implementation does not implement it."
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||
|
@ -172,11 +169,10 @@ def _convert_sampling_params(
|
|||
response_format: Optional[ResponseFormat], # type: ignore
|
||||
log_prob_config: Optional[LogProbConfig],
|
||||
) -> vllm.SamplingParams:
|
||||
"""Convert sampling and constrained decoding configuration from
|
||||
Llama Stack's format to vLLM's format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults
|
||||
# a encoded in the Llama Stack dataclasses. These defaults are
|
||||
# different from vLLM's defaults.
|
||||
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||
format."""
|
||||
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||
if sampling_params is None:
|
||||
sampling_params = SamplingParams()
|
||||
if log_prob_config is None:
|
||||
|
@ -220,11 +216,9 @@ def _convert_sampling_params(
|
|||
|
||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
"""
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple
|
||||
models.
|
||||
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||
|
||||
Requires the configuration parameters documented in the
|
||||
:class:`VllmConfig2` class.
|
||||
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||
"""
|
||||
|
||||
config: VLLMConfig
|
||||
|
@ -249,37 +243,28 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self.engine = None
|
||||
self.chat = None
|
||||
|
||||
def __del__(self):
|
||||
self._shutdown()
|
||||
|
||||
###########################################################################
|
||||
# METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS
|
||||
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""
|
||||
Callback that is invoked through many levels of indirection during
|
||||
provider class instantiation, sometime after when __init__() is called
|
||||
and before any model registration methods or methods connected to a
|
||||
REST API are called.
|
||||
Callback that is invoked through many levels of indirection during provider class
|
||||
instantiation, sometime after when __init__() is called and before any model registration
|
||||
methods or methods connected to a REST API are called.
|
||||
|
||||
It's not clear what assumptions the class can make about the platform's
|
||||
initialization state here that can't be made during __init__(), and
|
||||
vLLM can't be started until we know what model it's supposed to be
|
||||
serving, so nothing happens here currently.
|
||||
It's not clear what assumptions the class can make about the platform's initialization
|
||||
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||
what model it's supposed to be serving, so nothing happens here currently.
|
||||
"""
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
"""
|
||||
Callback that apparently is invoked when shutting down the Llama
|
||||
Stack server. Not sure how to shut down a Llama Stack server in such
|
||||
a way as to trigger this callback.
|
||||
Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
|
||||
to shut down a Llama Stack server in such a way as to trigger this callback.
|
||||
"""
|
||||
_info("Shutting down inline vLLM inference provider.")
|
||||
self._shutdown()
|
||||
|
||||
def _shutdown(self) -> None:
|
||||
"""Internal non-async version of self.shutdown(). Idempotent."""
|
||||
_info(f"Shutting down inline vLLM inference provider {self}.")
|
||||
if self.engine is not None:
|
||||
self.engine.shutdown_background_loop()
|
||||
self.engine = None
|
||||
|
@ -293,14 +278,14 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
# Note that the return type of the superclass method is WRONG
|
||||
async def register_model(self, model: Model) -> Model:
|
||||
"""
|
||||
Callback that is called when the server associates an inference endpoint
|
||||
with an inference provider.
|
||||
Callback that is called when the server associates an inference endpoint with an
|
||||
inference provider.
|
||||
|
||||
:param model: Object that encapsulates parameters necessary for identifying
|
||||
a specific LLM.
|
||||
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||
LLM.
|
||||
|
||||
:returns: The input ``Model`` object. It may or may not be permissible
|
||||
to change fields before returning this object.
|
||||
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||
before returning this object.
|
||||
"""
|
||||
_debug(f"In register_model({model})")
|
||||
|
||||
|
@ -318,14 +303,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if self.resolved_model_id is not None:
|
||||
if resolved_model_id != self.resolved_model_id:
|
||||
raise ValueError(
|
||||
f"Attempted to serve two LLMs (ids "
|
||||
f"'{self.resolved_model_id}') and "
|
||||
f"'{resolved_model_id}') from one copy of "
|
||||
f"provider '{self}'. Use multiple "
|
||||
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||
f"'{resolved_model_id}') from one copy of provider '{self}'. Use multiple "
|
||||
f"copies of the provider instead."
|
||||
)
|
||||
else:
|
||||
# Model already loaded
|
||||
self.model_ids.add(model.model_id)
|
||||
return model
|
||||
|
||||
_info(f"Preloading model: {resolved_model_id}")
|
||||
|
@ -386,24 +370,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
async def unregister_model(self, model_id: str) -> None:
|
||||
"""
|
||||
Callback that is called when the server removes an inference endpoint
|
||||
from an inference provider.
|
||||
Callback that is called when the server removes an inference endpoint from an inference
|
||||
provider.
|
||||
|
||||
:param model_id: The same external ID that the higher layers of the
|
||||
stack previously passed to :func:`register_model()`
|
||||
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||
to :func:`register_model()`
|
||||
"""
|
||||
if model_id not in self.model_ids:
|
||||
raise ValueError(
|
||||
f"Attempted to unregister model ID '{model_id}', "
|
||||
f"but that ID is not registered to this provider."
|
||||
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||
)
|
||||
self.model_ids.remove(model_id)
|
||||
|
||||
if len(self.model_ids) == 0:
|
||||
# Last model was just unregistered. Shut down the connection
|
||||
# to vLLM and free up resources.
|
||||
# Note that this operation may cause in-flight chat completion
|
||||
# requests on the now-unregistered model to return errors.
|
||||
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||
# resources.
|
||||
# Note that this operation may cause in-flight chat completion requests on the
|
||||
# now-unregistered model to return errors.
|
||||
self.resolved_model_id = None
|
||||
self.chat = None
|
||||
self.engine.shutdown_background_loop()
|
||||
|
@ -449,21 +432,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
async def _streaming_completion(
|
||||
self, content: str, sampling_params: vllm.SamplingParams
|
||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||
"""Internal implementation of :func:`completion()` API for the streaming
|
||||
case. Assumes that arguments have been validated upstream.
|
||||
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||
that arguments have been validated upstream.
|
||||
|
||||
:param content: Must be a string
|
||||
:param sampling_params: Paramters from public API's ``response_format``
|
||||
and ``sampling_params`` arguments, converted to VLLM format
|
||||
"""
|
||||
# We run agains the vLLM generate() call directly instead of using the
|
||||
# OpenAI-compatible layer, because doing so simplifies the code here.
|
||||
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||
# layer, because doing so simplifies the code here.
|
||||
|
||||
# The vLLM engine requires a unique identifier for each call to generate()
|
||||
request_id = _random_uuid_str()
|
||||
|
||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||
# The generator returns objects of type vllm.RequestOutput
|
||||
# The generator returns objects of type vllm.RequestOutput.
|
||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||
|
||||
# Need to know the model's EOS token ID for the conversion code below.
|
||||
|
@ -477,10 +460,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
# This case also should never happen
|
||||
raise ValueError("Inference produced empty result")
|
||||
|
||||
# If we get here, then request_output contains the final output of the
|
||||
# generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs
|
||||
# only allow us to return one.
|
||||
# If we get here, then request_output contains the final output of the generate() call.
|
||||
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||
# us to return one.
|
||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||
completion_string = output.text
|
||||
|
||||
|
@ -490,8 +472,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
for logprob_dict in output.logprobs
|
||||
]
|
||||
|
||||
# The final output chunk should be labeled with the reason that the
|
||||
# overall generate() call completed.
|
||||
# The final output chunk should be labeled with the reason that the overall generate()
|
||||
# call completed.
|
||||
stop_reason_str = output.stop_reason
|
||||
if stop_reason_str is None:
|
||||
stop_reason = None # Still going
|
||||
|
@ -502,15 +484,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
else:
|
||||
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
|
||||
|
||||
# vLLM's protocol outputs the stop token, then sets end of message
|
||||
# on the next step for some reason.
|
||||
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||
# some reason.
|
||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||
stop_reason = StopReason.end_of_message
|
||||
|
||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||
|
||||
# Llama Stack requires that the last chunk have a stop reason, but
|
||||
# vLLM doesn't always provide one if it runs out of tokens.
|
||||
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||
# provide one if it runs out of tokens.
|
||||
if stop_reason is None:
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta=completion_string,
|
||||
|
@ -542,10 +524,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||
)
|
||||
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
||||
# dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in
|
||||
# Llama Stack.
|
||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||
ChatCompletionRequest(
|
||||
model=self.resolved_model_id,
|
||||
|
@ -573,7 +553,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
if not isinstance(vllm_result, AsyncGenerator):
|
||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||
# vLLM client returns a stream of strings, which need to be parsed.
|
||||
# Stream comes in the form of an async generator
|
||||
# Stream comes in the form of an async generator.
|
||||
return self._convert_streaming_results(vllm_result)
|
||||
else:
|
||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||
|
@ -587,17 +567,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||
) -> ChatCompletionResponse:
|
||||
"""
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible
|
||||
API into an equivalent Llama Stack object.
|
||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||
equivalent Llama Stack object.
|
||||
|
||||
The result from vLLM's non-streaming API is a dataclass with
|
||||
the same name as the Llama Stack ChatCompletionResponse dataclass,
|
||||
but with more and different field names. We ignore the fields that
|
||||
aren't currently present in the Llama Stack dataclass.
|
||||
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||
the fields that aren't currently present in the Llama Stack dataclass.
|
||||
"""
|
||||
|
||||
# There may be multiple responses, but we can only pass through the
|
||||
# first one.
|
||||
# There may be multiple responses, but we can only pass through the first one.
|
||||
if len(vllm_result.choices) == 0:
|
||||
raise ValueError("Don't know how to convert response object without any responses")
|
||||
vllm_message = vllm_result.choices[0].message
|
||||
|
@ -634,13 +612,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
:param vllm_result: Stream of strings that need to be parsed
|
||||
"""
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger
|
||||
# chunks. We build up those chunks and output them at the end.
|
||||
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||
# those chunks and output them at the end.
|
||||
# This data structure holds the current set of partial tool calls.
|
||||
index_to_tool_call: Dict[int, Dict] = dict()
|
||||
|
||||
# The Llama Stack event stream must always start with a start event.
|
||||
# Use an empty one to simplify logic below
|
||||
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||
# simplify logic below
|
||||
yield ChatCompletionResponseStreamChunk(
|
||||
event=ChatCompletionResponseEvent(
|
||||
event_type=ChatCompletionResponseEventType.start,
|
||||
|
@ -651,8 +629,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
converted_stop_reason = None
|
||||
async for chunk_str in vllm_result:
|
||||
# Due to OpenAI compatibility, each event in the stream
|
||||
# will start with "data: " and end with "\n\n".
|
||||
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||
# end with "\n\n".
|
||||
_prefix = "data: "
|
||||
_suffix = "\n\n"
|
||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||
|
@ -677,8 +655,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||
|
||||
# The result may contain multiple completions, but Llama Stack APIs
|
||||
# only support returning one.
|
||||
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||
# returning one.
|
||||
first_choice = parsed_chunk["choices"][0]
|
||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||
delta_record = first_choice["delta"]
|
||||
|
@ -693,8 +671,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
)
|
||||
elif "tool_calls" in delta_record:
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return
|
||||
# partial tool calls, so buffer until we get a "tool calls" stop reason
|
||||
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||
# calls, so buffer until we get a "tool calls" stop reason
|
||||
for tc in delta_record["tool_calls"]:
|
||||
index = tc["index"]
|
||||
if index not in index_to_tool_call:
|
||||
|
@ -716,8 +694,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
|
||||
if first_choice["finish_reason"] == "tool_calls":
|
||||
# Special OpenAI code for "tool calls complete".
|
||||
# Output the buffered tool calls. Llama Stack requires a separate
|
||||
# event per tool call.
|
||||
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||
# call.
|
||||
for tool_call_record in index_to_tool_call.values():
|
||||
# Arguments come in as a string. Parse the completed string.
|
||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||
|
@ -731,6 +709,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
|||
)
|
||||
)
|
||||
|
||||
# If we get here, we've lost the connection with the vLLM event stream
|
||||
# before it ended normally.
|
||||
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||
# normally.
|
||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue