revert some unintentional changes by copying source of truth to llama-models

This commit is contained in:
Ashwin Bharambe 2025-04-07 11:00:48 -07:00
parent 53a8086e37
commit cfaf9e0e8b
9 changed files with 133 additions and 113 deletions

View file

@ -4163,6 +4163,11 @@
] ]
}, },
"arguments": { "arguments": {
"oneOf": [
{
"type": "string"
},
{
"type": "object", "type": "object",
"additionalProperties": { "additionalProperties": {
"oneOf": [ "oneOf": [
@ -4228,6 +4233,11 @@
] ]
} }
} }
]
},
"arguments_json": {
"type": "string"
}
}, },
"additionalProperties": false, "additionalProperties": false,
"required": [ "required": [

View file

@ -2890,7 +2890,9 @@ components:
title: BuiltinTool title: BuiltinTool
- type: string - type: string
arguments: arguments:
type: object oneOf:
- type: string
- type: object
additionalProperties: additionalProperties:
oneOf: oneOf:
- type: string - type: string
@ -2914,6 +2916,8 @@ components:
- type: number - type: number
- type: boolean - type: boolean
- type: 'null' - type: 'null'
arguments_json:
type: string
additionalProperties: false additionalProperties: false
required: required:
- call_id - call_id

View file

@ -38,7 +38,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel): class ToolCall(BaseModel):
call_id: str call_id: str
tool_name: Union[BuiltinTool, str] tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType] # Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before") @field_validator("tool_name", mode="before")
@classmethod @classmethod

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -210,9 +210,12 @@ class ChatFormat:
content = ToolUtils.encode_tool_call(t, tool_prompt_format) content = ToolUtils.encode_tool_call(t, tool_prompt_format)
_process_content(content) _process_content(content)
# Tool calls and Tool Response messages should be eom
eom = False eom = False
if message.role == "assistant": if message.role == "assistant":
eom = message.stop_reason == StopReason.end_of_message eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
elif message.role == "tool":
eom = True
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"]) tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
return tokens, images return tokens, images
@ -247,6 +250,11 @@ class ChatFormat:
if content.startswith(header_str): if content.startswith(header_str):
content = content[len(header_str) :] content = content[len(header_str) :]
ipython = content.startswith("<|python_start|>")
if ipython:
content = content[len("<|python_start|>") :]
content = content.replace("<|python_end|>", "")
if content.endswith("<|eot|>"): if content.endswith("<|eot|>"):
content = content[: -len("<|eot|>")] content = content[: -len("<|eot|>")]
stop_reason = StopReason.end_of_turn stop_reason = StopReason.end_of_turn
@ -277,6 +285,11 @@ class ChatFormat:
} }
if tool_name in BuiltinTool.__members__: if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
elif ipython:
tool_name = BuiltinTool.code_interpreter
tool_arguments = {
"code": content,
}
tool_calls = [] tool_calls = []
if tool_name is not None and tool_arguments is not None: if tool_name is not None and tool_arguments is not None:

View file

@ -4,22 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import codecs import codecs
import io import io
import json import json

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -4,16 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os import os
from logging import getLogger from logging import getLogger
from pathlib import Path from pathlib import Path

View file

@ -96,6 +96,7 @@ def _convert_to_vllm_tool_calls_in_response(
call_id=call.id, call_id=call.id,
tool_name=call.function.name, tool_name=call.function.name,
arguments=json.loads(call.function.arguments), arguments=json.loads(call.function.arguments),
arguments_json=call.function.arguments,
) )
for call in tool_calls for call in tool_calls
] ]
@ -175,6 +176,7 @@ async def _process_vllm_chat_completion_stream_response(
call_id=tool_call_buf.call_id, call_id=tool_call_buf.call_id,
tool_name=tool_call_buf.tool_name, tool_name=tool_call_buf.tool_name,
arguments=args, arguments=args,
arguments_json=args_str,
), ),
parse_status=ToolCallParseStatus.succeeded, parse_status=ToolCallParseStatus.succeeded,
), ),