Use shared code where possible

This commit is contained in:
Fred Reiss 2025-01-24 22:30:53 -08:00 committed by Ashwin Bharambe
parent 25c780802f
commit 29ae2552fd
2 changed files with 220 additions and 146 deletions

View file

@ -0,0 +1,186 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List, Optional
import vllm
from llama_models.llama3.api.datatypes import BuiltinTool, ToolDefinition
from llama_stack.apis.inference import (
ChatCompletionRequest,
GrammarResponseFormat,
JsonSchemaResponseFormat,
Message,
ToolChoice,
UserMessage,
)
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
)
###############################################################################
# This file contains OpenAI compatibility code that is currently only used
# by the inline vLLM connector. Some or all of this code may be moved to a
# central location at a later date.
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _llama_stack_tools_to_openai_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description}
for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(
vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(
function=function_def
)
)
return result
async def llama_stack_chat_completion_to_openai_chat_completion_dict(
request: ChatCompletionRequest,
) -> dict:
"""
Convert a chat completion request in Llama Stack format into an
equivalent set of arguments to pass to an OpenAI-compatible
chat completions API.
:param request: Bundled request parameters in Llama Stack format.
:returns: Dictionary of key-value pairs to use as an initializer
for a dataclass or to be converted directly to JSON and sent
over the wire.
"""
converted_messages = [
# This mystery async call makes the parent function also be async
await convert_message_to_openai_dict(
_merge_context_into_content(m), download=True
)
for m in request.messages
]
converted_tools = _llama_stack_tools_to_openai_tools(request.tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if (
request.tool_choice == ToolChoice.auto
and request.tools is not None
and len(request.tools) > 0
):
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
# Use Llama Stack shared code to translate sampling parameters.
sampling_options = get_sampling_options(request.sampling_params)
# get_sampling_options() translates repetition penalties to an option that
# OpenAI's APIs don't know about.
# vLLM's OpenAI-compatible API also handles repetition penalties wrong.
# For now, translate repetition penalties into a format that vLLM's broken
# API will handle correctly. Two wrongs make a right...
if "repeat_penalty" in sampling_options:
del sampling_options["repeat_penalty"]
if (
request.sampling_params.repetition_penalty is not None
and request.sampling_params.repetition_penalty != 1.0
):
sampling_options["repetition_penalty"] = (
request.sampling_params.repetition_penalty
)
# Convert a single response format into four different parameters, per
# the OpenAI spec
guided_decoding_options = dict()
if request.response_format is None:
# Use defaults
pass
elif isinstance(request.response_format, JsonSchemaResponseFormat):
guided_decoding_options["guided_json"] = request.response_format.json_schema
elif isinstance(request.response_format, GrammarResponseFormat):
guided_decoding_options["guided_grammar"] = request.response_format.bnf
else:
raise TypeError(
f"ResponseFormat object is of unexpected "
f"subtype '{type(request.response_format)}'"
)
logprob_options = dict()
if request.logprobs is not None:
logprob_options["logprobs"] = request.logprobs.top_k
# Marshall together all the arguments for a ChatCompletionRequest
request_options = {
"model": request.model,
"messages": converted_messages,
"tools": converted_tools,
"tool_choice": converted_tool_choice,
"stream": request.stream,
}
request_options.update(sampling_options)
request_options.update(guided_decoding_options)
request_options.update(logprob_options)
return request_options

View file

@ -16,6 +16,19 @@ from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import vllm.entrypoints.openai.protocol
import vllm.sampling_params
############################################################################
# llama_models imports go here
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import (
SamplingParams,
StopReason,
ToolDefinition,
ToolPromptFormat,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_models.llama3.api.tokenizer import Tokenizer
############################################################################
# vLLM imports go here
#
@ -33,6 +46,7 @@ from llama_stack.apis.common.content_types import (
ToolCallDelta,
)
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionResponseEvent,
ChatCompletionResponseEventType,
@ -50,9 +64,6 @@ from llama_stack.apis.inference import (
TokenLogProbs,
ToolCall,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.llama3.chat_format import ChatFormat
@ -64,27 +75,12 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
ModelsProtocolPrivate,
)
from llama_stack.providers.utils.inference.openai_compat import (
GrammarResponseFormat,
Inference,
JsonSchemaResponseFormat,
LogProbConfig,
Message,
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
ResponseFormat,
ToolCall,
ToolChoice,
UserMessage,
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_compat import get_stop_reason
############################################################################
# Package-local imports go here
from .config import VLLMConfig
from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict
############################################################################
# Constants go here
@ -119,54 +115,12 @@ def _random_uuid_str() -> str:
return str(uuid.uuid4().hex)
def _merge_context_into_content(message: Message) -> Message: # type: ignore
"""
Merge the ``context`` field of a Llama Stack ``Message`` object into
the content field for compabilitiy with OpenAI-style APIs.
Generates a content string that emulates the current behavior
of ``llama_models.llama3.api.chat_format.encode_message()``.
:param message: Message that may include ``context`` field
:returns: A version of ``message`` with any context merged into the
``content`` field.
"""
if not isinstance(message, UserMessage): # Separate type check for linter
return message
if message.context is None:
return message
return UserMessage(
role=message.role,
# Emumate llama_models.llama3.api.chat_format.encode_message()
content=message.content + "\n\n" + message.context,
context=None,
)
def _convert_finish_reason(finish_reason: str | None) -> str | None:
"""Convert an OpenAI "finish_reason" result to the equivalent
Llama Stack result code.
"""
# This conversion is currently a wild guess.
if finish_reason is None:
return None
elif finish_reason == "stop":
return StopReason.end_of_turn
else:
return StopReason.out_of_tokens
def _response_format_to_guided_decoding_params(
response_format: Optional[ResponseFormat], # type: ignore
) -> vllm.sampling_params.GuidedDecodingParams:
"""
Like Llama Stack, vLLM's OpenAI-compatible API also uses the name
"ResponseFormat" to describe the object that is a wrapper around
another object that is a wrapper around another object inside
someone else's constrained decoding library.
Here we translate from Llama Stack's wrapper code to vLLM's code
that does the same.
Translate constrained decoding parameters from Llama Stack's
format to vLLM's format.
:param response_format: Llama Stack version of constrained decoding
info. Can be ``None``, indicating no constraints.
@ -244,42 +198,6 @@ def _convert_sampling_params(
return vllm_sampling_params
def _convert_tools(
tools: Optional[List[ToolDefinition]] = None,
) -> List[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]:
"""
Convert the list of available tools from Llama Stack's format to vLLM's
version of OpenAI's format.
"""
if tools is None:
return []
result = []
for t in tools:
if isinstance(t.tool_name, BuiltinTool):
raise NotImplementedError("Built-in tools not yet implemented")
if t.parameters is None:
parameters = None
else: # if t.parameters is not None
# Convert the "required" flags to a list of required params
required_params = [k for k, v in t.parameters.items() if v.required]
parameters = {
"type": "object", # Mystery value that shows up in OpenAI docs
"properties": {
k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items()
},
"required": required_params,
}
function_def = vllm.entrypoints.openai.protocol.FunctionDefinition(
name=t.tool_name, description=t.description, parameters=parameters
)
# Every tool definition is double-boxed in a ChatCompletionToolsParam
result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def))
return result
############################################################################
# Class definitions go here
@ -582,51 +500,20 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# dataclass.
# Note that this dataclass has the same name as a similar dataclass in
# Llama Stack.
converted_messages = [
await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) for m in messages
]
converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs)
converted_tools = _convert_tools(tools)
# Llama will try to use built-in tools with no tool catalog, so don't enable
# tool choice unless at least one tool is enabled.
converted_tool_choice = "none"
if tool_choice == ToolChoice.auto and tools is not None and len(tools) > 0:
converted_tool_choice = "auto"
# TODO: Figure out what to do with the tool_prompt_format argument.
# Other connectors appear to drop it quietly.
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(
model=self.resolved_model_id,
messages=converted_messages,
tools=converted_tools,
tool_choice=converted_tool_choice,
stream=stream,
)
# vLLM's OpenAI-compatible APIs take sampling parameters as multiple
# keyword args instead of a vLLM SamplingParams object. Copy over
# all the parts that we currently convert from Llama Stack format.
for param_name in [
"max_tokens",
"temperature",
"top_p",
"top_k",
"repetition_penalty",
]:
setattr(
chat_completion_request,
param_name,
getattr(converted_sampling_params, param_name),
request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(
ChatCompletionRequest(
model=self.resolved_model_id,
messages=messages,
sampling_params=sampling_params,
response_format=response_format,
tools=tools,
tool_choice=tool_choice,
tool_prompt_format=tool_prompt_format,
stream=stream,
logprobs=logprobs,
)
# Guided decoding parameters are further broken out
if converted_sampling_params.guided_decoding is not None:
g = converted_sampling_params.guided_decoding
chat_completion_request.guided_json = g.json
chat_completion_request.guided_regex = g.regex
chat_completion_request.guided_grammar = g.grammar
)
chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options)
_info(f"Converted request: {chat_completion_request}")
@ -668,12 +555,13 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
if len(vllm_result.choices) == 0:
raise ValueError("Don't know how to convert response object without any responses")
vllm_message = vllm_result.choices[0].message
vllm_finish_reason = vllm_result.choices[0].finish_reason
converted_message = CompletionMessage(
role=vllm_message.role,
# Llama Stack API won't accept None for content field.
content=("" if vllm_message.content is None else vllm_message.content),
stop_reason=_convert_finish_reason(vllm_result.choices[0].finish_reason),
stop_reason=get_stop_reason(vllm_finish_reason),
tool_calls=[
ToolCall(
call_id=t.id,
@ -746,7 +634,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
# The result may contain multiple completions, but Llama Stack APIs
# only support returning one.
first_choice = parsed_chunk["choices"][0]
converted_stop_reason = _convert_finish_reason(first_choice["finish_reason"])
converted_stop_reason = get_stop_reason(first_choice["finish_reason"])
delta_record = first_choice["delta"]
if "content" in delta_record: