Add completion API

This commit is contained in:
Fred Reiss 2025-01-24 10:56:47 -08:00 committed by Ashwin Bharambe
parent 80c357f434
commit fcb87faa36

View file

@ -8,6 +8,7 @@ import datetime
import json import json
import logging import logging
import re import re
import uuid
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
# These vLLM modules contain names that overlap with Llama Stack names, # These vLLM modules contain names that overlap with Llama Stack names,
@ -114,6 +115,10 @@ def _info(msg: str):
# logger.info(msg) # logger.info(msg)
def _random_uuid_str() -> str:
return str(uuid.uuid4().hex)
def _merge_context_into_content(message: Message) -> Message: # type: ignore def _merge_context_into_content(message: Message) -> Message: # type: ignore
""" """
Merge the ``context`` field of a Llama Stack ``Message`` object into Merge the ``context`` field of a Llama Stack ``Message`` object into
@ -169,7 +174,10 @@ def _response_format_to_guided_decoding_params(
layer of vLLM. layer of vLLM.
""" """
if response_format is None: if response_format is None:
return vllm.sampling_params.GuidedDecodingParams() # As of vLLM 0.6.3, the default constructor for GuidedDecodingParams()
# returns an invalid value that crashes the executor on some code
# paths. Use ``None`` instead.
return None
# Llama Stack currently implements fewer types of constrained # Llama Stack currently implements fewer types of constrained
# decoding than vLLM does. Translate the types that exist and # decoding than vLLM does. Translate the types that exist and
@ -440,7 +448,93 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
raise NotImplementedError() if model_id not in self.model_ids:
raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
)
if not isinstance(content, str):
raise NotImplementedError("Multimodal input not currently supported")
if sampling_params is None:
sampling_params = SamplingParams()
if logprobs is not None:
raise NotImplementedError("logprobs argument not currently implemented")
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
if stream:
return self._streaming_completion(content, converted_sampling_params)
else:
streaming_result = None
async for streaming_result in self._streaming_completion(content, converted_sampling_params):
pass
return CompletionResponse(
content=streaming_result.delta,
stop_reason=streaming_result.stop_reason,
logprobs=streaming_result.logprobs,
)
async def _streaming_completion(
self, content: str, sampling_params: vllm.SamplingParams
) -> AsyncIterator[CompletionResponseStreamChunk]:
"""Internal implementation of :func:`completion()` API for the streaming
case. Assumes that arguments have been validated upstream.
:param content: Must be a string
:param sampling_params: Paramters from public API's ``response_format``
and ``sampling_params`` arguments, converted to VLLM format
"""
# We run agains the vLLM generate() call directly instead of using the
# OpenAI-compatible layer, because doing so simplifies the code here.
# The vLLM engine requires a unique identifier for each call to generate()
request_id = _random_uuid_str()
# The vLLM generate() API is streaming-only and returns an async generator.
# The generator returns objects of type vllm.RequestOutput
results_generator = self.engine.generate(content, sampling_params, request_id)
# Need to know the model's EOS token ID for the conversion code below.
# This information is buried pretty deeply.
eos_token_id = self.engine.engine.tokenizer.tokenizer.eos_token_id
request_output: vllm.RequestOutput = None
async for request_output in results_generator:
# Check for weird inference failures
if request_output.outputs is None or len(request_output.outputs) == 0:
# This case also should never happen
raise ValueError("Inference produced empty result")
# If we get here, then request_output contains the final output of the
# generate() call. There should be one or more output chunks.
completion_string = "".join([output.text for output in request_output.outputs])
# The final output chunk should be labeled with the reason that the
# overall generate() call completed.
stop_reason_str = request_output.outputs[-1].stop_reason
if stop_reason_str is None:
stop_reason = None # Still going
elif stop_reason_str == "stop":
stop_reason = StopReason.end_of_turn
elif stop_reason_str == "length":
stop_reason = StopReason.out_of_tokens
else:
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
# _info(f"completion string: {completion_string}")
# _info(f"stop reason: {stop_reason_str}")
# _info(f"completion tokens: {completion_tokens}")
# vLLM's protocol outputs the stop token, then sets end of message
# on the next step for some reason.
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
stop_reason = StopReason.end_of_message
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason)
# Llama Stack requires that the last chunk have a stop reason, but
# vLLM doesn't always provide one if it runs out of tokens.
if stop_reason is None:
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens)
async def embeddings( async def embeddings(
self, self,
@ -460,21 +554,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_prompt_format: Optional[ToolPromptFormat] = None, tool_prompt_format: Optional[ToolPromptFormat] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
# model_id: str,
# messages: List[Message], # type: ignore
# sampling_params: Optional[SamplingParams] = SamplingParams(),
# tools: Optional[List[ToolDefinition]] = None,
# tool_choice: Optional[ToolChoice] = ToolChoice.auto,
# tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
# response_format: Optional[ResponseFormat] = None,
# stream: Optional[bool] = False,
# logprobs: Optional[LogProbConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
sampling_params = sampling_params or SamplingParams()
if model_id not in self.model_ids: if model_id not in self.model_ids:
raise ValueError( raise ValueError(
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
) )
if logprobs is not None:
raise NotImplementedError("logprobs argument not currently implemented")
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest # Arguments to the vLLM call must be packaged as a ChatCompletionRequest
# dataclass. # dataclass.