feat(agents)!: changing agents API signatures to use OpenAI types

Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
Ashwin Bharambe 2025-10-10 13:04:41 -07:00
parent 548ccff368
commit c56b2deb7d
6 changed files with 392 additions and 305 deletions

View file

@ -11,19 +11,18 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.content_types import URL, ContentDelta
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
OpenAIAssistantMessageParam,
OpenAIChatCompletionMessageContent,
OpenAIChatCompletionToolCall,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIToolMessageParam,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -63,7 +62,7 @@ class Attachment(BaseModel):
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
content: OpenAIChatCompletionMessageContent | URL
mime_type: str
@ -74,7 +73,7 @@ class Document(BaseModel):
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
content: OpenAIChatCompletionMessageContent | URL
mime_type: str
@ -108,6 +107,7 @@ class StepType(StrEnum):
memory_retrieval = "memory_retrieval"
@json_schema_type
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
@ -118,7 +118,8 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
model_response: OpenAIAssistantMessageParam
finish_reason: str | None = None
@json_schema_type
@ -130,8 +131,8 @@ class ToolExecutionStep(StepCommon):
"""
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
tool_calls: list[OpenAIChatCompletionToolCall]
tool_responses: list[OpenAIToolMessageParam]
@json_schema_type
@ -156,7 +157,7 @@ class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
inserted_context: OpenAIChatCompletionMessageContent
Step = Annotated[
@ -181,9 +182,10 @@ class Turn(BaseModel):
turn_id: str
session_id: str
input_messages: list[UserMessage | ToolResponseMessage]
input_messages: list[OpenAIMessageParam]
steps: list[Step]
output_message: CompletionMessage
output_message: OpenAIAssistantMessageParam
finish_reason: str | None = None
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
@ -216,31 +218,22 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stop: list[str] | None = None
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[OpenAIResponseInputTool | ToolDef] | None = Field(default_factory=list)
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
if self.tool_config is None:
self.tool_config = ToolConfig()
@json_schema_type
@ -258,7 +251,7 @@ class AgentConfig(AgentConfigCommon):
instructions: str
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
response_format: OpenAIResponseFormatParam | None = None
@json_schema_type
@ -434,10 +427,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
messages: list[OpenAIMessageParam]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
@ -460,7 +450,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: list[ToolResponse]
tool_responses: list[OpenAIToolMessageParam]
stream: bool | None = False
@ -531,7 +521,7 @@ class Agents(Protocol):
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
messages: list[OpenAIMessageParam],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
@ -569,7 +559,7 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
tool_responses: list[OpenAIToolMessageParam],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.

View file

@ -10,12 +10,14 @@ import re
import uuid
import warnings
from collections.abc import AsyncGenerator
from typing import Any
from datetime import UTC, datetime
import httpx
from llama_stack.apis.agents import (
AgentConfig,
OpenAIResponseInputTool,
AgentToolGroup,
AgentToolGroupWithArgs,
AgentTurnCreateRequest,
@ -32,16 +34,12 @@ from llama_stack.apis.agents import (
Document,
InferenceStep,
ShieldCallStep,
Step,
StepType,
ToolExecutionStep,
Turn,
)
from llama_stack.apis.common.content_types import (
URL,
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.apis.inference import (
ChatCompletionResponseEventType,
@ -50,20 +48,24 @@ from llama_stack.apis.inference import (
Message,
OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionMessageContent,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIImageURL,
OpenAIMessageParam,
OpenAISystemMessageParam,
OpenAIToolMessageParam,
OpenAIUserMessageParam,
SamplingParams,
StopReason,
SystemMessage,
ToolDefinition,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
@ -90,6 +92,167 @@ RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents::meta_reference")
def _map_finish_reason_to_stop_reason(finish_reason: str | None) -> StopReason:
if finish_reason == "length":
return StopReason.out_of_tokens
if finish_reason == "tool_calls":
return StopReason.end_of_message
# Default to end_of_turn for unknown or "stop"
return StopReason.end_of_turn
def _map_stop_reason_to_finish_reason(stop_reason: StopReason | None) -> str | None:
if stop_reason == StopReason.out_of_tokens:
return "length"
if stop_reason == StopReason.end_of_message:
return "tool_calls"
if stop_reason == StopReason.end_of_turn:
return "stop"
return None
def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> ToolCall:
name = None
if tool_call.function and tool_call.function.name:
name = tool_call.function.name
return ToolCall(
call_id=tool_call.id or f"call_{uuid.uuid4()}",
tool_name=name or "",
arguments=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "{}",
)
def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall:
function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
return OpenAIChatCompletionToolCall(
index=index,
id=tool_call.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=function_name,
arguments=tool_call.arguments,
),
)
def _tool_response_message_to_openai(response: ToolResponseMessage) -> OpenAIToolMessageParam:
content = _coerce_to_text(response.content)
return OpenAIToolMessageParam(
tool_call_id=response.call_id,
content=content,
)
def _openai_message_content_to_text(
content: OpenAIChatCompletionMessageContent,
) -> str:
if isinstance(content, str):
return content
parts = []
for item in content:
if isinstance(item, OpenAIChatCompletionContentPartTextParam):
parts.append(item.text)
elif isinstance(item, OpenAIChatCompletionContentPartImageParam) and item.image_url:
if item.image_url.url:
parts.append(item.image_url.url)
return "\n".join(parts)
def _append_text_to_openai_message(message: OpenAIMessageParam, text: str) -> None:
if not text:
return
if isinstance(message, OpenAIUserMessageParam):
content = message.content
if content is None or content == "":
message.content = text
elif isinstance(content, str):
message.content = f"{content}\n{text}"
else:
content.append(OpenAIChatCompletionContentPartTextParam(text=text))
def _coerce_to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
return "\n".join(_coerce_to_text(item) for item in content)
if hasattr(content, "text"):
return getattr(content, "text")
if hasattr(content, "image"):
image = getattr(content, "image")
if hasattr(image, "url") and image.url:
return getattr(image.url, "uri", "")
return str(content)
def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message:
if isinstance(message, OpenAIUserMessageParam):
return UserMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAISystemMessageParam):
return SystemMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAIToolMessageParam):
return ToolResponseMessage(
call_id=message.tool_call_id,
content=_openai_message_content_to_text(message.content),
)
if isinstance(message, OpenAIDeveloperMessageParam):
# Map developer messages to user role for legacy compatibility
return UserMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAIAssistantMessageParam):
tool_calls = [
_openai_tool_call_to_legacy(tool_call)
for tool_call in message.tool_calls or []
]
return CompletionMessage(
content=_openai_message_content_to_text(message.content) if message.content is not None else "",
stop_reason=StopReason.end_of_turn,
tool_calls=tool_calls,
)
raise ValueError(f"Unsupported OpenAI message type: {type(message)}")
async def _legacy_message_to_openai(message: Message) -> OpenAIMessageParam:
openai_dict = await convert_message_to_openai_dict_new(message)
role = openai_dict.get("role")
if role == "user":
return OpenAIUserMessageParam(**openai_dict)
if role == "system":
return OpenAISystemMessageParam(**openai_dict)
if role == "assistant":
return OpenAIAssistantMessageParam(**openai_dict)
if role == "tool":
return OpenAIToolMessageParam(**openai_dict)
if role == "developer":
return OpenAIDeveloperMessageParam(**openai_dict)
raise ValueError(f"Unsupported OpenAI message role: {role}")
async def _completion_to_openai_assistant(
completion: CompletionMessage,
) -> tuple[OpenAIAssistantMessageParam, str | None]:
assistant_param = await _legacy_message_to_openai(completion)
assert isinstance(assistant_param, OpenAIAssistantMessageParam)
finish_reason = _map_stop_reason_to_finish_reason(completion.stop_reason)
return assistant_param, finish_reason
def _client_tool_to_tool_definition(tool: OpenAIResponseInputTool | ToolDef) -> ToolDefinition:
if isinstance(tool, ToolDef):
return ToolDefinition(
tool_name=tool.name,
description=tool.description,
input_schema=tool.input_schema,
)
if getattr(tool, "type", None) == "function":
return ToolDefinition(
tool_name=tool.name,
description=getattr(tool, "description", None),
input_schema=getattr(tool, "parameters", None),
)
raise ValueError(f"Unsupported client tool type '{getattr(tool, 'type', None)}' for agent configuration")
class ChatAgent(ShieldRunnerMixin):
def __init__(
self,
@ -123,59 +286,70 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields,
)
def turn_to_messages(self, turn: Turn) -> list[Message]:
messages = []
def _resolve_generation_options(
self,
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
) -> dict[str, Any]:
def _pick(attr: str) -> Any:
value = getattr(request, attr, None)
if value is not None:
return value
return getattr(self.agent_config, attr)
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
tool_call_ids = set()
for step in turn.steps:
if step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
temperature = _pick("temperature")
top_p = _pick("top_p")
max_output_tokens = _pick("max_output_tokens")
stop = _pick("stop")
for m in turn.input_messages:
msg = m.model_copy()
# We do not want to keep adding RAG context to the input messages
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way
if isinstance(msg, UserMessage):
msg.context = None
if isinstance(msg, ToolResponseMessage):
if msg.call_id in tool_call_ids:
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
continue
# Ensure we don't share mutable defaults
if isinstance(stop, list):
stop = list(stop)
messages.append(msg)
return {
"temperature": temperature,
"top_p": top_p,
"max_output_tokens": max_output_tokens,
"stop": stop,
}
def turn_to_messages(self, turn: Turn) -> list[OpenAIMessageParam]:
messages: list[OpenAIMessageParam] = []
tool_response_ids = {
response.tool_call_id
for step in turn.steps
if step.step_type == StepType.tool_execution.value
for response in step.tool_responses
}
for message in turn.input_messages:
copied = message.model_copy(deep=True)
if isinstance(copied, OpenAIToolMessageParam) and copied.tool_call_id in tool_response_ids:
# Skip tool responses; they will be reintroduced from the execution step
continue
messages.append(copied)
for step in turn.steps:
if step.step_type == StepType.inference.value:
messages.append(step.model_response)
messages.append(step.model_response.model_copy(deep=True))
elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses:
messages.append(
ToolResponseMessage(
call_id=response.call_id,
content=response.content,
)
)
elif step.step_type == StepType.shield_call.value:
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=step.violation.user_message,
stop_reason=StopReason.end_of_turn,
)
)
messages.append(response.model_copy(deep=True))
elif step.step_type == StepType.shield_call.value and step.violation:
assistant_msg = OpenAIAssistantMessageParam(
content=str(step.violation.user_message),
)
messages.append(assistant_msg)
return messages
async def create_session(self, name: str) -> str:
return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
messages = []
if self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions))
async def get_messages_from_turns(self, turns: list[Turn]) -> list[OpenAIMessageParam]:
messages: list[OpenAIMessageParam] = []
if self.agent_config.instructions:
messages.append(OpenAISystemMessageParam(content=self.agent_config.instructions))
for turn in turns:
messages.extend(self.turn_to_messages(turn))
@ -228,26 +402,19 @@ class ChatAgent(ShieldRunnerMixin):
if is_resume and len(turns) == 0:
raise ValueError("No turns found for session")
steps = []
messages = await self.get_messages_from_turns(turns)
steps: list[Step] = []
history_openai = await self.get_messages_from_turns(turns)
if turn_id is None:
turn_id = request.turn_id
if is_resume:
tool_response_messages = [
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
]
messages.extend(tool_response_messages)
tool_response_messages = [resp.model_copy(deep=True) for resp in request.tool_responses]
history_openai.extend(tool_response_messages)
last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(tool_response_messages)
steps = list(last_turn.steps)
# get steps from the turn
steps = last_turn.steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id
)
@ -256,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses,
tool_responses=tool_response_messages,
completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
)
@ -270,26 +437,34 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
input_messages = last_turn.input_messages
turn_id = request.turn_id
input_messages_openai = [msg.model_copy(deep=True) for msg in last_turn.input_messages]
start_time = last_turn.started_at
else:
messages.extend(request.messages)
new_messages = [msg.model_copy(deep=True) for msg in request.messages]
history_openai.extend(new_messages)
input_messages_openai = new_messages
start_time = datetime.now(UTC).isoformat()
input_messages = request.messages
output_message = None
generation_options = self._resolve_generation_options(request)
output_completion: CompletionMessage | None = None
output_finish_reason: str | None = None
output_assistant_message: OpenAIAssistantMessageParam | None = None
async for chunk in self.run(
session_id=request.session_id,
turn_id=turn_id,
input_messages=messages,
sampling_params=self.agent_config.sampling_params,
input_messages=history_openai,
stream=request.stream,
documents=request.documents if not is_resume else None,
temperature=generation_options["temperature"],
top_p=generation_options["top_p"],
max_output_tokens=generation_options["max_output_tokens"],
stop=generation_options["stop"],
):
if isinstance(chunk, CompletionMessage):
output_message = chunk
output_completion = chunk
output_assistant_message, output_finish_reason = await _completion_to_openai_assistant(chunk)
continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
@ -299,19 +474,21 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk
assert output_message is not None
assert output_completion is not None
assert output_assistant_message is not None
turn = Turn(
turn_id=turn_id,
session_id=request.session_id,
input_messages=input_messages,
output_message=output_message,
input_messages=input_messages_openai,
output_message=output_assistant_message,
finish_reason=output_finish_reason,
started_at=start_time,
completed_at=datetime.now(UTC).isoformat(),
steps=steps,
)
await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls:
if output_assistant_message.tool_calls:
chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload(
@ -334,10 +511,13 @@ class ChatAgent(ShieldRunnerMixin):
self,
session_id: str,
turn_id: str,
input_messages: list[Message],
sampling_params: SamplingParams,
input_messages: list[OpenAIMessageParam],
stream: bool = False,
documents: list[Document] | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_output_tokens: int | None = None,
stop: list[str] | None = None,
) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -357,9 +537,12 @@ class ChatAgent(ShieldRunnerMixin):
session_id,
turn_id,
input_messages,
sampling_params,
stream,
documents,
temperature,
top_p,
max_output_tokens,
stop,
):
if isinstance(res, bool):
return
@ -370,8 +553,9 @@ class ChatAgent(ShieldRunnerMixin):
yield res
assert final_response is not None
final_assistant, final_finish_reason = await _completion_to_openai_assistant(copy.deepcopy(final_response))
# for output shields run on the full input and output combination
messages = input_messages + [final_response]
messages = input_messages + [final_assistant.model_copy(deep=True)]
if len(self.output_shields) > 0:
async for res in self.run_multiple_shields_wrapper(
@ -387,7 +571,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper(
self,
turn_id: str,
messages: list[Message],
messages: list[OpenAIMessageParam],
shields: list[str],
touchpoint: str,
) -> AsyncGenerator:
@ -412,7 +596,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
)
await self.run_multiple_shields(messages, shields)
legacy_messages = [_openai_message_param_to_legacy(m) for m in messages]
await self.run_multiple_shields(legacy_messages, shields)
except SafetyException as e:
yield AgentTurnResponseStreamChunk(
@ -461,29 +646,25 @@ class ChatAgent(ShieldRunnerMixin):
self,
session_id: str,
turn_id: str,
input_messages: list[Message],
sampling_params: SamplingParams,
input_messages: list[OpenAIMessageParam],
stream: bool = False,
documents: list[Document] | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_output_tokens: int | None = None,
stop: list[str] | None = None,
) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document
# and sent it as a user message
if documents:
contexts = []
conversation = [msg.model_copy(deep=True) for msg in input_messages]
# if document is passed in a turn, hydrate the last user message with the context
if documents and conversation:
appended_texts = []
for document in documents:
raw_document_text = await get_raw_document_text(document)
contexts.append(raw_document_text)
attached_context = "\n".join(contexts)
if isinstance(input_messages[-1].content, str):
input_messages[-1].content += attached_context
elif isinstance(input_messages[-1].content, list):
input_messages[-1].content.append(TextContentItem(text=attached_context))
else:
input_messages[-1].content = [
input_messages[-1].content,
TextContentItem(text=attached_context),
]
if raw_document_text:
appended_texts.append(raw_document_text)
if appended_texts:
_append_text_to_openai_message(conversation[-1], "\n".join(appended_texts))
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it
@ -500,9 +681,13 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup
client_tools = {}
for tool in self.agent_config.client_tools:
client_tools[tool.name] = tool
client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {}
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
if isinstance(tool, ToolDef) and tool.name:
client_tools[tool.name] = tool
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
client_tools[tool.name] = tool
while True:
step_id = str(uuid.uuid4())
inference_start_time = datetime.now(UTC).isoformat()
@ -520,81 +705,33 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason: StopReason | None = None
async with tracing.span("inference") as span:
if self.telemetry_enabled and span is not None:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
if self.telemetry_enabled and span is not None and self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name)
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
from pydantic import BaseModel
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
elif isinstance(value, dict):
return {k: _serialize_nested(v) for k, v in value.items()}
elif isinstance(value, list):
return [_serialize_nested(item) for item in value]
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
role = openai_msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**openai_msg)
elif role == "system":
return OpenAISystemMessageParam(**openai_msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**openai_msg)
elif role == "tool":
return OpenAIToolMessageParam(**openai_msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**openai_msg)
else:
raise ValueError(f"Unknown message role: {role}")
# Convert messages to OpenAI format
openai_messages: list[OpenAIMessageParam] = [
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
]
# Convert tool definitions to OpenAI format
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
# Extract tool_choice from tool_config for OpenAI compatibility
# Note: tool_choice can only be provided when tools are also provided
tool_choice = None
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
tc = self.agent_config.tool_config.tool_choice
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
# Convert tool_choice to OpenAI format
if tool_choice_str in ("auto", "none", "required"):
tool_choice = tool_choice_str
else:
# It's a specific tool name, wrap it in the proper format
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion(
model=self.agent_config.model,
messages=openai_messages,
messages=[msg.model_copy(deep=True) for msg in conversation],
tools=openai_tools if openai_tools else None,
tool_choice=tool_choice,
response_format=self.agent_config.response_format,
temperature=temperature,
top_p=top_p,
max_tokens=max_tokens,
max_tokens=max_output_tokens,
stop=stop,
stream=True,
)
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True
)
@ -644,7 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
json.dumps([json.loads(m.model_copy(deep=True).model_dump_json()) for m in conversation]),
)
output_attr = json.dumps(
{
@ -671,6 +808,8 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls=tool_calls,
)
assistant_param, finish_reason = await _completion_to_openai_assistant(copy.deepcopy(message))
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
@ -682,7 +821,8 @@ class ChatAgent(ShieldRunnerMixin):
# `deepcopy` for now, but this is symptomatic of a deeper issue.
step_id=step_id,
turn_id=turn_id,
model_response=copy.deepcopy(message),
model_response=assistant_param,
finish_reason=finish_reason,
started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(),
),
@ -703,9 +843,10 @@ class ChatAgent(ShieldRunnerMixin):
yield message
break
assistant_param = assistant_param.model_copy(deep=True)
if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0:
if isinstance(message.content, list):
message.content += output_attachments
@ -714,18 +855,20 @@ class ChatAgent(ShieldRunnerMixin):
yield message
else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
conversation.append(assistant_param)
else:
input_messages = input_messages + [message]
conversation.append(assistant_param)
# Process tool calls in the message
client_tool_calls = []
non_client_tool_calls = []
client_tool_calls_openai = []
# Separate client and non-client tool calls
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
client_tool_calls_openai.append(_legacy_tool_call_to_openai(tool_call))
else:
non_client_tool_calls.append(tool_call)
@ -781,18 +924,14 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
openai_tool_call = _legacy_tool_call_to_openai(tool_call)
openai_tool_response = _tool_response_message_to_openai(result_message)
tool_execution_step = ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
metadata=tool_result.metadata,
)
],
tool_calls=[openai_tool_call],
tool_responses=[openai_tool_response],
started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(),
)
@ -808,8 +947,8 @@ class ChatAgent(ShieldRunnerMixin):
)
)
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)
# Add the result message to conversation for the next iteration
conversation.append(openai_tool_response)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
@ -829,7 +968,7 @@ class ChatAgent(ShieldRunnerMixin):
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=client_tool_calls,
tool_calls=client_tool_calls_openai,
tool_responses=[],
started_at=datetime.now(UTC).isoformat(),
),
@ -866,19 +1005,15 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {}
tool_name_to_args = {}
tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {}
tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
for tool_def in self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None):
raise ValueError(f"Tool {tool_def.name} already exists")
# Use input_schema from ToolDef directly
tool_name_to_def[tool_def.name] = ToolDefinition(
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools:
tool_definition = _client_tool_to_tool_definition(tool)
if tool_name_to_def.get(tool_definition.tool_name):
raise ValueError(f"Tool {tool_definition.tool_name} already exists")
tool_name_to_def[tool_definition.tool_name] = tool_definition
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
@ -999,12 +1134,7 @@ async def get_raw_document_text(document: Document) -> str:
if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str):
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(f"Unexpected document content type: {type(document.content)}")
return _openai_message_content_to_text(document.content)
def _interpret_content_as_attachment(
@ -1015,7 +1145,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1)
data = json.loads(snippet)
return Attachment(
url=URL(uri="file://" + data["filepath"]),
content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"],
)

View file

@ -33,9 +33,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
OpenAIMessageParam,
OpenAIToolMessageParam,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents):
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
messages: list[OpenAIMessageParam],
toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None,
stream: bool | None = False,
@ -189,7 +188,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
tool_responses: list[OpenAIToolMessageParam],
stream: bool | None = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(

View file

@ -62,14 +62,9 @@ def agent_config(llama_stack_client, text_model_id):
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.0001,
"top_p": 0.9,
},
"max_tokens": 512,
},
temperature=0.0001,
top_p=0.9,
max_output_tokens=512,
tools=[],
input_shields=available_shields,
output_shields=available_shields,
@ -83,14 +78,9 @@ def agent_config_without_safety(text_model_id):
agent_config = dict(
model=text_model_id,
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 0.0001,
"top_p": 0.9,
},
"max_tokens": 512,
},
temperature=0.0001,
top_p=0.9,
max_output_tokens=512,
tools=[],
enable_session_persistence=False,
)
@ -194,14 +184,9 @@ def test_tool_config(agent_config):
common_params = dict(
model="meta-llama/Llama-3.2-3B-Instruct",
instructions="You are a helpful assistant",
sampling_params={
"strategy": {
"type": "top_p",
"temperature": 1.0,
"top_p": 0.9,
},
"max_tokens": 512,
},
temperature=1.0,
top_p=0.9,
max_output_tokens=512,
toolgroups=[],
enable_session_persistence=False,
)
@ -212,40 +197,25 @@ def test_tool_config(agent_config):
agent_config = AgentConfig(
**common_params,
tool_choice="auto",
tool_config=ToolConfig(tool_choice="auto"),
)
server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.auto
agent_config = AgentConfig(
**common_params,
tool_choice="auto",
tool_config=ToolConfig(
tool_choice="auto",
),
tool_config=ToolConfig(tool_choice="auto"),
)
server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.auto
agent_config = AgentConfig(
**common_params,
tool_config=ToolConfig(
tool_choice="required",
),
tool_config=ToolConfig(tool_choice="required"),
)
server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.required
agent_config = AgentConfig(
**common_params,
tool_choice="required",
tool_config=ToolConfig(
tool_choice="auto",
),
)
with pytest.raises(ValueError, match="tool_choice is deprecated"):
Server__AgentConfig(**agent_config)
def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = {

View file

@ -7,7 +7,7 @@
import pytest
from llama_stack.apis.agents import AgentConfig, Turn
from llama_stack.apis.inference import SamplingParams, UserMessage
from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture
def sample_messages():
return [
UserMessage(content="What's the weather like today?"),
OpenAIUserMessageParam(content="What's the weather like today?"),
]
@ -36,7 +36,9 @@ def common_params(inference_model):
model=inference_model,
instructions="You are a helpful assistant.",
enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
temperature=0.7,
top_p=0.95,
max_output_tokens=256,
input_shields=[],
output_shields=[],
tools=[],

View file

@ -69,30 +69,26 @@ async def agents_impl(config, mock_apis):
@pytest.fixture
def sample_agent_config():
return AgentConfig(
sampling_params={
"strategy": {"type": "greedy"},
"max_tokens": 0,
"repetition_penalty": 1.0,
},
temperature=0.0,
top_p=1.0,
max_output_tokens=0,
input_shields=["string"],
output_shields=["string"],
toolgroups=["mcp::my_mcp_server"],
client_tools=[
{
"type": "function",
"name": "client_tool",
"description": "Client Tool",
"parameters": [
{
"name": "string",
"parameter_type": "string",
"description": "string",
"required": True,
"default": None,
}
],
"metadata": {
"property1": None,
"property2": None,
"parameters": {
"type": "object",
"properties": {
"string": {
"type": "string",
"description": "string",
}
},
"required": ["string"],
},
}
],