mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Formatting
This commit is contained in:
parent
7076e661b5
commit
d9db9a01bf
1 changed files with 17 additions and 14 deletions
|
@ -4,9 +4,10 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
import warnings
|
import warnings
|
||||||
from typing import AsyncGenerator, Generator, Literal
|
from typing import AsyncGenerator, Generator, Literal
|
||||||
import json
|
|
||||||
from groq import Stream
|
from groq import Stream
|
||||||
from groq.types.chat.chat_completion import ChatCompletion
|
from groq.types.chat.chat_completion import ChatCompletion
|
||||||
from groq.types.chat.chat_completion_assistant_message_param import (
|
from groq.types.chat.chat_completion_assistant_message_param import (
|
||||||
|
@ -14,19 +15,19 @@ from groq.types.chat.chat_completion_assistant_message_param import (
|
||||||
)
|
)
|
||||||
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
from groq.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||||
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
from groq.types.chat.chat_completion_message_param import ChatCompletionMessageParam
|
||||||
|
from groq.types.chat.chat_completion_message_tool_call import (
|
||||||
|
ChatCompletionMessageToolCall,
|
||||||
|
)
|
||||||
from groq.types.chat.chat_completion_system_message_param import (
|
from groq.types.chat.chat_completion_system_message_param import (
|
||||||
ChatCompletionSystemMessageParam,
|
ChatCompletionSystemMessageParam,
|
||||||
)
|
)
|
||||||
|
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
||||||
from groq.types.chat.chat_completion_user_message_param import (
|
from groq.types.chat.chat_completion_user_message_param import (
|
||||||
ChatCompletionUserMessageParam,
|
ChatCompletionUserMessageParam,
|
||||||
)
|
)
|
||||||
from groq.types.chat.completion_create_params import CompletionCreateParams
|
from groq.types.chat.completion_create_params import CompletionCreateParams
|
||||||
from groq.types.chat.chat_completion_message_tool_call import (
|
|
||||||
ChatCompletionMessageToolCall,
|
|
||||||
)
|
|
||||||
from groq.types.chat.chat_completion_tool_param import ChatCompletionToolParam
|
|
||||||
from groq.types.shared.function_definition import FunctionDefinition
|
from groq.types.shared.function_definition import FunctionDefinition
|
||||||
from groq.types.shared.function_parameters import FunctionParameters
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
|
@ -38,13 +39,14 @@ from llama_stack.apis.inference import (
|
||||||
Role,
|
Role,
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
|
ToolCallDelta,
|
||||||
|
ToolCallParseStatus,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
ToolCallParseStatus,
|
|
||||||
ToolCallDelta,
|
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def convert_chat_completion_request(
|
def convert_chat_completion_request(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
) -> CompletionCreateParams:
|
) -> CompletionCreateParams:
|
||||||
|
@ -85,6 +87,7 @@ def convert_chat_completion_request(
|
||||||
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
tool_choice=request.tool_choice.value if request.tool_choice else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||||
if message.role == Role.system.value:
|
if message.role == Role.system.value:
|
||||||
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
return ChatCompletionSystemMessageParam(role="system", content=message.content)
|
||||||
|
@ -98,7 +101,6 @@ def _convert_message(message: Message) -> ChatCompletionMessageParam:
|
||||||
raise ValueError(f"Invalid message role: {message.role}")
|
raise ValueError(f"Invalid message role: {message.role}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||||
# Groq requires a description for function tools
|
# Groq requires a description for function tools
|
||||||
if tool_definition.description is None:
|
if tool_definition.description is None:
|
||||||
|
@ -114,13 +116,11 @@ def _convert_groq_tool_definition(tool_definition: ToolDefinition) -> dict:
|
||||||
key: _convert_groq_tool_parameter(param)
|
key: _convert_groq_tool_parameter(param)
|
||||||
for key, param in tool_parameters.items()
|
for key, param in tool_parameters.items()
|
||||||
},
|
},
|
||||||
)
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_groq_tool_parameter(
|
def _convert_groq_tool_parameter(tool_parameter: ToolParamDefinition) -> dict:
|
||||||
tool_parameter: ToolParamDefinition
|
|
||||||
) -> dict:
|
|
||||||
param = {
|
param = {
|
||||||
"type": tool_parameter.param_type,
|
"type": tool_parameter.param_type,
|
||||||
}
|
}
|
||||||
|
@ -211,7 +211,9 @@ async def convert_chat_completion_response_stream(
|
||||||
elif choice.delta.tool_calls:
|
elif choice.delta.tool_calls:
|
||||||
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
# We assume there is only one tool call per chunk, but emit a warning in case we're wrong
|
||||||
if len(choice.delta.tool_calls) > 1:
|
if len(choice.delta.tool_calls) > 1:
|
||||||
warnings.warn("Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest.")
|
warnings.warn(
|
||||||
|
"Groq returned multiple tool calls in one chunk. Using the first one, ignoring the rest."
|
||||||
|
)
|
||||||
|
|
||||||
# We assume Groq produces fully formed tool calls for each chunk
|
# We assume Groq produces fully formed tool calls for each chunk
|
||||||
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
tool_call = _convert_groq_tool_call(choice.delta.tool_calls[0])
|
||||||
|
@ -233,6 +235,7 @@ async def convert_chat_completion_response_stream(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
def _convert_groq_tool_call(tool_call: ChatCompletionMessageToolCall) -> ToolCall:
|
||||||
return ToolCall(
|
return ToolCall(
|
||||||
call_id=tool_call.id,
|
call_id=tool_call.id,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue