diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html
index b32b7cfdf..eb626fc44 100644
--- a/docs/_static/llama-stack-spec.html
+++ b/docs/_static/llama-stack-spec.html
@@ -4159,70 +4159,80 @@
]
},
"arguments": {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "integer"
- },
- {
- "type": "number"
- },
- {
- "type": "boolean"
- },
- {
- "type": "null"
- },
- {
- "type": "array",
- "items": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "integer"
- },
- {
- "type": "number"
- },
- {
- "type": "boolean"
- },
- {
- "type": "null"
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "integer"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ },
+ {
+ "type": "array",
+ "items": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "integer"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ]
}
- ]
- }
- },
- {
- "type": "object",
- "additionalProperties": {
- "oneOf": [
- {
- "type": "string"
- },
- {
- "type": "integer"
- },
- {
- "type": "number"
- },
- {
- "type": "boolean"
- },
- {
- "type": "null"
+ },
+ {
+ "type": "object",
+ "additionalProperties": {
+ "oneOf": [
+ {
+ "type": "string"
+ },
+ {
+ "type": "integer"
+ },
+ {
+ "type": "number"
+ },
+ {
+ "type": "boolean"
+ },
+ {
+ "type": "null"
+ }
+ ]
}
- ]
- }
+ }
+ ]
}
- ]
- }
+ }
+ ]
+ },
+ "arguments_json": {
+ "type": "string"
}
},
"additionalProperties": false,
diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml
index eb5d9722e..fa6920381 100644
--- a/docs/_static/llama-stack-spec.yaml
+++ b/docs/_static/llama-stack-spec.yaml
@@ -2864,30 +2864,34 @@ components:
title: BuiltinTool
- type: string
arguments:
- type: object
- additionalProperties:
- oneOf:
- - type: string
- - type: integer
- - type: number
- - type: boolean
- - type: 'null'
- - type: array
- items:
- oneOf:
- - type: string
- - type: integer
- - type: number
- - type: boolean
- - type: 'null'
- - type: object
- additionalProperties:
- oneOf:
- - type: string
- - type: integer
- - type: number
- - type: boolean
- - type: 'null'
+ oneOf:
+ - type: string
+ - type: object
+ additionalProperties:
+ oneOf:
+ - type: string
+ - type: integer
+ - type: number
+ - type: boolean
+ - type: 'null'
+ - type: array
+ items:
+ oneOf:
+ - type: string
+ - type: integer
+ - type: number
+ - type: boolean
+ - type: 'null'
+ - type: object
+ additionalProperties:
+ oneOf:
+ - type: string
+ - type: integer
+ - type: number
+ - type: boolean
+ - type: 'null'
+ arguments_json:
+ type: string
additionalProperties: false
required:
- call_id
diff --git a/llama_stack/models/llama/datatypes.py b/llama_stack/models/llama/datatypes.py
index b25bf0ea9..9842d7980 100644
--- a/llama_stack/models/llama/datatypes.py
+++ b/llama_stack/models/llama/datatypes.py
@@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
- arguments: Dict[str, RecursiveType]
+ # Plan is to deprecate the Dict in favor of a JSON string
+ # that is parsed on the client side instead of trying to manage
+ # the recursive type here.
+ # Making this a union so that client side can start prepping for this change.
+ # Eventually, we will remove both the Dict and arguments_json field,
+ # and arguments will just be a str
+ arguments: Union[str, Dict[str, RecursiveType]]
+ arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before")
@classmethod
diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py
index 011ccb02a..2862f8558 100644
--- a/llama_stack/models/llama/llama3/chat_format.py
+++ b/llama_stack/models/llama/llama3/chat_format.py
@@ -12,6 +12,7 @@
# the top-level of this source tree.
import io
+import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
- tool_arguments = {
- "query": list(tool_arguments.values())[0],
- }
+ if isinstance(tool_arguments, dict):
+ tool_arguments = {
+ "query": list(tool_arguments.values())[0],
+ }
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
@@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
+ arguments_json=json.dumps(tool_arguments),
)
)
content = ""
diff --git a/llama_stack/models/llama/llama3/template_data.py b/llama_stack/models/llama/llama3/template_data.py
index aa16aa009..076b4adb4 100644
--- a/llama_stack/models/llama/llama3/template_data.py
+++ b/llama_stack/models/llama/llama3/template_data.py
@@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
-from llama_stack.models.llama.datatypes import (
- BuiltinTool,
- StopReason,
- ToolCall,
-)
+
+from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,
diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py
index b59df13d0..256e0f821 100644
--- a/llama_stack/providers/inline/inference/vllm/vllm.py
+++ b/llama_stack/providers/inline/inference/vllm/vllm.py
@@ -582,6 +582,7 @@ class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
tool_name=t.function.name,
# vLLM function args come back as a string. Llama Stack expects JSON.
arguments=json.loads(t.function.arguments),
+ arguments_json=t.function.arguments,
)
for t in vllm_message.tool_calls
],
diff --git a/llama_stack/providers/remote/inference/sambanova/sambanova.py b/llama_stack/providers/remote/inference/sambanova/sambanova.py
index a5e17c2a3..635a42d38 100644
--- a/llama_stack/providers/remote/inference/sambanova/sambanova.py
+++ b/llama_stack/providers/remote/inference/sambanova/sambanova.py
@@ -42,9 +42,7 @@ from llama_stack.models.llama.datatypes import (
TopKSamplingStrategy,
TopPSamplingStrategy,
)
-from llama_stack.providers.utils.inference.model_registry import (
- ModelRegistryHelper,
-)
+from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
process_chat_completion_stream_response,
)
@@ -293,14 +291,12 @@ class SambaNovaInferenceAdapter(ModelRegistryHelper, Inference):
if not tool_calls:
return []
- for call in tool_calls:
- call_function_arguments = json.loads(call.function.arguments)
-
compitable_tool_calls = [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
- arguments=call_function_arguments,
+ arguments=json.loads(call.function.arguments),
+ arguments_json=call.function.arguments,
)
for call in tool_calls
]
diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py
index f940de7ba..eda1a179c 100644
--- a/llama_stack/providers/remote/inference/vllm/vllm.py
+++ b/llama_stack/providers/remote/inference/vllm/vllm.py
@@ -90,15 +90,12 @@ def _convert_to_vllm_tool_calls_in_response(
if not tool_calls:
return []
- call_function_arguments = None
- for call in tool_calls:
- call_function_arguments = json.loads(call.function.arguments)
-
return [
ToolCall(
call_id=call.id,
tool_name=call.function.name,
- arguments=call_function_arguments,
+ arguments=json.loads(call.function.arguments),
+ arguments_json=call.function.arguments,
)
for call in tool_calls
]
@@ -183,6 +180,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name,
arguments=args,
+ arguments_json=args_str,
),
parse_status=ToolCallParseStatus.succeeded,
),
diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py
index 2a362f8cb..b264c7312 100644
--- a/llama_stack/providers/utils/inference/openai_compat.py
+++ b/llama_stack/providers/utils/inference/openai_compat.py
@@ -529,7 +529,11 @@ async def convert_message_to_openai_dict_new(
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl(
content_: InterleavedContent,
- ) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
+ ) -> Union[
+ str,
+ OpenAIChatCompletionContentPartParam,
+ List[OpenAIChatCompletionContentPartParam],
+ ]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content_, str):
return content_
@@ -570,7 +574,7 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
- name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
+ name=(tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value),
arguments=json.dumps(tool.arguments),
),
type="function",
@@ -609,6 +613,7 @@ def convert_tool_call(
call_id=tool_call.id,
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
+ arguments_json=tool_call.function.arguments,
)
except Exception:
return UnparseableToolCall(
@@ -759,6 +764,7 @@ def _convert_openai_tool_calls(
call_id=call.id,
tool_name=call.function.name,
arguments=json.loads(call.function.arguments),
+ arguments_json=call.function.arguments,
)
for call in tool_calls
]
@@ -890,7 +896,8 @@ async def convert_openai_chat_completion_stream(
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn(
- "multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
+ "multiple tool calls found in a single delta, using the first, ignoring the rest",
+ stacklevel=2,
)
if not enable_incremental_tool_calls:
@@ -971,6 +978,7 @@ async def convert_openai_chat_completion_stream(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
+ arguments_json=buffer["arguments"],
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py
index c3755e2cb..0e2780e50 100644
--- a/tests/unit/models/test_prompt_adapter.py
+++ b/tests/unit/models/test_prompt_adapter.py
@@ -165,7 +165,10 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
- self.assertIn('{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', prompt)
+ self.assertIn(
+ '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
+ prompt,
+ )
async def test_user_provided_system_message(self):
content = "Hello !"