Handle response

Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
Yuan Tang 2025-02-10 15:23:09 -05:00
parent cc3bb0938a
commit b2a86532a2
No known key found for this signature in database
2 changed files with 82 additions and 43 deletions

View file

@ -3,10 +3,11 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_models.llama3.api import StopReason, ToolCall
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import all_registered_models
@ -30,6 +31,7 @@ from llama_stack.apis.inference import (
ToolConfig,
ToolDefinition,
ToolPromptFormat,
CompletionMessage,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.providers.datatypes import ModelsProtocolPrivate
@ -40,7 +42,6 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
convert_message_to_openai_dict,
get_sampling_options,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
@ -68,6 +69,73 @@ def build_model_aliases():
]
def convert_to_vllm_tool_calls_in_response(
tool_calls,
) -> List[ToolCall]:
if not tool_calls:
return []
call_function_arguments = None
for call in tool_calls:
call_function_arguments = json.loads(call.function.arguments)
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
arguments=call_function_arguments,
)
for call in tool_calls
]
def convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compat_tools = []
for tool in tools:
properties = {}
compat_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compat_required.append(tool_key)
compat_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compat_required,
},
},
}
compat_tools.append(compat_tool)
if len(compat_tools) > 0:
return compat_tools
return None
def convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
return {
"stop": StopReason.end_of_turn,
"length": StopReason.out_of_tokens,
"tool_calls": StopReason.end_of_message,
}.get(finish_reason, StopReason.end_of_turn)
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
self.register_helper = ModelRegistryHelper(build_model_aliases())
@ -142,7 +210,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
) -> ChatCompletionResponse:
params = await self._get_params(request)
r = client.chat.completions.create(**params)
return process_chat_completion_response(r, self.formatter)
choice = r.choices[0]
result = ChatCompletionResponse(
completion_message=CompletionMessage(
content=choice.message.content or "",
stop_reason=convert_to_vllm_finish_reason(choice.finish_reason),
tool_calls=convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
),
logprobs=None,
)
return result
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
params = await self._get_params(request)
@ -187,51 +264,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
)
return model
def convert_to_vllm_tools(self, tools: List[ToolDefinition]) -> List[dict]:
if tools is None:
return tools
compat_tools = []
for tool in tools:
properties = {}
compat_required = []
if tool.parameters:
for tool_key, tool_param in tool.parameters.items():
properties[tool_key] = {"type": tool_param.param_type}
if tool_param.description:
properties[tool_key]["description"] = tool_param.description
if tool_param.default:
properties[tool_key]["default"] = tool_param.default
if tool_param.required:
compat_required.append(tool_key)
compat_tool = {
"type": "function",
"function": {
"name": tool.tool_name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": properties,
"required": compat_required,
},
},
}
compat_tools.append(compat_tool)
if len(compat_tools) > 0:
return compat_tools
return None
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
options = get_sampling_options(request.sampling_params)
if "max_tokens" not in options:
options["max_tokens"] = self.config.max_tokens
input_dict = {}
input_dict["tools"] = self.convert_to_vllm_tools(request.tools)
input_dict = {"tools": convert_to_vllm_tools_in_request(request.tools)}
if isinstance(request, ChatCompletionRequest):
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]

View file

@ -174,6 +174,7 @@ def process_chat_completion_response(
) -> ChatCompletionResponse:
choice = response.choices[0]
# TODO: This does not work well with tool calls (at least for vLLM remote)
raw_message = formatter.decode_assistant_message_from_content(
text_from_choice(choice), get_stop_reason(choice.finish_reason)
)