mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-23 00:27:26 +00:00
feat(agents)!: changing agents API signatures to use OpenAI types
Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
parent
548ccff368
commit
c56b2deb7d
6 changed files with 392 additions and 305 deletions
|
@ -11,19 +11,18 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta
|
||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionMessageContent,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIToolMessageParam,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
|
@ -63,7 +62,7 @@ class Attachment(BaseModel):
|
|||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
content: OpenAIChatCompletionMessageContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
|
@ -74,7 +73,7 @@ class Document(BaseModel):
|
|||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
content: OpenAIChatCompletionMessageContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
|
@ -108,6 +107,7 @@ class StepType(StrEnum):
|
|||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
@ -118,7 +118,8 @@ class InferenceStep(StepCommon):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
model_response: OpenAIAssistantMessageParam
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -130,8 +131,8 @@ class ToolExecutionStep(StepCommon):
|
|||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
tool_calls: list[OpenAIChatCompletionToolCall]
|
||||
tool_responses: list[OpenAIToolMessageParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -156,7 +157,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
inserted_context: OpenAIChatCompletionMessageContent
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
|
@ -181,9 +182,10 @@ class Turn(BaseModel):
|
|||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_message: OpenAIAssistantMessageParam
|
||||
finish_reason: str | None = None
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
|
@ -216,31 +218,22 @@ register_schema(AgentToolGroup, name="AgentTool")
|
|||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
max_output_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
stop: list[str] | None = None
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
input_shields: list[str] | None = Field(default_factory=list)
|
||||
output_shields: list[str] | None = Field(default_factory=list)
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||
client_tools: list[OpenAIResponseInputTool | ToolDef] | None = Field(default_factory=list)
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
else:
|
||||
params = {}
|
||||
if self.tool_choice:
|
||||
params["tool_choice"] = self.tool_choice
|
||||
if self.tool_prompt_format:
|
||||
params["tool_prompt_format"] = self.tool_prompt_format
|
||||
self.tool_config = ToolConfig(**params)
|
||||
if self.tool_config is None:
|
||||
self.tool_config = ToolConfig()
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -258,7 +251,7 @@ class AgentConfig(AgentConfigCommon):
|
|||
instructions: str
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
response_format: OpenAIResponseFormatParam | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -434,10 +427,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
messages: list[OpenAIMessageParam]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
@ -460,7 +450,7 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: list[ToolResponse]
|
||||
tool_responses: list[OpenAIToolMessageParam]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
|
@ -531,7 +521,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
messages: list[OpenAIMessageParam],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
|
@ -569,7 +559,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
tool_responses: list[OpenAIToolMessageParam],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue