diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py index d5bdc11da..93f5cb56b 100644 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ b/llama_stack/providers/inline/inference/vllm/vllm.py @@ -11,8 +11,10 @@ import re import uuid from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union -# These vLLM modules contain names that overlap with Llama Stack names, -# so we import fully-qualified names +import llama_models.sku_list + +# These vLLM modules contain names that overlap with Llama Stack names, so we import +# fully-qualified names import vllm.entrypoints.openai.protocol import vllm.sampling_params @@ -86,8 +88,8 @@ from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_ # Constants go here # Map from Hugging Face model architecture name to appropriate tool parser. -# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers -# for the full list of available parsers. +# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of +# available parsers. # TODO: Expand this list CONFIG_TYPE_TO_TOOL_PARSER = { "GraniteConfig": "granite", @@ -135,33 +137,28 @@ def _response_format_to_guided_decoding_params( response_format: Optional[ResponseFormat], # type: ignore ) -> vllm.sampling_params.GuidedDecodingParams: """ - Translate constrained decoding parameters from Llama Stack's - format to vLLM's format. + Translate constrained decoding parameters from Llama Stack's format to vLLM's format. - :param response_format: Llama Stack version of constrained decoding - info. Can be ``None``, indicating no constraints. - :returns: The equivalent dataclass object for the low-level inference - layer of vLLM. + :param response_format: Llama Stack version of constrained decoding info. Can be ``None``, + indicating no constraints. + :returns: The equivalent dataclass object for the low-level inference layer of vLLM. """ if response_format is None: - # As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() - # returns an invalid value that crashes the executor on some code - # paths. Use ``None`` instead. + # As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid + # value that crashes the executor on some code paths. Use ``None`` instead. return None - # Llama Stack currently implements fewer types of constrained - # decoding than vLLM does. Translate the types that exist and - # detect if Llama Stack adds new ones. + # Llama Stack currently implements fewer types of constrained decoding than vLLM does. + # Translate the types that exist and detect if Llama Stack adds new ones. if isinstance(response_format, JsonSchemaResponseFormat): return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema) elif isinstance(response_format, GrammarResponseFormat): # BNF grammar. - # Llama Stack uses the parse tree of the grammar, while vLLM - # uses the string representation of the grammar. + # Llama Stack uses the parse tree of the grammar, while vLLM uses the string + # representation of the grammar. raise TypeError( - "Constrained decoding with BNF grammars is not " - "currently implemented, because the reference " - "implementation does not implement it." + "Constrained decoding with BNF grammars is not currently implemented, because the " + "reference implementation does not implement it." ) else: raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'") @@ -172,11 +169,10 @@ def _convert_sampling_params( response_format: Optional[ResponseFormat], # type: ignore log_prob_config: Optional[LogProbConfig], ) -> vllm.SamplingParams: - """Convert sampling and constrained decoding configuration from - Llama Stack's format to vLLM's format.""" - # In the absence of provided config values, use Llama Stack defaults - # a encoded in the Llama Stack dataclasses. These defaults are - # different from vLLM's defaults. + """Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's + format.""" + # In the absence of provided config values, use Llama Stack defaults as encoded in the Llama + # Stack dataclasses. These defaults are different from vLLM's defaults. if sampling_params is None: sampling_params = SamplingParams() if log_prob_config is None: @@ -220,11 +216,9 @@ def _convert_sampling_params( class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): """ - vLLM-based inference model adapter for Llama Stack with support for multiple - models. + vLLM-based inference model adapter for Llama Stack with support for multiple models. - Requires the configuration parameters documented in the - :class:`VllmConfig2` class. + Requires the configuration parameters documented in the :class:`VllmConfig2` class. """ config: VLLMConfig @@ -249,37 +243,28 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self.engine = None self.chat = None - def __del__(self): - self._shutdown() - ########################################################################### - # METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS + # METHODS INHERITED FROM IMPLICIT BASE CLASS. + # TODO: Make this class inherit from the new base class ProviderBase once that class exists. async def initialize(self) -> None: """ - Callback that is invoked through many levels of indirection during - provider class instantiation, sometime after when __init__() is called - and before any model registration methods or methods connected to a - REST API are called. + Callback that is invoked through many levels of indirection during provider class + instantiation, sometime after when __init__() is called and before any model registration + methods or methods connected to a REST API are called. - It's not clear what assumptions the class can make about the platform's - initialization state here that can't be made during __init__(), and - vLLM can't be started until we know what model it's supposed to be - serving, so nothing happens here currently. + It's not clear what assumptions the class can make about the platform's initialization + state here that can't be made during __init__(), and vLLM can't be started until we know + what model it's supposed to be serving, so nothing happens here currently. """ pass async def shutdown(self) -> None: """ - Callback that apparently is invoked when shutting down the Llama - Stack server. Not sure how to shut down a Llama Stack server in such - a way as to trigger this callback. + Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how + to shut down a Llama Stack server in such a way as to trigger this callback. """ - _info("Shutting down inline vLLM inference provider.") - self._shutdown() - - def _shutdown(self) -> None: - """Internal non-async version of self.shutdown(). Idempotent.""" + _info(f"Shutting down inline vLLM inference provider {self}.") if self.engine is not None: self.engine.shutdown_background_loop() self.engine = None @@ -293,14 +278,14 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # Note that the return type of the superclass method is WRONG async def register_model(self, model: Model) -> Model: """ - Callback that is called when the server associates an inference endpoint - with an inference provider. + Callback that is called when the server associates an inference endpoint with an + inference provider. - :param model: Object that encapsulates parameters necessary for identifying - a specific LLM. + :param model: Object that encapsulates parameters necessary for identifying a specific + LLM. - :returns: The input ``Model`` object. It may or may not be permissible - to change fields before returning this object. + :returns: The input ``Model`` object. It may or may not be permissible to change fields + before returning this object. """ _debug(f"In register_model({model})") @@ -318,14 +303,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if self.resolved_model_id is not None: if resolved_model_id != self.resolved_model_id: raise ValueError( - f"Attempted to serve two LLMs (ids " - f"'{self.resolved_model_id}') and " - f"'{resolved_model_id}') from one copy of " - f"provider '{self}'. Use multiple " + f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and " + f"'{resolved_model_id}') from one copy of provider '{self}'. Use multiple " f"copies of the provider instead." ) else: # Model already loaded + self.model_ids.add(model.model_id) return model _info(f"Preloading model: {resolved_model_id}") @@ -386,24 +370,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async def unregister_model(self, model_id: str) -> None: """ - Callback that is called when the server removes an inference endpoint - from an inference provider. + Callback that is called when the server removes an inference endpoint from an inference + provider. - :param model_id: The same external ID that the higher layers of the - stack previously passed to :func:`register_model()` + :param model_id: The same external ID that the higher layers of the stack previously passed + to :func:`register_model()` """ if model_id not in self.model_ids: raise ValueError( - f"Attempted to unregister model ID '{model_id}', " - f"but that ID is not registered to this provider." + f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider." ) self.model_ids.remove(model_id) if len(self.model_ids) == 0: - # Last model was just unregistered. Shut down the connection - # to vLLM and free up resources. - # Note that this operation may cause in-flight chat completion - # requests on the now-unregistered model to return errors. + # Last model was just unregistered. Shut down the connection to vLLM and free up + # resources. + # Note that this operation may cause in-flight chat completion requests on the + # now-unregistered model to return errors. self.resolved_model_id = None self.chat = None self.engine.shutdown_background_loop() @@ -449,21 +432,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): async def _streaming_completion( self, content: str, sampling_params: vllm.SamplingParams ) -> AsyncIterator[CompletionResponseStreamChunk]: - """Internal implementation of :func:`completion()` API for the streaming - case. Assumes that arguments have been validated upstream. + """Internal implementation of :func:`completion()` API for the streaming case. Assumes + that arguments have been validated upstream. :param content: Must be a string :param sampling_params: Paramters from public API's ``response_format`` and ``sampling_params`` arguments, converted to VLLM format """ - # We run agains the vLLM generate() call directly instead of using the - # OpenAI-compatible layer, because doing so simplifies the code here. + # We run agains the vLLM generate() call directly instead of using the OpenAI-compatible + # layer, because doing so simplifies the code here. # The vLLM engine requires a unique identifier for each call to generate() request_id = _random_uuid_str() # The vLLM generate() API is streaming-only and returns an async generator. - # The generator returns objects of type vllm.RequestOutput + # The generator returns objects of type vllm.RequestOutput. results_generator = self.engine.generate(content, sampling_params, request_id) # Need to know the model's EOS token ID for the conversion code below. @@ -477,10 +460,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): # This case also should never happen raise ValueError("Inference produced empty result") - # If we get here, then request_output contains the final output of the - # generate() call. - # The result may include multiple alternate outputs, but Llama Stack APIs - # only allow us to return one. + # If we get here, then request_output contains the final output of the generate() call. + # The result may include multiple alternate outputs, but Llama Stack APIs only allow + # us to return one. output: vllm.CompletionOutput = request_output.outputs[0] completion_string = output.text @@ -490,8 +472,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): for logprob_dict in output.logprobs ] - # The final output chunk should be labeled with the reason that the - # overall generate() call completed. + # The final output chunk should be labeled with the reason that the overall generate() + # call completed. stop_reason_str = output.stop_reason if stop_reason_str is None: stop_reason = None # Still going @@ -502,15 +484,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): else: raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'") - # vLLM's protocol outputs the stop token, then sets end of message - # on the next step for some reason. + # vLLM's protocol outputs the stop token, then sets end of message on the next step for + # some reason. if request_output.outputs[-1].token_ids[-1] == eos_token_id: stop_reason = StopReason.end_of_message yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs) - # Llama Stack requires that the last chunk have a stop reason, but - # vLLM doesn't always provide one if it runs out of tokens. + # Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always + # provide one if it runs out of tokens. if stop_reason is None: yield CompletionResponseStreamChunk( delta=completion_string, @@ -542,10 +524,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" ) - # Arguments to the vLLM call must be packaged as a ChatCompletionRequest - # dataclass. - # Note that this dataclass has the same name as a similar dataclass in - # Llama Stack. + # Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass. + # Note that this dataclass has the same name as a similar dataclass in Llama Stack. request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict( ChatCompletionRequest( model=self.resolved_model_id, @@ -573,7 +553,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if not isinstance(vllm_result, AsyncGenerator): raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") # vLLM client returns a stream of strings, which need to be parsed. - # Stream comes in the form of an async generator + # Stream comes in the form of an async generator. return self._convert_streaming_results(vllm_result) else: if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse): @@ -587,17 +567,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse ) -> ChatCompletionResponse: """ - Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible - API into an equivalent Llama Stack object. + Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an + equivalent Llama Stack object. - The result from vLLM's non-streaming API is a dataclass with - the same name as the Llama Stack ChatCompletionResponse dataclass, - but with more and different field names. We ignore the fields that - aren't currently present in the Llama Stack dataclass. + The result from vLLM's non-streaming API is a dataclass with the same name as the Llama + Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore + the fields that aren't currently present in the Llama Stack dataclass. """ - # There may be multiple responses, but we can only pass through the - # first one. + # There may be multiple responses, but we can only pass through the first one. if len(vllm_result.choices) == 0: raise ValueError("Don't know how to convert response object without any responses") vllm_message = vllm_result.choices[0].message @@ -634,13 +612,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): :param vllm_result: Stream of strings that need to be parsed """ - # Tool calls come in pieces, but Llama Stack expects them in bigger - # chunks. We build up those chunks and output them at the end. + # Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up + # those chunks and output them at the end. # This data structure holds the current set of partial tool calls. index_to_tool_call: Dict[int, Dict] = dict() - # The Llama Stack event stream must always start with a start event. - # Use an empty one to simplify logic below + # The Llama Stack event stream must always start with a start event. Use an empty one to + # simplify logic below yield ChatCompletionResponseStreamChunk( event=ChatCompletionResponseEvent( event_type=ChatCompletionResponseEventType.start, @@ -651,8 +629,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): converted_stop_reason = None async for chunk_str in vllm_result: - # Due to OpenAI compatibility, each event in the stream - # will start with "data: " and end with "\n\n". + # Due to OpenAI compatibility, each event in the stream will start with "data: " and + # end with "\n\n". _prefix = "data: " _suffix = "\n\n" if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix): @@ -677,8 +655,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): _debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") - # The result may contain multiple completions, but Llama Stack APIs - # only support returning one. + # The result may contain multiple completions, but Llama Stack APIs only support + # returning one. first_choice = parsed_chunk["choices"][0] converted_stop_reason = get_stop_reason(first_choice["finish_reason"]) delta_record = first_choice["delta"] @@ -693,8 +671,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) ) elif "tool_calls" in delta_record: - # Tool call(s). Llama Stack APIs do not have a clear way to return - # partial tool calls, so buffer until we get a "tool calls" stop reason + # Tool call(s). Llama Stack APIs do not have a clear way to return partial tool + # calls, so buffer until we get a "tool calls" stop reason for tc in delta_record["tool_calls"]: index = tc["index"] if index not in index_to_tool_call: @@ -716,8 +694,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): if first_choice["finish_reason"] == "tool_calls": # Special OpenAI code for "tool calls complete". - # Output the buffered tool calls. Llama Stack requires a separate - # event per tool call. + # Output the buffered tool calls. Llama Stack requires a separate event per tool + # call. for tool_call_record in index_to_tool_call.values(): # Arguments come in as a string. Parse the completed string. tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"]) @@ -731,6 +709,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate): ) ) - # If we get here, we've lost the connection with the vLLM event stream - # before it ended normally. + # If we get here, we've lost the connection with the vLLM event stream before it ended + # normally. raise ValueError("vLLM event stream ended without [DONE] message.")