mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-06-28 02:53:30 +00:00
fix: Updating ToolCall.arguments
to allow for json strings that can be decoded on client side (#1685)
### What does this PR do? Currently, `ToolCall.arguments` is a `Dict[str, RecursiveType]`. However, on the client SDK side -- the `RecursiveType` gets deserialized into a number ( both int and float get collapsed ) and hence when params are `int` they get converted to float which might break client side tools that might be doing type checking. Closes: https://github.com/meta-llama/llama-stack/issues/1683 ### Test Plan Stainless changes -- https://github.com/meta-llama/llama-stack-client-python/pull/204 ``` pytest -s -v --stack-config=fireworks tests/integration/agents/test_agents.py --text-model meta-llama/Llama-3.1-8B-Instruct ```
This commit is contained in:
parent
113f3a259c
commit
65ca85ba6b
10 changed files with 137 additions and 110 deletions
10
docs/_static/llama-stack-spec.html
vendored
10
docs/_static/llama-stack-spec.html
vendored
|
@ -4159,6 +4159,11 @@
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
"arguments": {
|
"arguments": {
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": {
|
"additionalProperties": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
|
@ -4224,6 +4229,11 @@
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
]
|
||||||
|
},
|
||||||
|
"arguments_json": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
|
|
6
docs/_static/llama-stack-spec.yaml
vendored
6
docs/_static/llama-stack-spec.yaml
vendored
|
@ -2864,7 +2864,9 @@ components:
|
||||||
title: BuiltinTool
|
title: BuiltinTool
|
||||||
- type: string
|
- type: string
|
||||||
arguments:
|
arguments:
|
||||||
type: object
|
oneOf:
|
||||||
|
- type: string
|
||||||
|
- type: object
|
||||||
additionalProperties:
|
additionalProperties:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: string
|
- type: string
|
||||||
|
@ -2888,6 +2890,8 @@ components:
|
||||||
- type: number
|
- type: number
|
||||||
- type: boolean
|
- type: boolean
|
||||||
- type: 'null'
|
- type: 'null'
|
||||||
|
arguments_json:
|
||||||
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- call_id
|
- call_id
|
||||||
|
|
|
@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
|
||||||
class ToolCall(BaseModel):
|
class ToolCall(BaseModel):
|
||||||
call_id: str
|
call_id: str
|
||||||
tool_name: Union[BuiltinTool, str]
|
tool_name: Union[BuiltinTool, str]
|
||||||
arguments: Dict[str, RecursiveType]
|
# Plan is to deprecate the Dict in favor of a JSON string
|
||||||
|
# that is parsed on the client side instead of trying to manage
|
||||||
|
# the recursive type here.
|
||||||
|
# Making this a union so that client side can start prepping for this change.
|
||||||
|
# Eventually, we will remove both the Dict and arguments_json field,
|
||||||
|
# and arguments will just be a str
|
||||||
|
arguments: Union[str, Dict[str, RecursiveType]]
|
||||||
|
arguments_json: Optional[str] = None
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
@field_validator("tool_name", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|
|
@ -12,6 +12,7 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
@ -203,6 +204,7 @@ class ChatFormat:
|
||||||
# This code tries to handle that case
|
# This code tries to handle that case
|
||||||
if tool_name in BuiltinTool.__members__:
|
if tool_name in BuiltinTool.__members__:
|
||||||
tool_name = BuiltinTool[tool_name]
|
tool_name = BuiltinTool[tool_name]
|
||||||
|
if isinstance(tool_arguments, dict):
|
||||||
tool_arguments = {
|
tool_arguments = {
|
||||||
"query": list(tool_arguments.values())[0],
|
"query": list(tool_arguments.values())[0],
|
||||||
}
|
}
|
||||||
|
@ -229,6 +231,7 @@ class ChatFormat:
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
arguments=tool_arguments,
|
arguments=tool_arguments,
|
||||||
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
content = ""
|
content = ""
|
||||||
|
|
|
@ -11,11 +11,8 @@
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
BuiltinTool,
|
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||||
StopReason,
|
|
||||||
ToolCall,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
|
|
|
@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||||
tool_name=t.function.name,
|
tool_name=t.function.name,
|
||||||
# vLLM function args come back as a string. Llama Stack expects JSON.
|
# vLLM function args come back as a string. Llama Stack expects JSON.
|
||||||
arguments=json.loads(t.function.arguments),
|
arguments=json.loads(t.function.arguments),
|
||||||
|
arguments_json=t.function.arguments,
|
||||||
)
|
)
|
||||||
for t in vllm_message.tool_calls
|
for t in vllm_message.tool_calls
|
||||||
],
|
],
|
||||||
|
|
|
@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
ModelRegistryHelper,
|
|
||||||
)
|
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
process_chat_completion_stream_response,
|
process_chat_completion_stream_response,
|
||||||
)
|
)
|
||||||
|
@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
for call in tool_calls:
|
|
||||||
call_function_arguments = json.loads(call.function.arguments)
|
|
||||||
|
|
||||||
compitable_tool_calls = [
|
compitable_tool_calls = [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=call.id,
|
call_id=call.id,
|
||||||
tool_name=call.function.name,
|
tool_name=call.function.name,
|
||||||
arguments=call_function_arguments,
|
arguments=json.loads(call.function.arguments),
|
||||||
|
arguments_json=call.function.arguments,
|
||||||
)
|
)
|
||||||
for call in tool_calls
|
for call in tool_calls
|
||||||
]
|
]
|
||||||
|
|
|
@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
|
||||||
if not tool_calls:
|
if not tool_calls:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
call_function_arguments = None
|
|
||||||
for call in tool_calls:
|
|
||||||
call_function_arguments = json.loads(call.function.arguments)
|
|
||||||
|
|
||||||
return [
|
return [
|
||||||
ToolCall(
|
ToolCall(
|
||||||
call_id=call.id,
|
call_id=call.id,
|
||||||
tool_name=call.function.name,
|
tool_name=call.function.name,
|
||||||
arguments=call_function_arguments,
|
arguments=json.loads(call.function.arguments),
|
||||||
|
arguments_json=call.function.arguments,
|
||||||
)
|
)
|
||||||
for call in tool_calls
|
for call in tool_calls
|
||||||
]
|
]
|
||||||
|
@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
|
||||||
call_id=tool_call_buf.call_id,
|
call_id=tool_call_buf.call_id,
|
||||||
tool_name=tool_call_buf.tool_name,
|
tool_name=tool_call_buf.tool_name,
|
||||||
arguments=args,
|
arguments=args,
|
||||||
|
arguments_json=args_str,
|
||||||
),
|
),
|
||||||
parse_status=ToolCallParseStatus.succeeded,
|
parse_status=ToolCallParseStatus.succeeded,
|
||||||
),
|
),
|
||||||
|
|
|
@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
|
||||||
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
|
||||||
async def impl(
|
async def impl(
|
||||||
content_: InterleavedContent,
|
content_: InterleavedContent,
|
||||||
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
|
) -> Union[
|
||||||
|
str,
|
||||||
|
OpenAIChatCompletionContentPartParam,
|
||||||
|
List[OpenAIChatCompletionContentPartParam],
|
||||||
|
]:
|
||||||
# Llama Stack and OpenAI spec match for str and text input
|
# Llama Stack and OpenAI spec match for str and text input
|
||||||
if isinstance(content_, str):
|
if isinstance(content_, str):
|
||||||
return content_
|
return content_
|
||||||
|
@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
|
||||||
OpenAIChatCompletionMessageToolCall(
|
OpenAIChatCompletionMessageToolCall(
|
||||||
id=tool.call_id,
|
id=tool.call_id,
|
||||||
function=OpenAIFunction(
|
function=OpenAIFunction(
|
||||||
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
|
name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
|
||||||
arguments=json.dumps(tool.arguments),
|
arguments=json.dumps(tool.arguments),
|
||||||
),
|
),
|
||||||
type="function",
|
type="function",
|
||||||
|
@ -609,6 +613,7 @@ def convert_tool_call(
|
||||||
call_id=tool_call.id,
|
call_id=tool_call.id,
|
||||||
tool_name=tool_call.function.name,
|
tool_name=tool_call.function.name,
|
||||||
arguments=json.loads(tool_call.function.arguments),
|
arguments=json.loads(tool_call.function.arguments),
|
||||||
|
arguments_json=tool_call.function.arguments,
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
return UnparseableToolCall(
|
return UnparseableToolCall(
|
||||||
|
@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
|
||||||
call_id=call.id,
|
call_id=call.id,
|
||||||
tool_name=call.function.name,
|
tool_name=call.function.name,
|
||||||
arguments=json.loads(call.function.arguments),
|
arguments=json.loads(call.function.arguments),
|
||||||
|
arguments_json=call.function.arguments,
|
||||||
)
|
)
|
||||||
for call in tool_calls
|
for call in tool_calls
|
||||||
]
|
]
|
||||||
|
@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
|
||||||
# ChatCompletionResponseEvent only supports one per stream
|
# ChatCompletionResponseEvent only supports one per stream
|
||||||
if len(choice.delta.tool_calls) > 1:
|
if len(choice.delta.tool_calls) > 1:
|
||||||
warnings.warn(
|
warnings.warn(
|
||||||
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
|
"multiple tool calls found in a single delta, using the first, ignoring the rest",
|
||||||
|
stacklevel=2,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not enable_incremental_tool_calls:
|
if not enable_incremental_tool_calls:
|
||||||
|
@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
|
||||||
call_id=buffer["call_id"],
|
call_id=buffer["call_id"],
|
||||||
tool_name=buffer["name"],
|
tool_name=buffer["name"],
|
||||||
arguments=arguments,
|
arguments=arguments,
|
||||||
|
arguments_json=buffer["arguments"],
|
||||||
)
|
)
|
||||||
yield ChatCompletionResponseStreamChunk(
|
yield ChatCompletionResponseStreamChunk(
|
||||||
event=ChatCompletionResponseEvent(
|
event=ChatCompletionResponseEvent(
|
||||||
|
|
|
@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
|
||||||
request.model = MODEL
|
request.model = MODEL
|
||||||
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
request.tool_config.tool_prompt_format = ToolPromptFormat.json
|
||||||
prompt = await chat_completion_request_to_prompt(request, request.model)
|
prompt = await chat_completion_request_to_prompt(request, request.model)
|
||||||
self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
|
self.assertIn(
|
||||||
|
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
|
||||||
|
prompt,
|
||||||
|
)
|
||||||
|
|
||||||
async def test_user_provided_system_message(self):
|
async def test_user_provided_system_message(self):
|
||||||
content = "Hello !"
|
content = "Hello !"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue