From c56b2deb7d4fa10b5fa5b95279f663374eefef8b Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 10 Oct 2025 13:04:41 -0700 Subject: [PATCH] feat(agents)!: changing agents API signatures to use OpenAI types Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs. --- llama_stack/apis/agents/agents.py | 76 ++- .../agents/meta_reference/agent_instance.py | 520 +++++++++++------- .../inline/agents/meta_reference/agents.py | 9 +- tests/integration/agents/test_agents.py | 54 +- tests/integration/agents/test_persistence.py | 8 +- .../agent/test_meta_reference_agent.py | 30 +- 6 files changed, 392 insertions(+), 305 deletions(-) diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 5983b5c45..17b166ba2 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -11,19 +11,18 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field -from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent +from llama_stack.apis.common.content_types import URL, ContentDelta from llama_stack.apis.common.responses import Order, PaginatedResponse from llama_stack.apis.inference import ( - CompletionMessage, - ResponseFormat, - SamplingParams, - ToolCall, + OpenAIAssistantMessageParam, + OpenAIChatCompletionMessageContent, + OpenAIChatCompletionToolCall, + OpenAIMessageParam, + OpenAIResponseFormatParam, + OpenAIToolMessageParam, ToolChoice, ToolConfig, ToolPromptFormat, - ToolResponse, - ToolResponseMessage, - UserMessage, ) from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.tools import ToolDef @@ -63,7 +62,7 @@ class Attachment(BaseModel): :param mime_type: The MIME type of the attachment. """ - content: InterleavedContent | URL + content: OpenAIChatCompletionMessageContent | URL mime_type: str @@ -74,7 +73,7 @@ class Document(BaseModel): :param mime_type: The MIME type of the document. """ - content: InterleavedContent | URL + content: OpenAIChatCompletionMessageContent | URL mime_type: str @@ -108,6 +107,7 @@ class StepType(StrEnum): memory_retrieval = "memory_retrieval" +@json_schema_type @json_schema_type class InferenceStep(StepCommon): """An inference step in an agent turn. @@ -118,7 +118,8 @@ class InferenceStep(StepCommon): model_config = ConfigDict(protected_namespaces=()) step_type: Literal[StepType.inference] = StepType.inference - model_response: CompletionMessage + model_response: OpenAIAssistantMessageParam + finish_reason: str | None = None @json_schema_type @@ -130,8 +131,8 @@ class ToolExecutionStep(StepCommon): """ step_type: Literal[StepType.tool_execution] = StepType.tool_execution - tool_calls: list[ToolCall] - tool_responses: list[ToolResponse] + tool_calls: list[OpenAIChatCompletionToolCall] + tool_responses: list[OpenAIToolMessageParam] @json_schema_type @@ -156,7 +157,7 @@ class MemoryRetrievalStep(StepCommon): step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval # TODO: should this be List[str]? vector_db_ids: str - inserted_context: InterleavedContent + inserted_context: OpenAIChatCompletionMessageContent Step = Annotated[ @@ -181,9 +182,10 @@ class Turn(BaseModel): turn_id: str session_id: str - input_messages: list[UserMessage | ToolResponseMessage] + input_messages: list[OpenAIMessageParam] steps: list[Step] - output_message: CompletionMessage + output_message: OpenAIAssistantMessageParam + finish_reason: str | None = None output_attachments: list[Attachment] | None = Field(default_factory=lambda: []) started_at: datetime @@ -216,31 +218,22 @@ register_schema(AgentToolGroup, name="AgentTool") class AgentConfigCommon(BaseModel): - sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) + max_output_tokens: int | None = None + temperature: float | None = None + top_p: float | None = None + stop: list[str] | None = None - input_shields: list[str] | None = Field(default_factory=lambda: []) - output_shields: list[str] | None = Field(default_factory=lambda: []) - toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) - client_tools: list[ToolDef] | None = Field(default_factory=lambda: []) - tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead") - tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead") + input_shields: list[str] | None = Field(default_factory=list) + output_shields: list[str] | None = Field(default_factory=list) + toolgroups: list[AgentToolGroup] | None = Field(default_factory=list) + client_tools: list[OpenAIResponseInputTool | ToolDef] | None = Field(default_factory=list) tool_config: ToolConfig | None = Field(default=None) max_infer_iters: int | None = 10 def model_post_init(self, __context): - if self.tool_config: - if self.tool_choice and self.tool_config.tool_choice != self.tool_choice: - raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.") - if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format: - raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.") - else: - params = {} - if self.tool_choice: - params["tool_choice"] = self.tool_choice - if self.tool_prompt_format: - params["tool_prompt_format"] = self.tool_prompt_format - self.tool_config = ToolConfig(**params) + if self.tool_config is None: + self.tool_config = ToolConfig() @json_schema_type @@ -258,7 +251,7 @@ class AgentConfig(AgentConfigCommon): instructions: str name: str | None = None enable_session_persistence: bool | None = False - response_format: ResponseFormat | None = None + response_format: OpenAIResponseFormatParam | None = None @json_schema_type @@ -434,10 +427,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn): agent_id: str session_id: str - # TODO: figure out how we can simplify this and make why - # ToolResponseMessage needs to be here (it is function call - # execution from outside the system) - messages: list[UserMessage | ToolResponseMessage] + messages: list[OpenAIMessageParam] documents: list[Document] | None = None toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) @@ -460,7 +450,7 @@ class AgentTurnResumeRequest(BaseModel): agent_id: str session_id: str turn_id: str - tool_responses: list[ToolResponse] + tool_responses: list[OpenAIToolMessageParam] stream: bool | None = False @@ -531,7 +521,7 @@ class Agents(Protocol): self, agent_id: str, session_id: str, - messages: list[UserMessage | ToolResponseMessage], + messages: list[OpenAIMessageParam], stream: bool | None = False, documents: list[Document] | None = None, toolgroups: list[AgentToolGroup] | None = None, @@ -569,7 +559,7 @@ class Agents(Protocol): agent_id: str, session_id: str, turn_id: str, - tool_responses: list[ToolResponse], + tool_responses: list[OpenAIToolMessageParam], stream: bool | None = False, ) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]: """Resume an agent turn with executed tool call responses. diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index b17c720e9..1e7f1a9f6 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -10,12 +10,14 @@ import re import uuid import warnings from collections.abc import AsyncGenerator +from typing import Any from datetime import UTC, datetime import httpx from llama_stack.apis.agents import ( AgentConfig, + OpenAIResponseInputTool, AgentToolGroup, AgentToolGroupWithArgs, AgentTurnCreateRequest, @@ -32,16 +34,12 @@ from llama_stack.apis.agents import ( Document, InferenceStep, ShieldCallStep, + Step, StepType, ToolExecutionStep, Turn, ) -from llama_stack.apis.common.content_types import ( - URL, - TextContentItem, - ToolCallDelta, - ToolCallParseStatus, -) +from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.apis.inference import ( ChatCompletionResponseEventType, @@ -50,20 +48,24 @@ from llama_stack.apis.inference import ( Message, OpenAIAssistantMessageParam, OpenAIDeveloperMessageParam, + OpenAIChatCompletionContentPartImageParam, + OpenAIChatCompletionContentPartTextParam, + OpenAIChatCompletionMessageContent, + OpenAIChatCompletionToolCall, + OpenAIChatCompletionToolCallFunction, + OpenAIImageURL, OpenAIMessageParam, OpenAISystemMessageParam, OpenAIToolMessageParam, OpenAIUserMessageParam, - SamplingParams, StopReason, SystemMessage, ToolDefinition, - ToolResponse, ToolResponseMessage, UserMessage, ) from llama_stack.apis.safety import Safety -from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime +from llama_stack.apis.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.vector_io import VectorIO from llama_stack.core.datatypes import AccessRule from llama_stack.log import get_logger @@ -90,6 +92,167 @@ RAG_TOOL_GROUP = "builtin::rag" logger = get_logger(name=__name__, category="agents::meta_reference") +def _map_finish_reason_to_stop_reason(finish_reason: str | None) -> StopReason: + if finish_reason == "length": + return StopReason.out_of_tokens + if finish_reason == "tool_calls": + return StopReason.end_of_message + # Default to end_of_turn for unknown or "stop" + return StopReason.end_of_turn + + +def _map_stop_reason_to_finish_reason(stop_reason: StopReason | None) -> str | None: + if stop_reason == StopReason.out_of_tokens: + return "length" + if stop_reason == StopReason.end_of_message: + return "tool_calls" + if stop_reason == StopReason.end_of_turn: + return "stop" + return None + + +def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> ToolCall: + name = None + if tool_call.function and tool_call.function.name: + name = tool_call.function.name + return ToolCall( + call_id=tool_call.id or f"call_{uuid.uuid4()}", + tool_name=name or "", + arguments=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "{}", + ) + + +def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall: + function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value + return OpenAIChatCompletionToolCall( + index=index, + id=tool_call.call_id, + function=OpenAIChatCompletionToolCallFunction( + name=function_name, + arguments=tool_call.arguments, + ), + ) + + +def _tool_response_message_to_openai(response: ToolResponseMessage) -> OpenAIToolMessageParam: + content = _coerce_to_text(response.content) + return OpenAIToolMessageParam( + tool_call_id=response.call_id, + content=content, + ) + + +def _openai_message_content_to_text( + content: OpenAIChatCompletionMessageContent, +) -> str: + if isinstance(content, str): + return content + parts = [] + for item in content: + if isinstance(item, OpenAIChatCompletionContentPartTextParam): + parts.append(item.text) + elif isinstance(item, OpenAIChatCompletionContentPartImageParam) and item.image_url: + if item.image_url.url: + parts.append(item.image_url.url) + return "\n".join(parts) + + +def _append_text_to_openai_message(message: OpenAIMessageParam, text: str) -> None: + if not text: + return + if isinstance(message, OpenAIUserMessageParam): + content = message.content + if content is None or content == "": + message.content = text + elif isinstance(content, str): + message.content = f"{content}\n{text}" + else: + content.append(OpenAIChatCompletionContentPartTextParam(text=text)) + + +def _coerce_to_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + return "\n".join(_coerce_to_text(item) for item in content) + if hasattr(content, "text"): + return getattr(content, "text") + if hasattr(content, "image"): + image = getattr(content, "image") + if hasattr(image, "url") and image.url: + return getattr(image.url, "uri", "") + return str(content) + + +def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message: + if isinstance(message, OpenAIUserMessageParam): + return UserMessage(content=_openai_message_content_to_text(message.content)) + if isinstance(message, OpenAISystemMessageParam): + return SystemMessage(content=_openai_message_content_to_text(message.content)) + if isinstance(message, OpenAIToolMessageParam): + return ToolResponseMessage( + call_id=message.tool_call_id, + content=_openai_message_content_to_text(message.content), + ) + if isinstance(message, OpenAIDeveloperMessageParam): + # Map developer messages to user role for legacy compatibility + return UserMessage(content=_openai_message_content_to_text(message.content)) + if isinstance(message, OpenAIAssistantMessageParam): + tool_calls = [ + _openai_tool_call_to_legacy(tool_call) + for tool_call in message.tool_calls or [] + ] + return CompletionMessage( + content=_openai_message_content_to_text(message.content) if message.content is not None else "", + stop_reason=StopReason.end_of_turn, + tool_calls=tool_calls, + ) + raise ValueError(f"Unsupported OpenAI message type: {type(message)}") + + +async def _legacy_message_to_openai(message: Message) -> OpenAIMessageParam: + openai_dict = await convert_message_to_openai_dict_new(message) + role = openai_dict.get("role") + if role == "user": + return OpenAIUserMessageParam(**openai_dict) + if role == "system": + return OpenAISystemMessageParam(**openai_dict) + if role == "assistant": + return OpenAIAssistantMessageParam(**openai_dict) + if role == "tool": + return OpenAIToolMessageParam(**openai_dict) + if role == "developer": + return OpenAIDeveloperMessageParam(**openai_dict) + raise ValueError(f"Unsupported OpenAI message role: {role}") + + +async def _completion_to_openai_assistant( + completion: CompletionMessage, +) -> tuple[OpenAIAssistantMessageParam, str | None]: + assistant_param = await _legacy_message_to_openai(completion) + assert isinstance(assistant_param, OpenAIAssistantMessageParam) + finish_reason = _map_stop_reason_to_finish_reason(completion.stop_reason) + return assistant_param, finish_reason + + +def _client_tool_to_tool_definition(tool: OpenAIResponseInputTool | ToolDef) -> ToolDefinition: + if isinstance(tool, ToolDef): + return ToolDefinition( + tool_name=tool.name, + description=tool.description, + input_schema=tool.input_schema, + ) + if getattr(tool, "type", None) == "function": + return ToolDefinition( + tool_name=tool.name, + description=getattr(tool, "description", None), + input_schema=getattr(tool, "parameters", None), + ) + raise ValueError(f"Unsupported client tool type '{getattr(tool, 'type', None)}' for agent configuration") + + class ChatAgent(ShieldRunnerMixin): def __init__( self, @@ -123,59 +286,70 @@ class ChatAgent(ShieldRunnerMixin): output_shields=agent_config.output_shields, ) - def turn_to_messages(self, turn: Turn) -> list[Message]: - messages = [] + def _resolve_generation_options( + self, + request: AgentTurnCreateRequest | AgentTurnResumeRequest, + ) -> dict[str, Any]: + def _pick(attr: str) -> Any: + value = getattr(request, attr, None) + if value is not None: + return value + return getattr(self.agent_config, attr) - # NOTE: if a toolcall response is in a step, we do not add it when processing the input messages - tool_call_ids = set() - for step in turn.steps: - if step.step_type == StepType.tool_execution.value: - for response in step.tool_responses: - tool_call_ids.add(response.call_id) + temperature = _pick("temperature") + top_p = _pick("top_p") + max_output_tokens = _pick("max_output_tokens") + stop = _pick("stop") - for m in turn.input_messages: - msg = m.model_copy() - # We do not want to keep adding RAG context to the input messages - # May be this should be a parameter of the agentic instance - # that can define its behavior in a custom way - if isinstance(msg, UserMessage): - msg.context = None - if isinstance(msg, ToolResponseMessage): - if msg.call_id in tool_call_ids: - # NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps - continue + # Ensure we don't share mutable defaults + if isinstance(stop, list): + stop = list(stop) - messages.append(msg) + return { + "temperature": temperature, + "top_p": top_p, + "max_output_tokens": max_output_tokens, + "stop": stop, + } + + def turn_to_messages(self, turn: Turn) -> list[OpenAIMessageParam]: + messages: list[OpenAIMessageParam] = [] + + tool_response_ids = { + response.tool_call_id + for step in turn.steps + if step.step_type == StepType.tool_execution.value + for response in step.tool_responses + } + + for message in turn.input_messages: + copied = message.model_copy(deep=True) + if isinstance(copied, OpenAIToolMessageParam) and copied.tool_call_id in tool_response_ids: + # Skip tool responses; they will be reintroduced from the execution step + continue + messages.append(copied) for step in turn.steps: if step.step_type == StepType.inference.value: - messages.append(step.model_response) + messages.append(step.model_response.model_copy(deep=True)) elif step.step_type == StepType.tool_execution.value: for response in step.tool_responses: - messages.append( - ToolResponseMessage( - call_id=response.call_id, - content=response.content, - ) - ) - elif step.step_type == StepType.shield_call.value: - if step.violation: - # CompletionMessage itself in the ShieldResponse - messages.append( - CompletionMessage( - content=step.violation.user_message, - stop_reason=StopReason.end_of_turn, - ) - ) + messages.append(response.model_copy(deep=True)) + elif step.step_type == StepType.shield_call.value and step.violation: + assistant_msg = OpenAIAssistantMessageParam( + content=str(step.violation.user_message), + ) + messages.append(assistant_msg) + return messages async def create_session(self, name: str) -> str: return await self.storage.create_session(name) - async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: - messages = [] - if self.agent_config.instructions != "": - messages.append(SystemMessage(content=self.agent_config.instructions)) + async def get_messages_from_turns(self, turns: list[Turn]) -> list[OpenAIMessageParam]: + messages: list[OpenAIMessageParam] = [] + if self.agent_config.instructions: + messages.append(OpenAISystemMessageParam(content=self.agent_config.instructions)) for turn in turns: messages.extend(self.turn_to_messages(turn)) @@ -228,26 +402,19 @@ class ChatAgent(ShieldRunnerMixin): if is_resume and len(turns) == 0: raise ValueError("No turns found for session") - steps = [] - messages = await self.get_messages_from_turns(turns) + steps: list[Step] = [] + history_openai = await self.get_messages_from_turns(turns) + + if turn_id is None: + turn_id = request.turn_id + if is_resume: - tool_response_messages = [ - ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses - ] - messages.extend(tool_response_messages) + tool_response_messages = [resp.model_copy(deep=True) for resp in request.tool_responses] + history_openai.extend(tool_response_messages) + last_turn = turns[-1] - last_turn_messages = self.turn_to_messages(last_turn) - last_turn_messages = [ - x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage) - ] - last_turn_messages.extend(tool_response_messages) + steps = list(last_turn.steps) - # get steps from the turn - steps = last_turn.steps - - # mark tool execution step as complete - # if there's no tool execution in progress step (due to storage, or tool call parsing on client), - # we'll create a new tool execution step with current time in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( request.session_id, request.turn_id ) @@ -256,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin): step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), turn_id=request.turn_id, tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), - tool_responses=request.tool_responses, + tool_responses=tool_response_messages, completed_at=now, started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), ) @@ -270,26 +437,34 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - input_messages = last_turn.input_messages - turn_id = request.turn_id + input_messages_openai = [msg.model_copy(deep=True) for msg in last_turn.input_messages] start_time = last_turn.started_at else: - messages.extend(request.messages) + new_messages = [msg.model_copy(deep=True) for msg in request.messages] + history_openai.extend(new_messages) + input_messages_openai = new_messages start_time = datetime.now(UTC).isoformat() - input_messages = request.messages - output_message = None + generation_options = self._resolve_generation_options(request) + + output_completion: CompletionMessage | None = None + output_finish_reason: str | None = None + output_assistant_message: OpenAIAssistantMessageParam | None = None async for chunk in self.run( session_id=request.session_id, turn_id=turn_id, - input_messages=messages, - sampling_params=self.agent_config.sampling_params, + input_messages=history_openai, stream=request.stream, documents=request.documents if not is_resume else None, + temperature=generation_options["temperature"], + top_p=generation_options["top_p"], + max_output_tokens=generation_options["max_output_tokens"], + stop=generation_options["stop"], ): if isinstance(chunk, CompletionMessage): - output_message = chunk + output_completion = chunk + output_assistant_message, output_finish_reason = await _completion_to_openai_assistant(chunk) continue assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" @@ -299,19 +474,21 @@ class ChatAgent(ShieldRunnerMixin): yield chunk - assert output_message is not None + assert output_completion is not None + assert output_assistant_message is not None turn = Turn( turn_id=turn_id, session_id=request.session_id, - input_messages=input_messages, - output_message=output_message, + input_messages=input_messages_openai, + output_message=output_assistant_message, + finish_reason=output_finish_reason, started_at=start_time, completed_at=datetime.now(UTC).isoformat(), steps=steps, ) await self.storage.add_turn_to_session(request.session_id, turn) - if output_message.tool_calls: + if output_assistant_message.tool_calls: chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseTurnAwaitingInputPayload( @@ -334,10 +511,13 @@ class ChatAgent(ShieldRunnerMixin): self, session_id: str, turn_id: str, - input_messages: list[Message], - sampling_params: SamplingParams, + input_messages: list[OpenAIMessageParam], stream: bool = False, documents: list[Document] | None = None, + temperature: float | None = None, + top_p: float | None = None, + max_output_tokens: int | None = None, + stop: list[str] | None = None, ) -> AsyncGenerator: # Doing async generators makes downstream code much simpler and everything amenable to # streaming. However, it also makes things complicated here because AsyncGenerators cannot @@ -357,9 +537,12 @@ class ChatAgent(ShieldRunnerMixin): session_id, turn_id, input_messages, - sampling_params, stream, documents, + temperature, + top_p, + max_output_tokens, + stop, ): if isinstance(res, bool): return @@ -370,8 +553,9 @@ class ChatAgent(ShieldRunnerMixin): yield res assert final_response is not None + final_assistant, final_finish_reason = await _completion_to_openai_assistant(copy.deepcopy(final_response)) # for output shields run on the full input and output combination - messages = input_messages + [final_response] + messages = input_messages + [final_assistant.model_copy(deep=True)] if len(self.output_shields) > 0: async for res in self.run_multiple_shields_wrapper( @@ -387,7 +571,7 @@ class ChatAgent(ShieldRunnerMixin): async def run_multiple_shields_wrapper( self, turn_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], shields: list[str], touchpoint: str, ) -> AsyncGenerator: @@ -412,7 +596,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) ) - await self.run_multiple_shields(messages, shields) + legacy_messages = [_openai_message_param_to_legacy(m) for m in messages] + await self.run_multiple_shields(legacy_messages, shields) except SafetyException as e: yield AgentTurnResponseStreamChunk( @@ -461,29 +646,25 @@ class ChatAgent(ShieldRunnerMixin): self, session_id: str, turn_id: str, - input_messages: list[Message], - sampling_params: SamplingParams, + input_messages: list[OpenAIMessageParam], stream: bool = False, documents: list[Document] | None = None, + temperature: float | None = None, + top_p: float | None = None, + max_output_tokens: int | None = None, + stop: list[str] | None = None, ) -> AsyncGenerator: - # if document is passed in a turn, we parse the raw text of the document - # and sent it as a user message - if documents: - contexts = [] + conversation = [msg.model_copy(deep=True) for msg in input_messages] + + # if document is passed in a turn, hydrate the last user message with the context + if documents and conversation: + appended_texts = [] for document in documents: raw_document_text = await get_raw_document_text(document) - contexts.append(raw_document_text) - - attached_context = "\n".join(contexts) - if isinstance(input_messages[-1].content, str): - input_messages[-1].content += attached_context - elif isinstance(input_messages[-1].content, list): - input_messages[-1].content.append(TextContentItem(text=attached_context)) - else: - input_messages[-1].content = [ - input_messages[-1].content, - TextContentItem(text=attached_context), - ] + if raw_document_text: + appended_texts.append(raw_document_text) + if appended_texts: + _append_text_to_openai_message(conversation[-1], "\n".join(appended_texts)) session_info = await self.storage.get_session_info(session_id) # if the session has a memory bank id, let the memory tool use it @@ -500,9 +681,13 @@ class ChatAgent(ShieldRunnerMixin): n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 # Build a map of custom tools to their definitions for faster lookup - client_tools = {} - for tool in self.agent_config.client_tools: - client_tools[tool.name] = tool + client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {} + if self.agent_config.client_tools: + for tool in self.agent_config.client_tools: + if isinstance(tool, ToolDef) and tool.name: + client_tools[tool.name] = tool + elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None): + client_tools[tool.name] = tool while True: step_id = str(uuid.uuid4()) inference_start_time = datetime.now(UTC).isoformat() @@ -520,81 +705,33 @@ class ChatAgent(ShieldRunnerMixin): stop_reason: StopReason | None = None async with tracing.span("inference") as span: - if self.telemetry_enabled and span is not None: - if self.agent_config.name: - span.set_attribute("agent_name", self.agent_config.name) + if self.telemetry_enabled and span is not None and self.agent_config.name: + span.set_attribute("agent_name", self.agent_config.name) - def _serialize_nested(value): - """Recursively serialize nested Pydantic models to dicts.""" - from pydantic import BaseModel - - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - elif isinstance(value, dict): - return {k: _serialize_nested(v) for k, v in value.items()} - elif isinstance(value, list): - return [_serialize_nested(item) for item in value] - else: - return value - - def _add_type(openai_msg: dict) -> OpenAIMessageParam: - # Serialize any nested Pydantic models to plain dicts - openai_msg = _serialize_nested(openai_msg) - - role = openai_msg.get("role") - if role == "user": - return OpenAIUserMessageParam(**openai_msg) - elif role == "system": - return OpenAISystemMessageParam(**openai_msg) - elif role == "assistant": - return OpenAIAssistantMessageParam(**openai_msg) - elif role == "tool": - return OpenAIToolMessageParam(**openai_msg) - elif role == "developer": - return OpenAIDeveloperMessageParam(**openai_msg) - else: - raise ValueError(f"Unknown message role: {role}") - - # Convert messages to OpenAI format - openai_messages: list[OpenAIMessageParam] = [ - _add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages - ] - - # Convert tool definitions to OpenAI format openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])] - # Extract tool_choice from tool_config for OpenAI compatibility - # Note: tool_choice can only be provided when tools are also provided tool_choice = None if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice: tc = self.agent_config.tool_config.tool_choice tool_choice_str = tc.value if hasattr(tc, "value") else str(tc) - # Convert tool_choice to OpenAI format if tool_choice_str in ("auto", "none", "required"): tool_choice = tool_choice_str else: - # It's a specific tool name, wrap it in the proper format tool_choice = {"type": "function", "function": {"name": tool_choice_str}} - # Convert sampling params to OpenAI format (temperature, top_p, max_tokens) - temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None) - top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None) - max_tokens = getattr(sampling_params, "max_tokens", None) - - # Use OpenAI chat completion openai_stream = await self.inference_api.openai_chat_completion( model=self.agent_config.model, - messages=openai_messages, + messages=[msg.model_copy(deep=True) for msg in conversation], tools=openai_tools if openai_tools else None, tool_choice=tool_choice, response_format=self.agent_config.response_format, temperature=temperature, top_p=top_p, - max_tokens=max_tokens, + max_tokens=max_output_tokens, + stop=stop, stream=True, ) - # Convert OpenAI stream back to Llama Stack format response_stream = convert_openai_chat_completion_stream( openai_stream, enable_incremental_tool_calls=True ) @@ -644,7 +781,7 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn) span.set_attribute( "input", - json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), + json.dumps([json.loads(m.model_copy(deep=True).model_dump_json()) for m in conversation]), ) output_attr = json.dumps( { @@ -671,6 +808,8 @@ class ChatAgent(ShieldRunnerMixin): tool_calls=tool_calls, ) + assistant_param, finish_reason = await _completion_to_openai_assistant(copy.deepcopy(message)) + yield AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( payload=AgentTurnResponseStepCompletePayload( @@ -682,7 +821,8 @@ class ChatAgent(ShieldRunnerMixin): # `deepcopy` for now, but this is symptomatic of a deeper issue. step_id=step_id, turn_id=turn_id, - model_response=copy.deepcopy(message), + model_response=assistant_param, + finish_reason=finish_reason, started_at=inference_start_time, completed_at=datetime.now(UTC).isoformat(), ), @@ -703,9 +843,10 @@ class ChatAgent(ShieldRunnerMixin): yield message break + assistant_param = assistant_param.model_copy(deep=True) + if len(message.tool_calls) == 0: if stop_reason == StopReason.end_of_turn: - # TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS) if len(output_attachments) > 0: if isinstance(message.content, list): message.content += output_attachments @@ -714,18 +855,20 @@ class ChatAgent(ShieldRunnerMixin): yield message else: logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") - input_messages = input_messages + [message] + conversation.append(assistant_param) else: - input_messages = input_messages + [message] + conversation.append(assistant_param) # Process tool calls in the message client_tool_calls = [] non_client_tool_calls = [] + client_tool_calls_openai = [] # Separate client and non-client tool calls for tool_call in message.tool_calls: if tool_call.tool_name in client_tools: client_tool_calls.append(tool_call) + client_tool_calls_openai.append(_legacy_tool_call_to_openai(tool_call)) else: non_client_tool_calls.append(tool_call) @@ -781,18 +924,14 @@ class ChatAgent(ShieldRunnerMixin): span.set_attribute("output", result_message.model_dump_json()) # Store tool execution step + openai_tool_call = _legacy_tool_call_to_openai(tool_call) + openai_tool_response = _tool_response_message_to_openai(result_message) + tool_execution_step = ToolExecutionStep( step_id=step_id, turn_id=turn_id, - tool_calls=[tool_call], - tool_responses=[ - ToolResponse( - call_id=tool_call.call_id, - tool_name=tool_call.tool_name, - content=tool_result.content, - metadata=tool_result.metadata, - ) - ], + tool_calls=[openai_tool_call], + tool_responses=[openai_tool_response], started_at=tool_execution_start_time, completed_at=datetime.now(UTC).isoformat(), ) @@ -808,8 +947,8 @@ class ChatAgent(ShieldRunnerMixin): ) ) - # Add the result message to input_messages for the next iteration - input_messages.append(result_message) + # Add the result message to conversation for the next iteration + conversation.append(openai_tool_response) # TODO: add tool-input touchpoint and a "start" event for this step also # but that needs a lot more refactoring of Tool code potentially @@ -829,7 +968,7 @@ class ChatAgent(ShieldRunnerMixin): ToolExecutionStep( step_id=step_id, turn_id=turn_id, - tool_calls=client_tool_calls, + tool_calls=client_tool_calls_openai, tool_responses=[], started_at=datetime.now(UTC).isoformat(), ), @@ -866,19 +1005,15 @@ class ChatAgent(ShieldRunnerMixin): toolgroup_to_args = toolgroup_to_args or {} - tool_name_to_def = {} - tool_name_to_args = {} + tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {} + tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {} - for tool_def in self.agent_config.client_tools: - if tool_name_to_def.get(tool_def.name, None): - raise ValueError(f"Tool {tool_def.name} already exists") - - # Use input_schema from ToolDef directly - tool_name_to_def[tool_def.name] = ToolDefinition( - tool_name=tool_def.name, - description=tool_def.description, - input_schema=tool_def.input_schema, - ) + if self.agent_config.client_tools: + for tool in self.agent_config.client_tools: + tool_definition = _client_tool_to_tool_definition(tool) + if tool_name_to_def.get(tool_definition.tool_name): + raise ValueError(f"Tool {tool_definition.tool_name} already exists") + tool_name_to_def[tool_definition.tool_name] = tool_definition for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) @@ -999,12 +1134,7 @@ async def get_raw_document_text(document: Document) -> str: if isinstance(document.content, URL): return await load_data_from_url(document.content.uri) - elif isinstance(document.content, str): - return document.content - elif isinstance(document.content, TextContentItem): - return document.content.text - else: - raise ValueError(f"Unexpected document content type: {type(document.content)}") + return _openai_message_content_to_text(document.content) def _interpret_content_as_attachment( @@ -1015,7 +1145,7 @@ def _interpret_content_as_attachment( snippet = match.group(1) data = json.loads(snippet) return Attachment( - url=URL(uri="file://" + data["filepath"]), + content=URL(uri="file://" + data["filepath"]), mime_type=data["mimetype"], ) diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index cfaf56a34..bfa04600e 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -33,9 +33,8 @@ from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.inference import ( Inference, ToolConfig, - ToolResponse, - ToolResponseMessage, - UserMessage, + OpenAIMessageParam, + OpenAIToolMessageParam, ) from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime @@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents): self, agent_id: str, session_id: str, - messages: list[UserMessage | ToolResponseMessage], + messages: list[OpenAIMessageParam], toolgroups: list[AgentToolGroup] | None = None, documents: list[Document] | None = None, stream: bool | None = False, @@ -189,7 +188,7 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, turn_id: str, - tool_responses: list[ToolResponse], + tool_responses: list[OpenAIToolMessageParam], stream: bool | None = False, ) -> AsyncGenerator: request = AgentTurnResumeRequest( diff --git a/tests/integration/agents/test_agents.py b/tests/integration/agents/test_agents.py index 07ba7bb01..51cc0d764 100644 --- a/tests/integration/agents/test_agents.py +++ b/tests/integration/agents/test_agents.py @@ -62,14 +62,9 @@ def agent_config(llama_stack_client, text_model_id): agent_config = dict( model=text_model_id, instructions="You are a helpful assistant", - sampling_params={ - "strategy": { - "type": "top_p", - "temperature": 0.0001, - "top_p": 0.9, - }, - "max_tokens": 512, - }, + temperature=0.0001, + top_p=0.9, + max_output_tokens=512, tools=[], input_shields=available_shields, output_shields=available_shields, @@ -83,14 +78,9 @@ def agent_config_without_safety(text_model_id): agent_config = dict( model=text_model_id, instructions="You are a helpful assistant", - sampling_params={ - "strategy": { - "type": "top_p", - "temperature": 0.0001, - "top_p": 0.9, - }, - "max_tokens": 512, - }, + temperature=0.0001, + top_p=0.9, + max_output_tokens=512, tools=[], enable_session_persistence=False, ) @@ -194,14 +184,9 @@ def test_tool_config(agent_config): common_params = dict( model="meta-llama/Llama-3.2-3B-Instruct", instructions="You are a helpful assistant", - sampling_params={ - "strategy": { - "type": "top_p", - "temperature": 1.0, - "top_p": 0.9, - }, - "max_tokens": 512, - }, + temperature=1.0, + top_p=0.9, + max_output_tokens=512, toolgroups=[], enable_session_persistence=False, ) @@ -212,40 +197,25 @@ def test_tool_config(agent_config): agent_config = AgentConfig( **common_params, - tool_choice="auto", + tool_config=ToolConfig(tool_choice="auto"), ) server_config = Server__AgentConfig(**agent_config) assert server_config.tool_config.tool_choice == ToolChoice.auto agent_config = AgentConfig( **common_params, - tool_choice="auto", - tool_config=ToolConfig( - tool_choice="auto", - ), + tool_config=ToolConfig(tool_choice="auto"), ) server_config = Server__AgentConfig(**agent_config) assert server_config.tool_config.tool_choice == ToolChoice.auto agent_config = AgentConfig( **common_params, - tool_config=ToolConfig( - tool_choice="required", - ), + tool_config=ToolConfig(tool_choice="required"), ) server_config = Server__AgentConfig(**agent_config) assert server_config.tool_config.tool_choice == ToolChoice.required - agent_config = AgentConfig( - **common_params, - tool_choice="required", - tool_config=ToolConfig( - tool_choice="auto", - ), - ) - with pytest.raises(ValueError, match="tool_choice is deprecated"): - Server__AgentConfig(**agent_config) - def test_builtin_tool_web_search(llama_stack_client, agent_config): agent_config = { diff --git a/tests/integration/agents/test_persistence.py b/tests/integration/agents/test_persistence.py index 49d9d42d0..006dd24c8 100644 --- a/tests/integration/agents/test_persistence.py +++ b/tests/integration/agents/test_persistence.py @@ -7,7 +7,7 @@ import pytest from llama_stack.apis.agents import AgentConfig, Turn -from llama_stack.apis.inference import SamplingParams, UserMessage +from llama_stack.apis.inference import OpenAIUserMessageParam from llama_stack.providers.datatypes import Api from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig @pytest.fixture def sample_messages(): return [ - UserMessage(content="What's the weather like today?"), + OpenAIUserMessageParam(content="What's the weather like today?"), ] @@ -36,7 +36,9 @@ def common_params(inference_model): model=inference_model, instructions="You are a helpful assistant.", enable_session_persistence=True, - sampling_params=SamplingParams(temperature=0.7, top_p=0.95), + temperature=0.7, + top_p=0.95, + max_output_tokens=256, input_shields=[], output_shields=[], tools=[], diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py index fdbb2b8e9..b11c52b84 100644 --- a/tests/unit/providers/agent/test_meta_reference_agent.py +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -69,30 +69,26 @@ async def agents_impl(config, mock_apis): @pytest.fixture def sample_agent_config(): return AgentConfig( - sampling_params={ - "strategy": {"type": "greedy"}, - "max_tokens": 0, - "repetition_penalty": 1.0, - }, + temperature=0.0, + top_p=1.0, + max_output_tokens=0, input_shields=["string"], output_shields=["string"], toolgroups=["mcp::my_mcp_server"], client_tools=[ { + "type": "function", "name": "client_tool", "description": "Client Tool", - "parameters": [ - { - "name": "string", - "parameter_type": "string", - "description": "string", - "required": True, - "default": None, - } - ], - "metadata": { - "property1": None, - "property2": None, + "parameters": { + "type": "object", + "properties": { + "string": { + "type": "string", + "description": "string", + } + }, + "required": ["string"], }, } ],