feat(agents)!: changing agents API signatures to use OpenAI types

Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
Ashwin Bharambe 2025-10-10 13:04:41 -07:00
parent 548ccff368
commit c56b2deb7d
6 changed files with 392 additions and 305 deletions

View file

@ -11,19 +11,18 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.content_types import URL, ContentDelta
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
OpenAIAssistantMessageParam,
OpenAIChatCompletionMessageContent,
OpenAIChatCompletionToolCall,
OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIToolMessageParam,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
@ -63,7 +62,7 @@ class Attachment(BaseModel):
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
content: OpenAIChatCompletionMessageContent | URL
mime_type: str
@ -74,7 +73,7 @@ class Document(BaseModel):
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
content: OpenAIChatCompletionMessageContent | URL
mime_type: str
@ -108,6 +107,7 @@ class StepType(StrEnum):
memory_retrieval = "memory_retrieval"
@json_schema_type
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
@ -118,7 +118,8 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
model_response: OpenAIAssistantMessageParam
finish_reason: str | None = None
@json_schema_type
@ -130,8 +131,8 @@ class ToolExecutionStep(StepCommon):
"""
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
tool_calls: list[OpenAIChatCompletionToolCall]
tool_responses: list[OpenAIToolMessageParam]
@json_schema_type
@ -156,7 +157,7 @@ class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_db_ids: str
inserted_context: InterleavedContent
inserted_context: OpenAIChatCompletionMessageContent
Step = Annotated[
@ -181,9 +182,10 @@ class Turn(BaseModel):
turn_id: str
session_id: str
input_messages: list[UserMessage | ToolResponseMessage]
input_messages: list[OpenAIMessageParam]
steps: list[Step]
output_message: CompletionMessage
output_message: OpenAIAssistantMessageParam
finish_reason: str | None = None
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
@ -216,31 +218,22 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stop: list[str] | None = None
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[OpenAIResponseInputTool | ToolDef] | None = Field(default_factory=list)
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
if self.tool_config is None:
self.tool_config = ToolConfig()
@json_schema_type
@ -258,7 +251,7 @@ class AgentConfig(AgentConfigCommon):
instructions: str
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
response_format: OpenAIResponseFormatParam | None = None
@json_schema_type
@ -434,10 +427,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
messages: list[OpenAIMessageParam]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
@ -460,7 +450,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str
session_id: str
turn_id: str
tool_responses: list[ToolResponse]
tool_responses: list[OpenAIToolMessageParam]
stream: bool | None = False
@ -531,7 +521,7 @@ class Agents(Protocol):
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
messages: list[OpenAIMessageParam],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
@ -569,7 +559,7 @@ class Agents(Protocol):
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
tool_responses: list[OpenAIToolMessageParam],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.