mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
Use shared code where possible
This commit is contained in:
parent
25c780802f
commit
29ae2552fd
2 changed files with 220 additions and 146 deletions
186
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
186
llama_stack/providers/inline/inference/vllm/openai_utils.py
Normal file
|
@ -0,0 +1,186 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
import vllm
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
GrammarResponseFormat,
|
||||||
|
JsonSchemaResponseFormat,
|
||||||
|
Message,
|
||||||
|
ToolChoice,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
convert_message_to_openai_dict,
|
||||||
|
get_sampling_options,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# This file contains OpenAI compatibility code that is currently only used
|
||||||
|
# by the inline vLLM connector. Some or all of this code may be moved to a
|
||||||
|
# central location at a later date.
|
||||||
|
|
||||||
|
|
||||||
|
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
||||||
|
"""
|
||||||
|
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
||||||
|
the content field for compabilitiy with OpenAI-style APIs.
|
||||||
|
|
||||||
|
Generates a content string that emulates the current behavior
|
||||||
|
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
||||||
|
|
||||||
|
:param message: Message that may include ``context`` field
|
||||||
|
|
||||||
|
:returns: A version of ``message`` with any context merged into the
|
||||||
|
``content`` field.
|
||||||
|
"""
|
||||||
|
if not isinstance(message, UserMessage): # Separate type check for linter
|
||||||
|
return message
|
||||||
|
if message.context is None:
|
||||||
|
return message
|
||||||
|
return UserMessage(
|
||||||
|
role=message.role,
|
||||||
|
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
||||||
|
content=message.content + "\n\n" + message.context,
|
||||||
|
context=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _llama_stack_tools_to_openai_tools(
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
||||||
|
"""
|
||||||
|
Convert the list of available tools from Llama Stack's format to vLLM's
|
||||||
|
version of OpenAI's format.
|
||||||
|
"""
|
||||||
|
if tools is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
result = []
|
||||||
|
for t in tools:
|
||||||
|
if isinstance(t.tool_name, BuiltinTool):
|
||||||
|
raise NotImplementedError("Built-in tools not yet implemented")
|
||||||
|
if t.parameters is None:
|
||||||
|
parameters = None
|
||||||
|
else: # if t.parameters is not None
|
||||||
|
# Convert the "required" flags to a list of required params
|
||||||
|
required_params = [k for k, v in t.parameters.items() if v.required]
|
||||||
|
parameters = {
|
||||||
|
"type": "object", # Mystery value that shows up in OpenAI docs
|
||||||
|
"properties": {
|
||||||
|
k: {"type": v.param_type, "description": v.description}
|
||||||
|
for k, v in t.parameters.items()
|
||||||
|
},
|
||||||
|
"required": required_params,
|
||||||
|
}
|
||||||
|
|
||||||
|
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
||||||
|
name=t.tool_name, description=t.description, parameters=parameters
|
||||||
|
)
|
||||||
|
|
||||||
|
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
||||||
|
result.append(
|
||||||
|
vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(
|
||||||
|
function=function_def
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
|
request: ChatCompletionRequest,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Convert a chat completion request in Llama Stack format into an
|
||||||
|
equivalent set of arguments to pass to an OpenAI-compatible
|
||||||
|
chat completions API.
|
||||||
|
|
||||||
|
:param request: Bundled request parameters in Llama Stack format.
|
||||||
|
|
||||||
|
:returns: Dictionary of key-value pairs to use as an initializer
|
||||||
|
for a dataclass or to be converted directly to JSON and sent
|
||||||
|
over the wire.
|
||||||
|
"""
|
||||||
|
|
||||||
|
converted_messages = [
|
||||||
|
# This mystery async call makes the parent function also be async
|
||||||
|
await convert_message_to_openai_dict(
|
||||||
|
_merge_context_into_content(m), download=True
|
||||||
|
)
|
||||||
|
for m in request.messages
|
||||||
|
]
|
||||||
|
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
|
||||||
|
|
||||||
|
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
||||||
|
# tool choice unless at least one tool is enabled.
|
||||||
|
converted_tool_choice = "none"
|
||||||
|
if (
|
||||||
|
request.tool_choice == ToolChoice.auto
|
||||||
|
and request.tools is not None
|
||||||
|
and len(request.tools) > 0
|
||||||
|
):
|
||||||
|
converted_tool_choice = "auto"
|
||||||
|
|
||||||
|
# TODO: Figure out what to do with the tool_prompt_format argument.
|
||||||
|
# Other connectors appear to drop it quietly.
|
||||||
|
|
||||||
|
# Use Llama Stack shared code to translate sampling parameters.
|
||||||
|
sampling_options = get_sampling_options(request.sampling_params)
|
||||||
|
|
||||||
|
# get_sampling_options() translates repetition penalties to an option that
|
||||||
|
# OpenAI's APIs don't know about.
|
||||||
|
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
|
||||||
|
# For now, translate repetition penalties into a format that vLLM's broken
|
||||||
|
# API will handle correctly. Two wrongs make a right...
|
||||||
|
if "repeat_penalty" in sampling_options:
|
||||||
|
del sampling_options["repeat_penalty"]
|
||||||
|
if (
|
||||||
|
request.sampling_params.repetition_penalty is not None
|
||||||
|
and request.sampling_params.repetition_penalty != 1.0
|
||||||
|
):
|
||||||
|
sampling_options["repetition_penalty"] = (
|
||||||
|
request.sampling_params.repetition_penalty
|
||||||
|
)
|
||||||
|
|
||||||
|
# Convert a single response format into four different parameters, per
|
||||||
|
# the OpenAI spec
|
||||||
|
guided_decoding_options = dict()
|
||||||
|
if request.response_format is None:
|
||||||
|
# Use defaults
|
||||||
|
pass
|
||||||
|
elif isinstance(request.response_format, JsonSchemaResponseFormat):
|
||||||
|
guided_decoding_options["guided_json"] = request.response_format.json_schema
|
||||||
|
elif isinstance(request.response_format, GrammarResponseFormat):
|
||||||
|
guided_decoding_options["guided_grammar"] = request.response_format.bnf
|
||||||
|
else:
|
||||||
|
raise TypeError(
|
||||||
|
f"ResponseFormat object is of unexpected "
|
||||||
|
f"subtype '{type(request.response_format)}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
logprob_options = dict()
|
||||||
|
if request.logprobs is not None:
|
||||||
|
logprob_options["logprobs"] = request.logprobs.top_k
|
||||||
|
|
||||||
|
# Marshall together all the arguments for a ChatCompletionRequest
|
||||||
|
request_options = {
|
||||||
|
"model": request.model,
|
||||||
|
"messages": converted_messages,
|
||||||
|
"tools": converted_tools,
|
||||||
|
"tool_choice": converted_tool_choice,
|
||||||
|
"stream": request.stream,
|
||||||
|
}
|
||||||
|
request_options.update(sampling_options)
|
||||||
|
request_options.update(guided_decoding_options)
|
||||||
|
request_options.update(logprob_options)
|
||||||
|
|
||||||
|
return request_options
|
|
@ -16,6 +16,19 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
import vllm.entrypoints.openai.protocol
|
import vllm.entrypoints.openai.protocol
|
||||||
import vllm.sampling_params
|
import vllm.sampling_params
|
||||||
|
|
||||||
|
############################################################################
|
||||||
|
# llama_models imports go here
|
||||||
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
|
from llama_models.llama3.api.datatypes import (
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
# vLLM imports go here
|
# vLLM imports go here
|
||||||
#
|
#
|
||||||
|
@ -33,6 +46,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEvent,
|
ChatCompletionResponseEvent,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
|
@ -50,9 +64,6 @@ from llama_stack.apis.inference import (
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
|
||||||
ToolDefinition,
|
|
||||||
ToolPromptFormat,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
|
@ -64,27 +75,12 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
|
||||||
GrammarResponseFormat,
|
|
||||||
Inference,
|
|
||||||
JsonSchemaResponseFormat,
|
|
||||||
LogProbConfig,
|
|
||||||
Message,
|
|
||||||
OpenAICompatCompletionChoice,
|
|
||||||
OpenAICompatCompletionResponse,
|
|
||||||
ResponseFormat,
|
|
||||||
ToolCall,
|
|
||||||
ToolChoice,
|
|
||||||
UserMessage,
|
|
||||||
convert_message_to_openai_dict,
|
|
||||||
get_sampling_options,
|
|
||||||
process_chat_completion_response,
|
|
||||||
process_chat_completion_stream_response,
|
|
||||||
)
|
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
# Package-local imports go here
|
# Package-local imports go here
|
||||||
from .config import VLLMConfig
|
from .config import VLLMConfig
|
||||||
|
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
# Constants go here
|
# Constants go here
|
||||||
|
@ -119,54 +115,12 @@ def _random_uuid_str() -> str:
|
||||||
return str(uuid.uuid4().hex)
|
return str(uuid.uuid4().hex)
|
||||||
|
|
||||||
|
|
||||||
def _merge_context_into_content(message: Message) -> Message: # type: ignore
|
|
||||||
"""
|
|
||||||
Merge the ``context`` field of a Llama Stack ``Message`` object into
|
|
||||||
the content field for compabilitiy with OpenAI-style APIs.
|
|
||||||
|
|
||||||
Generates a content string that emulates the current behavior
|
|
||||||
of ``llama_models.llama3.api.chat_format.encode_message()``.
|
|
||||||
|
|
||||||
:param message: Message that may include ``context`` field
|
|
||||||
|
|
||||||
:returns: A version of ``message`` with any context merged into the
|
|
||||||
``content`` field.
|
|
||||||
"""
|
|
||||||
if not isinstance(message, UserMessage): # Separate type check for linter
|
|
||||||
return message
|
|
||||||
if message.context is None:
|
|
||||||
return message
|
|
||||||
return UserMessage(
|
|
||||||
role=message.role,
|
|
||||||
# Emumate llama_models.llama3.api.chat_format.encode_message()
|
|
||||||
content=message.content + "\n\n" + message.context,
|
|
||||||
context=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_finish_reason(finish_reason: str | None) -> str | None:
|
|
||||||
"""Convert an OpenAI "finish_reason" result to the equivalent
|
|
||||||
Llama Stack result code.
|
|
||||||
"""
|
|
||||||
# This conversion is currently a wild guess.
|
|
||||||
if finish_reason is None:
|
|
||||||
return None
|
|
||||||
elif finish_reason == "stop":
|
|
||||||
return StopReason.end_of_turn
|
|
||||||
else:
|
|
||||||
return StopReason.out_of_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def _response_format_to_guided_decoding_params(
|
def _response_format_to_guided_decoding_params(
|
||||||
response_format: Optional[ResponseFormat], # type: ignore
|
response_format: Optional[ResponseFormat], # type: ignore
|
||||||
) -> vllm.sampling_params.GuidedDecodingParams:
|
) -> vllm.sampling_params.GuidedDecodingParams:
|
||||||
"""
|
"""
|
||||||
Like Llama Stack, vLLM's OpenAI-compatible API also uses the name
|
Translate constrained decoding parameters from Llama Stack's
|
||||||
"ResponseFormat" to describe the object that is a wrapper around
|
format to vLLM's format.
|
||||||
another object that is a wrapper around another object inside
|
|
||||||
someone else's constrained decoding library.
|
|
||||||
Here we translate from Llama Stack's wrapper code to vLLM's code
|
|
||||||
that does the same.
|
|
||||||
|
|
||||||
:param response_format: Llama Stack version of constrained decoding
|
:param response_format: Llama Stack version of constrained decoding
|
||||||
info. Can be ``None``, indicating no constraints.
|
info. Can be ``None``, indicating no constraints.
|
||||||
|
@ -244,42 +198,6 @@ def _convert_sampling_params(
|
||||||
return vllm_sampling_params
|
return vllm_sampling_params
|
||||||
|
|
||||||
|
|
||||||
def _convert_tools(
|
|
||||||
tools: Optional[List[ToolDefinition]] = None,
|
|
||||||
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
|
|
||||||
"""
|
|
||||||
Convert the list of available tools from Llama Stack's format to vLLM's
|
|
||||||
version of OpenAI's format.
|
|
||||||
"""
|
|
||||||
if tools is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = []
|
|
||||||
for t in tools:
|
|
||||||
if isinstance(t.tool_name, BuiltinTool):
|
|
||||||
raise NotImplementedError("Built-in tools not yet implemented")
|
|
||||||
if t.parameters is None:
|
|
||||||
parameters = None
|
|
||||||
else: # if t.parameters is not None
|
|
||||||
# Convert the "required" flags to a list of required params
|
|
||||||
required_params = [k for k, v in t.parameters.items() if v.required]
|
|
||||||
parameters = {
|
|
||||||
"type": "object", # Mystery value that shows up in OpenAI docs
|
|
||||||
"properties": {
|
|
||||||
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
|
|
||||||
},
|
|
||||||
"required": required_params,
|
|
||||||
}
|
|
||||||
|
|
||||||
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
|
|
||||||
name=t.tool_name, description=t.description, parameters=parameters
|
|
||||||
)
|
|
||||||
|
|
||||||
# Every tool definition is double-boxed in a ChatCompletionToolsParam
|
|
||||||
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
############################################################################
|
############################################################################
|
||||||
# Class definitions go here
|
# Class definitions go here
|
||||||
|
|
||||||
|
@ -582,51 +500,20 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
# dataclass.
|
# dataclass.
|
||||||
# Note that this dataclass has the same name as a similar dataclass in
|
# Note that this dataclass has the same name as a similar dataclass in
|
||||||
# Llama Stack.
|
# Llama Stack.
|
||||||
converted_messages = [
|
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
|
||||||
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
|
ChatCompletionRequest(
|
||||||
]
|
model=self.resolved_model_id,
|
||||||
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
|
messages=messages,
|
||||||
converted_tools = _convert_tools(tools)
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
# Llama will try to use built-in tools with no tool catalog, so don't enable
|
tools=tools,
|
||||||
# tool choice unless at least one tool is enabled.
|
tool_choice=tool_choice,
|
||||||
converted_tool_choice = "none"
|
tool_prompt_format=tool_prompt_format,
|
||||||
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
|
stream=stream,
|
||||||
converted_tool_choice = "auto"
|
logprobs=logprobs,
|
||||||
|
|
||||||
# TODO: Figure out what to do with the tool_prompt_format argument.
|
|
||||||
# Other connectors appear to drop it quietly.
|
|
||||||
|
|
||||||
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
|
|
||||||
model=self.resolved_model_id,
|
|
||||||
messages=converted_messages,
|
|
||||||
tools=converted_tools,
|
|
||||||
tool_choice=converted_tool_choice,
|
|
||||||
stream=stream,
|
|
||||||
)
|
|
||||||
|
|
||||||
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
|
|
||||||
# keyword args instead of a vLLM SamplingParams object. Copy over
|
|
||||||
# all the parts that we currently convert from Llama Stack format.
|
|
||||||
for param_name in [
|
|
||||||
"max_tokens",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
"repetition_penalty",
|
|
||||||
]:
|
|
||||||
setattr(
|
|
||||||
chat_completion_request,
|
|
||||||
param_name,
|
|
||||||
getattr(converted_sampling_params, param_name),
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# Guided decoding parameters are further broken out
|
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
|
||||||
if converted_sampling_params.guided_decoding is not None:
|
|
||||||
g = converted_sampling_params.guided_decoding
|
|
||||||
chat_completion_request.guided_json = g.json
|
|
||||||
chat_completion_request.guided_regex = g.regex
|
|
||||||
chat_completion_request.guided_grammar = g.grammar
|
|
||||||
|
|
||||||
_info(f"Converted request: {chat_completion_request}")
|
_info(f"Converted request: {chat_completion_request}")
|
||||||
|
|
||||||
|
@ -668,12 +555,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
if len(vllm_result.choices) == 0:
|
if len(vllm_result.choices) == 0:
|
||||||
raise ValueError("Don't know how to convert response object without any responses")
|
raise ValueError("Don't know how to convert response object without any responses")
|
||||||
vllm_message = vllm_result.choices[0].message
|
vllm_message = vllm_result.choices[0].message
|
||||||
|
vllm_finish_reason = vllm_result.choices[0].finish_reason
|
||||||
|
|
||||||
converted_message = CompletionMessage(
|
converted_message = CompletionMessage(
|
||||||
role=vllm_message.role,
|
role=vllm_message.role,
|
||||||
# Llama Stack API won't accept None for content field.
|
# Llama Stack API won't accept None for content field.
|
||||||
content=("" if vllm_message.content is None else vllm_message.content),
|
content=("" if vllm_message.content is None else vllm_message.content),
|
||||||
stop_reason=_convert_finish_reason(vllm_result.choices[0].finish_reason),
|
stop_reason=get_stop_reason(vllm_finish_reason),
|
||||||
tool_calls=[
|
tool_calls=[
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=t.id,
|
call_id=t.id,
|
||||||
|
@ -746,7 +634,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
# The result may contain multiple completions, but Llama Stack APIs
|
# The result may contain multiple completions, but Llama Stack APIs
|
||||||
# only support returning one.
|
# only support returning one.
|
||||||
first_choice = parsed_chunk["choices"][0]
|
first_choice = parsed_chunk["choices"][0]
|
||||||
converted_stop_reason = _convert_finish_reason(first_choice["finish_reason"])
|
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
|
||||||
delta_record = first_choice["delta"]
|
delta_record = first_choice["delta"]
|
||||||
|
|
||||||
if "content" in delta_record:
|
if "content" in delta_record:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue