mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Add completion API
This commit is contained in:
parent
80c357f434
commit
fcb87faa36
1 changed files with 98 additions and 12 deletions
|
@ -8,6 +8,7 @@ import datetime
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
|
import uuid
|
||||||
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
# These vLLM modules contain names that overlap with Llama Stack names,
|
# These vLLM modules contain names that overlap with Llama Stack names,
|
||||||
|
@ -114,6 +115,10 @@ def _info(msg: str):
|
||||||
# logger.info(msg)
|
# logger.info(msg)
|
||||||
|
|
||||||
|
|
||||||
|
def _random_uuid_str() -> str:
|
||||||
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||||
"""
|
"""
|
||||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||||
|
@ -169,7 +174,10 @@ def _response_format_to_guided_decoding_params(
|
||||||
layer of vLLM.
|
layer of vLLM.
|
||||||
"""
|
"""
|
||||||
if response_format is None:
|
if response_format is None:
|
||||||
return vllm.sampling_params.GuidedDecodingParams()
|
# As of vLLM 0.6.3, the default constructor for GuidedDecodingParams()
|
||||||
|
# returns an invalid value that crashes the executor on some code
|
||||||
|
# paths. Use ``None`` instead.
|
||||||
|
return None
|
||||||
|
|
||||||
# Llama Stack currently implements fewer types of constrained
|
# Llama Stack currently implements fewer types of constrained
|
||||||
# decoding than vLLM does. Translate the types that exist and
|
# decoding than vLLM does. Translate the types that exist and
|
||||||
|
@ -440,7 +448,93 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]:
|
||||||
raise NotImplementedError()
|
if model_id not in self.model_ids:
|
||||||
|
raise ValueError(
|
||||||
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||||
|
)
|
||||||
|
if not isinstance(content, str):
|
||||||
|
raise NotImplementedError("Multimodal input not currently supported")
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
if logprobs is not None:
|
||||||
|
raise NotImplementedError("logprobs argument not currently implemented")
|
||||||
|
|
||||||
|
converted_sampling_params = _convert_sampling_params(sampling_params, response_format)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._streaming_completion(content, converted_sampling_params)
|
||||||
|
else:
|
||||||
|
streaming_result = None
|
||||||
|
async for streaming_result in self._streaming_completion(content, converted_sampling_params):
|
||||||
|
pass
|
||||||
|
return CompletionResponse(
|
||||||
|
content=streaming_result.delta,
|
||||||
|
stop_reason=streaming_result.stop_reason,
|
||||||
|
logprobs=streaming_result.logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _streaming_completion(
|
||||||
|
self, content: str, sampling_params: vllm.SamplingParams
|
||||||
|
) -> AsyncIterator[CompletionResponseStreamChunk]:
|
||||||
|
"""Internal implementation of :func:`completion()` API for the streaming
|
||||||
|
case. Assumes that arguments have been validated upstream.
|
||||||
|
|
||||||
|
:param content: Must be a string
|
||||||
|
:param sampling_params: Paramters from public API's ``response_format``
|
||||||
|
and ``sampling_params`` arguments, converted to VLLM format
|
||||||
|
"""
|
||||||
|
# We run agains the vLLM generate() call directly instead of using the
|
||||||
|
# OpenAI-compatible layer, because doing so simplifies the code here.
|
||||||
|
|
||||||
|
# The vLLM engine requires a unique identifier for each call to generate()
|
||||||
|
request_id = _random_uuid_str()
|
||||||
|
|
||||||
|
# The vLLM generate() API is streaming-only and returns an async generator.
|
||||||
|
# The generator returns objects of type vllm.RequestOutput
|
||||||
|
results_generator = self.engine.generate(content, sampling_params, request_id)
|
||||||
|
|
||||||
|
# Need to know the model's EOS token ID for the conversion code below.
|
||||||
|
# This information is buried pretty deeply.
|
||||||
|
eos_token_id = self.engine.engine.tokenizer.tokenizer.eos_token_id
|
||||||
|
|
||||||
|
request_output: vllm.RequestOutput = None
|
||||||
|
async for request_output in results_generator:
|
||||||
|
# Check for weird inference failures
|
||||||
|
if request_output.outputs is None or len(request_output.outputs) == 0:
|
||||||
|
# This case also should never happen
|
||||||
|
raise ValueError("Inference produced empty result")
|
||||||
|
|
||||||
|
# If we get here, then request_output contains the final output of the
|
||||||
|
# generate() call. There should be one or more output chunks.
|
||||||
|
completion_string = "".join([output.text for output in request_output.outputs])
|
||||||
|
|
||||||
|
# The final output chunk should be labeled with the reason that the
|
||||||
|
# overall generate() call completed.
|
||||||
|
stop_reason_str = request_output.outputs[-1].stop_reason
|
||||||
|
if stop_reason_str is None:
|
||||||
|
stop_reason = None # Still going
|
||||||
|
elif stop_reason_str == "stop":
|
||||||
|
stop_reason = StopReason.end_of_turn
|
||||||
|
elif stop_reason_str == "length":
|
||||||
|
stop_reason = StopReason.out_of_tokens
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unrecognized stop reason '{stop_reason_str}'")
|
||||||
|
|
||||||
|
# _info(f"completion string: {completion_string}")
|
||||||
|
# _info(f"stop reason: {stop_reason_str}")
|
||||||
|
# _info(f"completion tokens: {completion_tokens}")
|
||||||
|
|
||||||
|
# vLLM's protocol outputs the stop token, then sets end of message
|
||||||
|
# on the next step for some reason.
|
||||||
|
if request_output.outputs[-1].token_ids[-1] == eos_token_id:
|
||||||
|
stop_reason = StopReason.end_of_message
|
||||||
|
|
||||||
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason)
|
||||||
|
|
||||||
|
# Llama Stack requires that the last chunk have a stop reason, but
|
||||||
|
# vLLM doesn't always provide one if it runs out of tokens.
|
||||||
|
if stop_reason is None:
|
||||||
|
yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=StopReason.out_of_tokens)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -460,21 +554,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
# model_id: str,
|
|
||||||
# messages: List[Message], # type: ignore
|
|
||||||
# sampling_params: Optional[SamplingParams] = SamplingParams(),
|
|
||||||
# tools: Optional[List[ToolDefinition]] = None,
|
|
||||||
# tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
|
||||||
# tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
|
||||||
# response_format: Optional[ResponseFormat] = None,
|
|
||||||
# stream: Optional[bool] = False,
|
|
||||||
# logprobs: Optional[LogProbConfig] = None,
|
|
||||||
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
|
||||||
sampling_params = sampling_params or SamplingParams()
|
|
||||||
if model_id not in self.model_ids:
|
if model_id not in self.model_ids:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}"
|
||||||
)
|
)
|
||||||
|
if logprobs is not None:
|
||||||
|
raise NotImplementedError("logprobs argument not currently implemented")
|
||||||
|
|
||||||
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
# Arguments to the vLLM call must be packaged as a ChatCompletionRequest
|
||||||
# dataclass.
|
# dataclass.
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue