Merge branch 'main' into add-watsonx-inference-adapter

This commit is contained in:
Sajikumar JS 2025-03-20 10:18:05 +05:30
commit 28e6c8478b
308 changed files with 33749 additions and 5102 deletions

View file

@ -47,7 +47,14 @@ RecursiveType = Union[Primitive, List[Primitive], Dict[str, Primitive]]
class ToolCall(BaseModel):
call_id: str
tool_name: Union[BuiltinTool, str]
arguments: Dict[str, RecursiveType]
# Plan is to deprecate the Dict in favor of a JSON string
# that is parsed on the client side instead of trying to manage
# the recursive type here.
# Making this a union so that client side can start prepping for this change.
# Eventually, we will remove both the Dict and arguments_json field,
# and arguments will just be a str
arguments: Union[str, Dict[str, RecursiveType]]
arguments_json: Optional[str] = None
@field_validator("tool_name", mode="before")
@classmethod
@ -179,13 +186,11 @@ class TopKSamplingStrategy(BaseModel):
top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema(
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
],
name="SamplingStrategy",
)
SamplingStrategy = Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
]
register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type

View file

@ -12,6 +12,7 @@
# the top-level of this source tree.
import io
import json
import uuid
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple
@ -203,9 +204,10 @@ class ChatFormat:
# This code tries to handle that case
if tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
if isinstance(tool_arguments, dict):
tool_arguments = {
"query": list(tool_arguments.values())[0],
}
else:
builtin_tool_info = ToolUtils.maybe_extract_builtin_tool_call(content)
if builtin_tool_info is not None:
@ -229,6 +231,7 @@ class ChatFormat:
call_id=call_id,
tool_name=tool_name,
arguments=tool_arguments,
arguments_json=json.dumps(tool_arguments),
)
)
content = ""

View file

@ -34,7 +34,9 @@ class SystemDefaultGenerator(PromptTemplateGeneratorBase):
)
return PromptTemplate(
template_str.lstrip("\n"),
{"today": datetime.now().strftime("%d %B %Y")},
{
"today": datetime.now().strftime("%d %B %Y") # noqa: DTZ005 - we don't care about timezones here since we are displaying the date
},
)
def data_examples(self) -> List[Any]:

View file

@ -11,11 +11,8 @@
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from llama_stack.models.llama.datatypes import (
BuiltinTool,
StopReason,
ToolCall,
)
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
from .prompt_templates import (
BuiltinToolGenerator,

View file

@ -15,8 +15,11 @@ import json
import re
from typing import Optional, Tuple
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")
@ -92,7 +95,15 @@ def parse_python_list_for_function_calls(input_string):
# Extract keyword arguments
for keyword in node.keywords:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
try:
function_args[keyword.arg] = ast.literal_eval(keyword.value)
except ValueError as e:
logger.error(
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
)
raise ValueError(
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
) from e
result.append((function_name, function_args))