mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
Handle response
Signed-off-by: Yuan Tang <terrytangyuan@gmail.com>
This commit is contained in:
parent
cc3bb0938a
commit
b2a86532a2
2 changed files with 82 additions and 43 deletions
|
@ -3,10 +3,11 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
import json
|
||||||
import logging
|
import logging
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_models.llama3.api import StopReason, ToolCall
|
||||||
from llama_models.llama3.api.chat_format import ChatFormat
|
from llama_models.llama3.api.chat_format import ChatFormat
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
@ -30,6 +31,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
CompletionMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
@ -40,7 +42,6 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
process_chat_completion_response,
|
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
process_completion_response,
|
process_completion_response,
|
||||||
process_completion_stream_response,
|
process_completion_stream_response,
|
||||||
|
@ -68,6 +69,73 @@ def build_model_aliases():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_vllm_tool_calls_in_response(
|
||||||
|
tool_calls,
|
||||||
|
) -> List[ToolCall]:
|
||||||
|
if not tool_calls:
|
||||||
|
return []
|
||||||
|
|
||||||
|
call_function_arguments = None
|
||||||
|
for call in tool_calls:
|
||||||
|
call_function_arguments = json.loads(call.function.arguments)
|
||||||
|
|
||||||
|
return [
|
||||||
|
ToolCall(
|
||||||
|
call_id=call.id,
|
||||||
|
tool_name=call.function.name,
|
||||||
|
arguments=call_function_arguments,
|
||||||
|
)
|
||||||
|
for call in tool_calls
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_vllm_tools_in_request(tools: List[ToolDefinition]) -> List[dict]:
|
||||||
|
if tools is None:
|
||||||
|
return tools
|
||||||
|
|
||||||
|
compat_tools = []
|
||||||
|
|
||||||
|
for tool in tools:
|
||||||
|
properties = {}
|
||||||
|
compat_required = []
|
||||||
|
if tool.parameters:
|
||||||
|
for tool_key, tool_param in tool.parameters.items():
|
||||||
|
properties[tool_key] = {"type": tool_param.param_type}
|
||||||
|
if tool_param.description:
|
||||||
|
properties[tool_key]["description"] = tool_param.description
|
||||||
|
if tool_param.default:
|
||||||
|
properties[tool_key]["default"] = tool_param.default
|
||||||
|
if tool_param.required:
|
||||||
|
compat_required.append(tool_key)
|
||||||
|
|
||||||
|
compat_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.tool_name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": properties,
|
||||||
|
"required": compat_required,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
compat_tools.append(compat_tool)
|
||||||
|
|
||||||
|
if len(compat_tools) > 0:
|
||||||
|
return compat_tools
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_vllm_finish_reason(finish_reason: str) -> StopReason:
|
||||||
|
return {
|
||||||
|
"stop": StopReason.end_of_turn,
|
||||||
|
"length": StopReason.out_of_tokens,
|
||||||
|
"tool_calls": StopReason.end_of_message,
|
||||||
|
}.get(finish_reason, StopReason.end_of_turn)
|
||||||
|
|
||||||
|
|
||||||
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
def __init__(self, config: VLLMInferenceAdapterConfig) -> None:
|
||||||
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
self.register_helper = ModelRegistryHelper(build_model_aliases())
|
||||||
|
@ -142,7 +210,16 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
r = client.chat.completions.create(**params)
|
r = client.chat.completions.create(**params)
|
||||||
return process_chat_completion_response(r, self.formatter)
|
choice = r.choices[0]
|
||||||
|
result = ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content=choice.message.content or "",
|
||||||
|
stop_reason=convert_to_vllm_finish_reason(choice.finish_reason),
|
||||||
|
tool_calls=convert_to_vllm_tool_calls_in_response(choice.message.tool_calls),
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
async def _stream_chat_completion(self, request: ChatCompletionRequest, client: OpenAI) -> AsyncGenerator:
|
||||||
params = await self._get_params(request)
|
params = await self._get_params(request)
|
||||||
|
@ -187,51 +264,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
)
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def convert_to_vllm_tools(self, tools: List[ToolDefinition]) -> List[dict]:
|
|
||||||
if tools is None:
|
|
||||||
return tools
|
|
||||||
|
|
||||||
compat_tools = []
|
|
||||||
|
|
||||||
for tool in tools:
|
|
||||||
properties = {}
|
|
||||||
compat_required = []
|
|
||||||
if tool.parameters:
|
|
||||||
for tool_key, tool_param in tool.parameters.items():
|
|
||||||
properties[tool_key] = {"type": tool_param.param_type}
|
|
||||||
if tool_param.description:
|
|
||||||
properties[tool_key]["description"] = tool_param.description
|
|
||||||
if tool_param.default:
|
|
||||||
properties[tool_key]["default"] = tool_param.default
|
|
||||||
if tool_param.required:
|
|
||||||
compat_required.append(tool_key)
|
|
||||||
|
|
||||||
compat_tool = {
|
|
||||||
"type": "function",
|
|
||||||
"function": {
|
|
||||||
"name": tool.tool_name,
|
|
||||||
"description": tool.description,
|
|
||||||
"parameters": {
|
|
||||||
"type": "object",
|
|
||||||
"properties": properties,
|
|
||||||
"required": compat_required,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
compat_tools.append(compat_tool)
|
|
||||||
|
|
||||||
if len(compat_tools) > 0:
|
|
||||||
return compat_tools
|
|
||||||
return None
|
|
||||||
|
|
||||||
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
options = get_sampling_options(request.sampling_params)
|
options = get_sampling_options(request.sampling_params)
|
||||||
if "max_tokens" not in options:
|
if "max_tokens" not in options:
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
input_dict = {}
|
input_dict = {"tools": convert_to_vllm_tools_in_request(request.tools)}
|
||||||
input_dict["tools"] = self.convert_to_vllm_tools(request.tools)
|
|
||||||
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
input_dict["messages"] = [await convert_message_to_openai_dict(m, download=True) for m in request.messages]
|
||||||
|
|
|
@ -174,6 +174,7 @@ def process_chat_completion_response(
|
||||||
) -> ChatCompletionResponse:
|
) -> ChatCompletionResponse:
|
||||||
choice = response.choices[0]
|
choice = response.choices[0]
|
||||||
|
|
||||||
|
# TODO: This does not work well with tool calls (at least for vLLM remote)
|
||||||
raw_message = formatter.decode_assistant_message_from_content(
|
raw_message = formatter.decode_assistant_message_from_content(
|
||||||
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
text_from_choice(choice), get_stop_reason(choice.finish_reason)
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue