mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 04:50:39 +00:00
Change line width and remove __del__
This commit is contained in:
parent
74c8504f50
commit
24cc7a777c
1 changed files with 90 additions and 112 deletions
|
@ -11,8 +11,10 @@ import re
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
# These vLLM modules contain names that overlap with Llama Stack names,
|
import llama_models.sku_list
|
||||||
# so we import fully-qualified names
|
|
||||||
|
# These vLLM modules contain names that overlap with Llama Stack names, so we import
|
||||||
|
# fully-qualified names
|
||||||
import vllm.entrypoints.openai.protocol
|
import vllm.entrypoints.openai.protocol
|
||||||
import vllm.sampling_params
|
import vllm.sampling_params
|
||||||
|
|
||||||
|
@ -86,8 +88,8 @@ from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_
|
||||||
# Constants go here
|
# Constants go here
|
||||||
|
|
||||||
# Map from Hugging Face model architecture name to appropriate tool parser.
|
# Map from Hugging Face model architecture name to appropriate tool parser.
|
||||||
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers
|
# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of
|
||||||
# for the full list of available parsers.
|
# available parsers.
|
||||||
# TODO: Expand this list
|
# TODO: Expand this list
|
||||||
CONFIG_TYPE_TO_TOOL_PARSER = {
|
CONFIG_TYPE_TO_TOOL_PARSER = {
|
||||||
"GraniteConfig": "granite",
|
"GraniteConfig": "granite",
|
||||||
|
@ -135,33 +137,28 @@ def _response_format_to_guided_decoding_params(
|
||||||
response_format: Optional[ResponseFormat], # type: ignore
|
response_format: Optional[ResponseFormat], # type: ignore
|
||||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||||
"""
|
"""
|
||||||
Translate constrained decoding parameters from Llama Stack's
|
Translate constrained decoding parameters from Llama Stack's format to vLLM's format.
|
||||||
format to vLLM's format.
|
|
||||||
|
|
||||||
:param response_format: Llama Stack version of constrained decoding
|
:param response_format: Llama Stack version of constrained decoding info. Can be ``None``,
|
||||||
info. Can be ``None``, indicating no constraints.
|
indicating no constraints.
|
||||||
:returns: The equivalent dataclass object for the low-level inference
|
:returns: The equivalent dataclass object for the low-level inference layer of vLLM.
|
||||||
layer of vLLM.
|
|
||||||
"""
|
"""
|
||||||
if response_format is None:
|
if response_format is None:
|
||||||
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams()
|
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid
|
||||||
# returns an invalid value that crashes the executor on some code
|
# value that crashes the executor on some code paths. Use ``None`` instead.
|
||||||
# paths. Use ``None`` instead.
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# Llama Stack currently implements fewer types of constrained
|
# Llama Stack currently implements fewer types of constrained decoding than vLLM does.
|
||||||
# decoding than vLLM does. Translate the types that exist and
|
# Translate the types that exist and detect if Llama Stack adds new ones.
|
||||||
# detect if Llama Stack adds new ones.
|
|
||||||
if isinstance(response_format, JsonSchemaResponseFormat):
|
if isinstance(response_format, JsonSchemaResponseFormat):
|
||||||
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema)
|
||||||
elif isinstance(response_format, GrammarResponseFormat):
|
elif isinstance(response_format, GrammarResponseFormat):
|
||||||
# BNF grammar.
|
# BNF grammar.
|
||||||
# Llama Stack uses the parse tree of the grammar, while vLLM
|
# Llama Stack uses the parse tree of the grammar, while vLLM uses the string
|
||||||
# uses the string representation of the grammar.
|
# representation of the grammar.
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
"Constrained decoding with BNF grammars is not "
|
"Constrained decoding with BNF grammars is not currently implemented, because the "
|
||||||
"currently implemented, because the reference "
|
"reference implementation does not implement it."
|
||||||
"implementation does not implement it."
|
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'")
|
||||||
|
@ -172,11 +169,10 @@ def _convert_sampling_params(
|
||||||
response_format: Optional[ResponseFormat], # type: ignore
|
response_format: Optional[ResponseFormat], # type: ignore
|
||||||
log_prob_config: Optional[LogProbConfig],
|
log_prob_config: Optional[LogProbConfig],
|
||||||
) -> vllm.SamplingParams:
|
) -> vllm.SamplingParams:
|
||||||
"""Convert sampling and constrained decoding configuration from
|
"""Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's
|
||||||
Llama Stack's format to vLLM's format."""
|
format."""
|
||||||
# In the absence of provided config values, use Llama Stack defaults
|
# In the absence of provided config values, use Llama Stack defaults as encoded in the Llama
|
||||||
# a encoded in the Llama Stack dataclasses. These defaults are
|
# Stack dataclasses. These defaults are different from vLLM's defaults.
|
||||||
# different from vLLM's defaults.
|
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
if log_prob_config is None:
|
if log_prob_config is None:
|
||||||
|
@ -220,11 +216,9 @@ def _convert_sampling_params(
|
||||||
|
|
||||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
"""
|
"""
|
||||||
vLLM-based inference model adapter for Llama Stack with support for multiple
|
vLLM-based inference model adapter for Llama Stack with support for multiple models.
|
||||||
models.
|
|
||||||
|
|
||||||
Requires the configuration parameters documented in the
|
Requires the configuration parameters documented in the :class:`VllmConfig2` class.
|
||||||
:class:`VllmConfig2` class.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
config: VLLMConfig
|
config: VLLMConfig
|
||||||
|
@ -249,37 +243,28 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self.engine = None
|
self.engine = None
|
||||||
self.chat = None
|
self.chat = None
|
||||||
|
|
||||||
def __del__(self):
|
|
||||||
self._shutdown()
|
|
||||||
|
|
||||||
###########################################################################
|
###########################################################################
|
||||||
# METHODS INHERITED FROM UNDOCUMENTED IMPLICIT MYSTERY BASE CLASS
|
# METHODS INHERITED FROM IMPLICIT BASE CLASS.
|
||||||
|
# TODO: Make this class inherit from the new base class ProviderBase once that class exists.
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""
|
"""
|
||||||
Callback that is invoked through many levels of indirection during
|
Callback that is invoked through many levels of indirection during provider class
|
||||||
provider class instantiation, sometime after when __init__() is called
|
instantiation, sometime after when __init__() is called and before any model registration
|
||||||
and before any model registration methods or methods connected to a
|
methods or methods connected to a REST API are called.
|
||||||
REST API are called.
|
|
||||||
|
|
||||||
It's not clear what assumptions the class can make about the platform's
|
It's not clear what assumptions the class can make about the platform's initialization
|
||||||
initialization state here that can't be made during __init__(), and
|
state here that can't be made during __init__(), and vLLM can't be started until we know
|
||||||
vLLM can't be started until we know what model it's supposed to be
|
what model it's supposed to be serving, so nothing happens here currently.
|
||||||
serving, so nothing happens here currently.
|
|
||||||
"""
|
"""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
"""
|
"""
|
||||||
Callback that apparently is invoked when shutting down the Llama
|
Callback that apparently is invoked when shutting down the Llama Stack server. Not sure how
|
||||||
Stack server. Not sure how to shut down a Llama Stack server in such
|
to shut down a Llama Stack server in such a way as to trigger this callback.
|
||||||
a way as to trigger this callback.
|
|
||||||
"""
|
"""
|
||||||
_info("Shutting down inline vLLM inference provider.")
|
_info(f"Shutting down inline vLLM inference provider {self}.")
|
||||||
self._shutdown()
|
|
||||||
|
|
||||||
def _shutdown(self) -> None:
|
|
||||||
"""Internal non-async version of self.shutdown(). Idempotent."""
|
|
||||||
if self.engine is not None:
|
if self.engine is not None:
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
self.engine = None
|
self.engine = None
|
||||||
|
@ -293,14 +278,14 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
# Note that the return type of the superclass method is WRONG
|
# Note that the return type of the superclass method is WRONG
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
"""
|
"""
|
||||||
Callback that is called when the server associates an inference endpoint
|
Callback that is called when the server associates an inference endpoint with an
|
||||||
with an inference provider.
|
inference provider.
|
||||||
|
|
||||||
:param model: Object that encapsulates parameters necessary for identifying
|
:param model: Object that encapsulates parameters necessary for identifying a specific
|
||||||
a specific LLM.
|
LLM.
|
||||||
|
|
||||||
:returns: The input ``Model`` object. It may or may not be permissible
|
:returns: The input ``Model`` object. It may or may not be permissible to change fields
|
||||||
to change fields before returning this object.
|
before returning this object.
|
||||||
"""
|
"""
|
||||||
_debug(f"In register_model({model})")
|
_debug(f"In register_model({model})")
|
||||||
|
|
||||||
|
@ -318,14 +303,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if self.resolved_model_id is not None:
|
if self.resolved_model_id is not None:
|
||||||
if resolved_model_id != self.resolved_model_id:
|
if resolved_model_id != self.resolved_model_id:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attempted to serve two LLMs (ids "
|
f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and "
|
||||||
f"'{self.resolved_model_id}') and "
|
f"'{resolved_model_id}') from one copy of provider '{self}'. Use multiple "
|
||||||
f"'{resolved_model_id}') from one copy of "
|
|
||||||
f"provider '{self}'. Use multiple "
|
|
||||||
f"copies of the provider instead."
|
f"copies of the provider instead."
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Model already loaded
|
# Model already loaded
|
||||||
|
self.model_ids.add(model.model_id)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
_info(f"Preloading model: {resolved_model_id}")
|
_info(f"Preloading model: {resolved_model_id}")
|
||||||
|
@ -386,24 +370,23 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
"""
|
"""
|
||||||
Callback that is called when the server removes an inference endpoint
|
Callback that is called when the server removes an inference endpoint from an inference
|
||||||
from an inference provider.
|
provider.
|
||||||
|
|
||||||
:param model_id: The same external ID that the higher layers of the
|
:param model_id: The same external ID that the higher layers of the stack previously passed
|
||||||
stack previously passed to :func:`register_model()`
|
to :func:`register_model()`
|
||||||
"""
|
"""
|
||||||
if model_id not in self.model_ids:
|
if model_id not in self.model_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Attempted to unregister model ID '{model_id}', "
|
f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider."
|
||||||
f"but that ID is not registered to this provider."
|
|
||||||
)
|
)
|
||||||
self.model_ids.remove(model_id)
|
self.model_ids.remove(model_id)
|
||||||
|
|
||||||
if len(self.model_ids) == 0:
|
if len(self.model_ids) == 0:
|
||||||
# Last model was just unregistered. Shut down the connection
|
# Last model was just unregistered. Shut down the connection to vLLM and free up
|
||||||
# to vLLM and free up resources.
|
# resources.
|
||||||
# Note that this operation may cause in-flight chat completion
|
# Note that this operation may cause in-flight chat completion requests on the
|
||||||
# requests on the now-unregistered model to return errors.
|
# now-unregistered model to return errors.
|
||||||
self.resolved_model_id = None
|
self.resolved_model_id = None
|
||||||
self.chat = None
|
self.chat = None
|
||||||
self.engine.shutdown_background_loop()
|
self.engine.shutdown_background_loop()
|
||||||
|
@ -449,21 +432,21 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
async def _streaming_completion(
|
async def _streaming_completion(
|
||||||
self, content: str, sampling_params: vllm.SamplingParams
|
self, content: str, sampling_params: vllm.SamplingParams
|
||||||
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
"""Internal implementation of :func:`completion()` API for the streaming
|
"""Internal implementation of :func:`completion()` API for the streaming case. Assumes
|
||||||
case. Assumes that arguments have been validated upstream.
|
that arguments have been validated upstream.
|
||||||
|
|
||||||
:param content: Must be a string
|
:param content: Must be a string
|
||||||
:param sampling_params: Paramters from public API's ``response_format``
|
:param sampling_params: Paramters from public API's ``response_format``
|
||||||
and ``sampling_params`` arguments, converted to VLLM format
|
and ``sampling_params`` arguments, converted to VLLM format
|
||||||
"""
|
"""
|
||||||
# We run agains the vLLM generate() call directly instead of using the
|
# We run agains the vLLM generate() call directly instead of using the OpenAI-compatible
|
||||||
# OpenAI-compatible layer, because doing so simplifies the code here.
|
# layer, because doing so simplifies the code here.
|
||||||
|
|
||||||
# The vLLM engine requires a unique identifier for each call to generate()
|
# The vLLM engine requires a unique identifier for each call to generate()
|
||||||
request_id = _random_uuid_str()
|
request_id = _random_uuid_str()
|
||||||
|
|
||||||
# The vLLM generate() API is streaming-only and returns an async generator.
|
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||||
# The generator returns objects of type vllm.RequestOutput
|
# The generator returns objects of type vllm.RequestOutput.
|
||||||
results_generator = self.engine.generate(content, sampling_params, request_id)
|
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||||
|
|
||||||
# Need to know the model's EOS token ID for the conversion code below.
|
# Need to know the model's EOS token ID for the conversion code below.
|
||||||
|
@ -477,10 +460,9 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
# This case also should never happen
|
# This case also should never happen
|
||||||
raise ValueError("Inference produced empty result")
|
raise ValueError("Inference produced empty result")
|
||||||
|
|
||||||
# If we get here, then request_output contains the final output of the
|
# If we get here, then request_output contains the final output of the generate() call.
|
||||||
# generate() call.
|
# The result may include multiple alternate outputs, but Llama Stack APIs only allow
|
||||||
# The result may include multiple alternate outputs, but Llama Stack APIs
|
# us to return one.
|
||||||
# only allow us to return one.
|
|
||||||
output: vllm.CompletionOutput = request_output.outputs[0]
|
output: vllm.CompletionOutput = request_output.outputs[0]
|
||||||
completion_string = output.text
|
completion_string = output.text
|
||||||
|
|
||||||
|
@ -490,8 +472,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
for logprob_dict in output.logprobs
|
for logprob_dict in output.logprobs
|
||||||
]
|
]
|
||||||
|
|
||||||
# The final output chunk should be labeled with the reason that the
|
# The final output chunk should be labeled with the reason that the overall generate()
|
||||||
# overall generate() call completed.
|
# call completed.
|
||||||
stop_reason_str = output.stop_reason
|
stop_reason_str = output.stop_reason
|
||||||
if stop_reason_str is None:
|
if stop_reason_str is None:
|
||||||
stop_reason = None # Still going
|
stop_reason = None # Still going
|
||||||
|
@ -502,15 +484,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
|
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
|
||||||
|
|
||||||
# vLLM's protocol outputs the stop token, then sets end of message
|
# vLLM's protocol outputs the stop token, then sets end of message on the next step for
|
||||||
# on the next step for some reason.
|
# some reason.
|
||||||
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||||
stop_reason = StopReason.end_of_message
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs)
|
||||||
|
|
||||||
# Llama Stack requires that the last chunk have a stop reason, but
|
# Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always
|
||||||
# vLLM doesn't always provide one if it runs out of tokens.
|
# provide one if it runs out of tokens.
|
||||||
if stop_reason is None:
|
if stop_reason is None:
|
||||||
yield CompletionResponseStreamChunk(
|
yield CompletionResponseStreamChunk(
|
||||||
delta=completion_string,
|
delta=completion_string,
|
||||||
|
@ -542,10 +524,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass.
|
||||||
# dataclass.
|
# Note that this dataclass has the same name as a similar dataclass in Llama Stack.
|
||||||
# Note that this dataclass has the same name as a similar dataclass in
|
|
||||||
# Llama Stack.
|
|
||||||
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
|
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
ChatCompletionRequest(
|
ChatCompletionRequest(
|
||||||
model=self.resolved_model_id,
|
model=self.resolved_model_id,
|
||||||
|
@ -573,7 +553,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if not isinstance(vllm_result, AsyncGenerator):
|
if not isinstance(vllm_result, AsyncGenerator):
|
||||||
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call")
|
||||||
# vLLM client returns a stream of strings, which need to be parsed.
|
# vLLM client returns a stream of strings, which need to be parsed.
|
||||||
# Stream comes in the form of an async generator
|
# Stream comes in the form of an async generator.
|
||||||
return self._convert_streaming_results(vllm_result)
|
return self._convert_streaming_results(vllm_result)
|
||||||
else:
|
else:
|
||||||
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse):
|
||||||
|
@ -587,17 +567,15 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
"""
|
"""
|
||||||
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible
|
Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an
|
||||||
API into an equivalent Llama Stack object.
|
equivalent Llama Stack object.
|
||||||
|
|
||||||
The result from vLLM's non-streaming API is a dataclass with
|
The result from vLLM's non-streaming API is a dataclass with the same name as the Llama
|
||||||
the same name as the Llama Stack ChatCompletionResponse dataclass,
|
Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore
|
||||||
but with more and different field names. We ignore the fields that
|
the fields that aren't currently present in the Llama Stack dataclass.
|
||||||
aren't currently present in the Llama Stack dataclass.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# There may be multiple responses, but we can only pass through the
|
# There may be multiple responses, but we can only pass through the first one.
|
||||||
# first one.
|
|
||||||
if len(vllm_result.choices) == 0:
|
if len(vllm_result.choices) == 0:
|
||||||
raise ValueError("Don't know how to convert response object without any responses")
|
raise ValueError("Don't know how to convert response object without any responses")
|
||||||
vllm_message = vllm_result.choices[0].message
|
vllm_message = vllm_result.choices[0].message
|
||||||
|
@ -634,13 +612,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
:param vllm_result: Stream of strings that need to be parsed
|
:param vllm_result: Stream of strings that need to be parsed
|
||||||
"""
|
"""
|
||||||
# Tool calls come in pieces, but Llama Stack expects them in bigger
|
# Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up
|
||||||
# chunks. We build up those chunks and output them at the end.
|
# those chunks and output them at the end.
|
||||||
# This data structure holds the current set of partial tool calls.
|
# This data structure holds the current set of partial tool calls.
|
||||||
index_to_tool_call: Dict[int, Dict] = dict()
|
index_to_tool_call: Dict[int, Dict] = dict()
|
||||||
|
|
||||||
# The Llama Stack event stream must always start with a start event.
|
# The Llama Stack event stream must always start with a start event. Use an empty one to
|
||||||
# Use an empty one to simplify logic below
|
# simplify logic below
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
event_type=ChatCompletionResponseEventType.start,
|
event_type=ChatCompletionResponseEventType.start,
|
||||||
|
@ -651,8 +629,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
converted_stop_reason = None
|
converted_stop_reason = None
|
||||||
async for chunk_str in vllm_result:
|
async for chunk_str in vllm_result:
|
||||||
# Due to OpenAI compatibility, each event in the stream
|
# Due to OpenAI compatibility, each event in the stream will start with "data: " and
|
||||||
# will start with "data: " and end with "\n\n".
|
# end with "\n\n".
|
||||||
_prefix = "data: "
|
_prefix = "data: "
|
||||||
_suffix = "\n\n"
|
_suffix = "\n\n"
|
||||||
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix):
|
||||||
|
@ -677,8 +655,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
_debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}")
|
||||||
|
|
||||||
# The result may contain multiple completions, but Llama Stack APIs
|
# The result may contain multiple completions, but Llama Stack APIs only support
|
||||||
# only support returning one.
|
# returning one.
|
||||||
first_choice = parsed_chunk["choices"][0]
|
first_choice = parsed_chunk["choices"][0]
|
||||||
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||||
delta_record = first_choice["delta"]
|
delta_record = first_choice["delta"]
|
||||||
|
@ -693,8 +671,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
elif "tool_calls" in delta_record:
|
elif "tool_calls" in delta_record:
|
||||||
# Tool call(s). Llama Stack APIs do not have a clear way to return
|
# Tool call(s). Llama Stack APIs do not have a clear way to return partial tool
|
||||||
# partial tool calls, so buffer until we get a "tool calls" stop reason
|
# calls, so buffer until we get a "tool calls" stop reason
|
||||||
for tc in delta_record["tool_calls"]:
|
for tc in delta_record["tool_calls"]:
|
||||||
index = tc["index"]
|
index = tc["index"]
|
||||||
if index not in index_to_tool_call:
|
if index not in index_to_tool_call:
|
||||||
|
@ -716,8 +694,8 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
|
|
||||||
if first_choice["finish_reason"] == "tool_calls":
|
if first_choice["finish_reason"] == "tool_calls":
|
||||||
# Special OpenAI code for "tool calls complete".
|
# Special OpenAI code for "tool calls complete".
|
||||||
# Output the buffered tool calls. Llama Stack requires a separate
|
# Output the buffered tool calls. Llama Stack requires a separate event per tool
|
||||||
# event per tool call.
|
# call.
|
||||||
for tool_call_record in index_to_tool_call.values():
|
for tool_call_record in index_to_tool_call.values():
|
||||||
# Arguments come in as a string. Parse the completed string.
|
# Arguments come in as a string. Parse the completed string.
|
||||||
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"])
|
||||||
|
@ -731,6 +709,6 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# If we get here, we've lost the connection with the vLLM event stream
|
# If we get here, we've lost the connection with the vLLM event stream before it ended
|
||||||
# before it ended normally.
|
# normally.
|
||||||
raise ValueError("vLLM event stream ended without [DONE] message.")
|
raise ValueError("vLLM event stream ended without [DONE] message.")
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue