Formatting

This commit is contained in:
Aidan Do 2024-12-15 13:48:55 +11:00
parent 7076e661b5
commit d9db9a01bf

View file

@ -4,9 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import warnings
from typing import AsyncGenerator, Generator, Literal
import json
from groq import Stream
from groq.types.chat.chat_completion import ChatCompletion
from groq.types.chat.chat_completion_assistant_message_param import (
@ -14,19 +15,19 @@ from groq.types.chat.chat_completion_assistant_message_param import (
)
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_system_message_param import (
ChatCompletionSystemMessageParam,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.chat.chat_completion_user_message_param import (
ChatCompletionUserMessageParam,
)
from groq.types.chat.completion_create_params import CompletionCreateParams
from groq.types.chat.chat_completion_message_tool_call import (
ChatCompletionMessageToolCall,
)
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
from groq.types.shared.function_definition import FunctionDefinition
from groq.types.shared.function_parameters import FunctionParameters
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -38,13 +39,14 @@ from llama_stack.apis.inference import (
Role,
StopReason,
ToolCall,
ToolCallDelta,
ToolCallParseStatus,
ToolDefinition,
ToolParamDefinition,
ToolCallParseStatus,
ToolCallDelta,
ToolPromptFormat,
)
def convert_chat_completion_request(
request: ChatCompletionRequest,
) -> CompletionCreateParams:
@ -85,6 +87,7 @@ def convert_chat_completion_request(
tool_choice=request.tool_choice.value if request.tool_choice else None,
)
def _convert_message(message: Message) -> ChatCompletionMessageParam:
if message.role == Role.system.value:
return ChatCompletionSystemMessageParam(role="system", content=message.content)
@ -98,7 +101,6 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
raise ValueError(f"Invalid message role: {message.role}")
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
# Groq requires a description for function tools
if tool_definition.description is None:
@ -114,13 +116,11 @@ def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
key: _convert_groq_tool_parameter(param)
for key, param in tool_parameters.items()
},
)
),
)
def _convert_groq_tool_parameter(
tool_parameter: ToolParamDefinition
) -> dict:
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
param = {
"type": tool_parameter.param_type,
}
@ -211,7 +211,9 @@ async def convert_chat_completion_response_stream(
elif choice.delta.tool_calls:
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
if len(choice.delta.tool_calls) > 1:
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
warnings.warn(
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
)
# We assume Groq produces fully formed tool calls for each chunk
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
@ -233,6 +235,7 @@ async def convert_chat_completion_response_stream(
)
)
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
return ToolCall(
call_id=tool_call.id,