Use shared code where possible

This commit is contained in:
Fred Reiss 2025-01-24 22:30:53 -08:00 committed by Ashwin Bharambe
parent 25c780802f
commit 29ae2552fd
2 changed files with 220 additions and 146 deletions

View file

@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
import vllm
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
JsonSchemaResponseFormat,
Message,
ToolChoice,
UserMessage,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
###############################################################################
# This file contains OpenAI compatibility code that is currently only used
# by the inline vLLM connector. Some or all of this code may be moved to a
# central location at a later date.
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _llama_stack_tools_to_openai_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description}
for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(
vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(
function=function_def
)
)
return result
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a chat completion request in Llama Stack format into an
equivalent set of arguments to pass to an OpenAI-compatible
chat completions API.
:param request: Bundled request parameters in Llama Stack format.
:returns: Dictionary of key-value pairs to use as an initializer
for a dataclass or to be converted directly to JSON and sent
over the wire.
"""
converted_messages = [
# This mystery async call makes the parent function also be async
await convert_message_to_openai_dict(
_merge_context_into_content(m), download=True
)
for m in request.messages
]
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if (
request.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
# Use Llama Stack shared code to translate sampling parameters.
sampling_options = get_sampling_options(request.sampling_params)
# get_sampling_options() translates repetition penalties to an option that
# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
if "repeat_penalty" in sampling_options:
del sampling_options["repeat_penalty"]
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
sampling_options["repetition_penalty"] = (
request.sampling_params.repetition_penalty
)
# Convert a single response format into four different parameters, per
# the OpenAI spec
guided_decoding_options = dict()
if request.response_format is None:
# Use defaults
pass
elif isinstance(request.response_format, JsonSchemaResponseFormat):
guided_decoding_options["guided_json"] = request.response_format.json_schema
elif isinstance(request.response_format, GrammarResponseFormat):
guided_decoding_options["guided_grammar"] = request.response_format.bnf
else:
raise TypeError(
f"ResponseFormat object is of unexpected "
f"subtype '{type(request.response_format)}'"
)
logprob_options = dict()
if request.logprobs is not None:
logprob_options["logprobs"] = request.logprobs.top_k
# Marshall together all the arguments for a ChatCompletionRequest
request_options = {
"model": request.model,
"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
}
request_options.update(sampling_options)
request_options.update(guided_decoding_options)
request_options.update(logprob_options)
return request_options

View file

@ -16,6 +16,19 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import vllm.entrypoints.openai.protocol import vllm.entrypoints.openai.protocol
import vllm.sampling_params import vllm.sampling_params
############################################################################
# llama_models imports go here
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.tokenizer import Tokenizer
############################################################################ ############################################################################
# vLLM imports go here # vLLM imports go here
# #
@ -33,6 +46,7 @@ from llama_stack.apis.common.content_types import (
ToolCallDelta, ToolCallDelta,
) )
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse, ChatCompletionResponse,
ChatCompletionResponseEvent, ChatCompletionResponseEvent,
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -50,9 +64,6 @@ from llama_stack.apis.inference import (
TokenLogProbs, TokenLogProbs,
ToolCall, ToolCall,
ToolChoice, ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
) )
from llama_stack.apis.models import Model from llama_stack.apis.models import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat from llama_stack.models.llama.llama3.chat_format import ChatFormat
@ -64,27 +75,12 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
ModelsProtocolPrivate, ModelsProtocolPrivate,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
ResponseFormat,
ToolCall,
ToolChoice,
UserMessage,
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
############################################################################ ############################################################################
# Package-local imports go here # Package-local imports go here
from .config import VLLMConfig from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
############################################################################ ############################################################################
# Constants go here # Constants go here
@ -119,54 +115,12 @@ def _random_uuid_str() -> str:
return str(uuid.uuid4().hex) return str(uuid.uuid4().hex)
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _convert_finish_reason(finish_reason: str | None) -> str | None:
"""Convert an OpenAI "finish_reason" result to the equivalent
Llama Stack result code.
"""
# This conversion is currently a wild guess.
if finish_reason is None:
return None
elif finish_reason == "stop":
return StopReason.end_of_turn
else:
return StopReason.out_of_tokens
def _response_format_to_guided_decoding_params( def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore response_format: Optional[ResponseFormat], # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams: ) -> vllm.sampling_params.GuidedDecodingParams:
""" """
Like Llama Stack, vLLM's OpenAI-compatible API also uses the name Translate constrained decoding parameters from Llama Stack's
"ResponseFormat" to describe the object that is a wrapper around format to vLLM's format.
another object that is a wrapper around another object inside
someone else's constrained decoding library.
Here we translate from Llama Stack's wrapper code to vLLM's code
that does the same.
:param response_format: Llama Stack version of constrained decoding :param response_format: Llama Stack version of constrained decoding
info. Can be ``None``, indicating no constraints. info. Can be ``None``, indicating no constraints.
@ -244,42 +198,6 @@ def _convert_sampling_params(
return vllm_sampling_params return vllm_sampling_params
def _convert_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
return result
############################################################################ ############################################################################
# Class definitions go here # Class definitions go here
@ -582,51 +500,20 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# dataclass. # dataclass.
# Note that this dataclass has the same name as a similar dataclass in # Note that this dataclass has the same name as a similar dataclass in
# Llama Stack. # Llama Stack.
converted_messages = [ request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages ChatCompletionRequest(
] model=self.resolved_model_id,
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) messages=messages,
converted_tools = _convert_tools(tools) sampling_params=sampling_params,
response_format=response_format,
# Llama will try to use built-in tools with no tool catalog, so don't enable tools=tools,
# tool choice unless at least one tool is enabled. tool_choice=tool_choice,
converted_tool_choice = "none" tool_prompt_format=tool_prompt_format,
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0: stream=stream,
converted_tool_choice = "auto" logprobs=logprobs,
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
model=self.resolved_model_id,
messages=converted_messages,
tools=converted_tools,
tool_choice=converted_tool_choice,
stream=stream,
)
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
# keyword args instead of a vLLM SamplingParams object. Copy over
# all the parts that we currently convert from Llama Stack format.
for param_name in [
"max_tokens",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
]:
setattr(
chat_completion_request,
param_name,
getattr(converted_sampling_params, param_name),
) )
)
# Guided decoding parameters are further broken out chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
if converted_sampling_params.guided_decoding is not None:
g = converted_sampling_params.guided_decoding
chat_completion_request.guided_json = g.json
chat_completion_request.guided_regex = g.regex
chat_completion_request.guided_grammar = g.grammar
_info(f"Converted request: {chat_completion_request}") _info(f"Converted request: {chat_completion_request}")
@ -668,12 +555,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if len(vllm_result.choices) == 0: if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses") raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message vllm_message = vllm_result.choices[0].message
vllm_finish_reason = vllm_result.choices[0].finish_reason
converted_message = CompletionMessage( converted_message = CompletionMessage(
role=vllm_message.role, role=vllm_message.role,
# Llama Stack API won't accept None for content field. # Llama Stack API won't accept None for content field.
content=("" if vllm_message.content is None else vllm_message.content), content=("" if vllm_message.content is None else vllm_message.content),
stop_reason=_convert_finish_reason(vllm_result.choices[0].finish_reason), stop_reason=get_stop_reason(vllm_finish_reason),
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
call_id=t.id, call_id=t.id,
@ -746,7 +634,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# The result may contain multiple completions, but Llama Stack APIs # The result may contain multiple completions, but Llama Stack APIs
# only support returning one. # only support returning one.
first_choice = parsed_chunk["choices"][0] first_choice = parsed_chunk["choices"][0]
converted_stop_reason = _convert_finish_reason(first_choice["finish_reason"]) converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"] delta_record = first_choice["delta"]
if "content" in delta_record: if "content" in delta_record: