feat(agents)!: changing agents API signatures to use OpenAI types

Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
Ashwin Bharambe 2025-10-10 13:04:41 -07:00
parent 548ccff368
commit c56b2deb7d
6 changed files with 392 additions and 305 deletions

View file

@ -11,19 +11,18 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent from llama_stack.apis.common.content_types import URL, ContentDelta
from llama_stack.apis.common.responses import Order, PaginatedResponse from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
CompletionMessage, OpenAIAssistantMessageParam,
ResponseFormat, OpenAIChatCompletionMessageContent,
SamplingParams, OpenAIChatCompletionToolCall,
ToolCall, OpenAIMessageParam,
OpenAIResponseFormatParam,
OpenAIToolMessageParam,
ToolChoice, ToolChoice,
ToolConfig, ToolConfig,
ToolPromptFormat, ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
) )
from llama_stack.apis.safety import SafetyViolation from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef from llama_stack.apis.tools import ToolDef
@ -63,7 +62,7 @@ class Attachment(BaseModel):
:param mime_type: The MIME type of the attachment. :param mime_type: The MIME type of the attachment.
""" """
content: InterleavedContent | URL content: OpenAIChatCompletionMessageContent | URL
mime_type: str mime_type: str
@ -74,7 +73,7 @@ class Document(BaseModel):
:param mime_type: The MIME type of the document. :param mime_type: The MIME type of the document.
""" """
content: InterleavedContent | URL content: OpenAIChatCompletionMessageContent | URL
mime_type: str mime_type: str
@ -108,6 +107,7 @@ class StepType(StrEnum):
memory_retrieval = "memory_retrieval" memory_retrieval = "memory_retrieval"
@json_schema_type
@json_schema_type @json_schema_type
class InferenceStep(StepCommon): class InferenceStep(StepCommon):
"""An inference step in an agent turn. """An inference step in an agent turn.
@ -118,7 +118,8 @@ class InferenceStep(StepCommon):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage model_response: OpenAIAssistantMessageParam
finish_reason: str | None = None
@json_schema_type @json_schema_type
@ -130,8 +131,8 @@ class ToolExecutionStep(StepCommon):
""" """
step_type: Literal[StepType.tool_execution] = StepType.tool_execution step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall] tool_calls: list[OpenAIChatCompletionToolCall]
tool_responses: list[ToolResponse] tool_responses: list[OpenAIToolMessageParam]
@json_schema_type @json_schema_type
@ -156,7 +157,7 @@ class MemoryRetrievalStep(StepCommon):
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]? # TODO: should this be List[str]?
vector_db_ids: str vector_db_ids: str
inserted_context: InterleavedContent inserted_context: OpenAIChatCompletionMessageContent
Step = Annotated[ Step = Annotated[
@ -181,9 +182,10 @@ class Turn(BaseModel):
turn_id: str turn_id: str
session_id: str session_id: str
input_messages: list[UserMessage | ToolResponseMessage] input_messages: list[OpenAIMessageParam]
steps: list[Step] steps: list[Step]
output_message: CompletionMessage output_message: OpenAIAssistantMessageParam
finish_reason: str | None = None
output_attachments: list[Attachment] | None = Field(default_factory=lambda: []) output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime started_at: datetime
@ -216,31 +218,22 @@ register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams) max_output_tokens: int | None = None
temperature: float | None = None
top_p: float | None = None
stop: list[str] | None = None
input_shields: list[str] | None = Field(default_factory=lambda: []) input_shields: list[str] | None = Field(default_factory=list)
output_shields: list[str] | None = Field(default_factory=lambda: []) output_shields: list[str] | None = Field(default_factory=list)
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
client_tools: list[ToolDef] | None = Field(default_factory=lambda: []) client_tools: list[OpenAIResponseInputTool | ToolDef] | None = Field(default_factory=list)
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None) tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10 max_infer_iters: int | None = 10
def model_post_init(self, __context): def model_post_init(self, __context):
if self.tool_config: if self.tool_config is None:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice: self.tool_config = ToolConfig()
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type @json_schema_type
@ -258,7 +251,7 @@ class AgentConfig(AgentConfigCommon):
instructions: str instructions: str
name: str | None = None name: str | None = None
enable_session_persistence: bool | None = False enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None response_format: OpenAIResponseFormatParam | None = None
@json_schema_type @json_schema_type
@ -434,10 +427,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
agent_id: str agent_id: str
session_id: str session_id: str
# TODO: figure out how we can simplify this and make why messages: list[OpenAIMessageParam]
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: []) toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
@ -460,7 +450,7 @@ class AgentTurnResumeRequest(BaseModel):
agent_id: str agent_id: str
session_id: str session_id: str
turn_id: str turn_id: str
tool_responses: list[ToolResponse] tool_responses: list[OpenAIToolMessageParam]
stream: bool | None = False stream: bool | None = False
@ -531,7 +521,7 @@ class Agents(Protocol):
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
messages: list[UserMessage | ToolResponseMessage], messages: list[OpenAIMessageParam],
stream: bool | None = False, stream: bool | None = False,
documents: list[Document] | None = None, documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None, toolgroups: list[AgentToolGroup] | None = None,
@ -569,7 +559,7 @@ class Agents(Protocol):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: list[ToolResponse], tool_responses: list[OpenAIToolMessageParam],
stream: bool | None = False, stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]: ) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses. """Resume an agent turn with executed tool call responses.

View file

@ -10,12 +10,14 @@ import re
import uuid import uuid
import warnings import warnings
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any
from datetime import UTC, datetime from datetime import UTC, datetime
import httpx import httpx
from llama_stack.apis.agents import ( from llama_stack.apis.agents import (
AgentConfig, AgentConfig,
OpenAIResponseInputTool,
AgentToolGroup, AgentToolGroup,
AgentToolGroupWithArgs, AgentToolGroupWithArgs,
AgentTurnCreateRequest, AgentTurnCreateRequest,
@ -32,16 +34,12 @@ from llama_stack.apis.agents import (
Document, Document,
InferenceStep, InferenceStep,
ShieldCallStep, ShieldCallStep,
Step,
StepType, StepType,
ToolExecutionStep, ToolExecutionStep,
Turn, Turn,
) )
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus
URL,
TextContentItem,
ToolCallDelta,
ToolCallParseStatus,
)
from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionResponseEventType, ChatCompletionResponseEventType,
@ -50,20 +48,24 @@ from llama_stack.apis.inference import (
Message, Message,
OpenAIAssistantMessageParam, OpenAIAssistantMessageParam,
OpenAIDeveloperMessageParam, OpenAIDeveloperMessageParam,
OpenAIChatCompletionContentPartImageParam,
OpenAIChatCompletionContentPartTextParam,
OpenAIChatCompletionMessageContent,
OpenAIChatCompletionToolCall,
OpenAIChatCompletionToolCallFunction,
OpenAIImageURL,
OpenAIMessageParam, OpenAIMessageParam,
OpenAISystemMessageParam, OpenAISystemMessageParam,
OpenAIToolMessageParam, OpenAIToolMessageParam,
OpenAIUserMessageParam, OpenAIUserMessageParam,
SamplingParams,
StopReason, StopReason,
SystemMessage, SystemMessage,
ToolDefinition, ToolDefinition,
ToolResponse,
ToolResponseMessage, ToolResponseMessage,
UserMessage, UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime from llama_stack.apis.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger from llama_stack.log import get_logger
@ -90,6 +92,167 @@ RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents::meta_reference") logger = get_logger(name=__name__, category="agents::meta_reference")
def _map_finish_reason_to_stop_reason(finish_reason: str | None) -> StopReason:
if finish_reason == "length":
return StopReason.out_of_tokens
if finish_reason == "tool_calls":
return StopReason.end_of_message
# Default to end_of_turn for unknown or "stop"
return StopReason.end_of_turn
def _map_stop_reason_to_finish_reason(stop_reason: StopReason | None) -> str | None:
if stop_reason == StopReason.out_of_tokens:
return "length"
if stop_reason == StopReason.end_of_message:
return "tool_calls"
if stop_reason == StopReason.end_of_turn:
return "stop"
return None
def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> ToolCall:
name = None
if tool_call.function and tool_call.function.name:
name = tool_call.function.name
return ToolCall(
call_id=tool_call.id or f"call_{uuid.uuid4()}",
tool_name=name or "",
arguments=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "{}",
)
def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall:
function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
return OpenAIChatCompletionToolCall(
index=index,
id=tool_call.call_id,
function=OpenAIChatCompletionToolCallFunction(
name=function_name,
arguments=tool_call.arguments,
),
)
def _tool_response_message_to_openai(response: ToolResponseMessage) -> OpenAIToolMessageParam:
content = _coerce_to_text(response.content)
return OpenAIToolMessageParam(
tool_call_id=response.call_id,
content=content,
)
def _openai_message_content_to_text(
content: OpenAIChatCompletionMessageContent,
) -> str:
if isinstance(content, str):
return content
parts = []
for item in content:
if isinstance(item, OpenAIChatCompletionContentPartTextParam):
parts.append(item.text)
elif isinstance(item, OpenAIChatCompletionContentPartImageParam) and item.image_url:
if item.image_url.url:
parts.append(item.image_url.url)
return "\n".join(parts)
def _append_text_to_openai_message(message: OpenAIMessageParam, text: str) -> None:
if not text:
return
if isinstance(message, OpenAIUserMessageParam):
content = message.content
if content is None or content == "":
message.content = text
elif isinstance(content, str):
message.content = f"{content}\n{text}"
else:
content.append(OpenAIChatCompletionContentPartTextParam(text=text))
def _coerce_to_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
return "\n".join(_coerce_to_text(item) for item in content)
if hasattr(content, "text"):
return getattr(content, "text")
if hasattr(content, "image"):
image = getattr(content, "image")
if hasattr(image, "url") and image.url:
return getattr(image.url, "uri", "")
return str(content)
def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message:
if isinstance(message, OpenAIUserMessageParam):
return UserMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAISystemMessageParam):
return SystemMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAIToolMessageParam):
return ToolResponseMessage(
call_id=message.tool_call_id,
content=_openai_message_content_to_text(message.content),
)
if isinstance(message, OpenAIDeveloperMessageParam):
# Map developer messages to user role for legacy compatibility
return UserMessage(content=_openai_message_content_to_text(message.content))
if isinstance(message, OpenAIAssistantMessageParam):
tool_calls = [
_openai_tool_call_to_legacy(tool_call)
for tool_call in message.tool_calls or []
]
return CompletionMessage(
content=_openai_message_content_to_text(message.content) if message.content is not None else "",
stop_reason=StopReason.end_of_turn,
tool_calls=tool_calls,
)
raise ValueError(f"Unsupported OpenAI message type: {type(message)}")
async def _legacy_message_to_openai(message: Message) -> OpenAIMessageParam:
openai_dict = await convert_message_to_openai_dict_new(message)
role = openai_dict.get("role")
if role == "user":
return OpenAIUserMessageParam(**openai_dict)
if role == "system":
return OpenAISystemMessageParam(**openai_dict)
if role == "assistant":
return OpenAIAssistantMessageParam(**openai_dict)
if role == "tool":
return OpenAIToolMessageParam(**openai_dict)
if role == "developer":
return OpenAIDeveloperMessageParam(**openai_dict)
raise ValueError(f"Unsupported OpenAI message role: {role}")
async def _completion_to_openai_assistant(
completion: CompletionMessage,
) -> tuple[OpenAIAssistantMessageParam, str | None]:
assistant_param = await _legacy_message_to_openai(completion)
assert isinstance(assistant_param, OpenAIAssistantMessageParam)
finish_reason = _map_stop_reason_to_finish_reason(completion.stop_reason)
return assistant_param, finish_reason
def _client_tool_to_tool_definition(tool: OpenAIResponseInputTool | ToolDef) -> ToolDefinition:
if isinstance(tool, ToolDef):
return ToolDefinition(
tool_name=tool.name,
description=tool.description,
input_schema=tool.input_schema,
)
if getattr(tool, "type", None) == "function":
return ToolDefinition(
tool_name=tool.name,
description=getattr(tool, "description", None),
input_schema=getattr(tool, "parameters", None),
)
raise ValueError(f"Unsupported client tool type '{getattr(tool, 'type', None)}' for agent configuration")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
self, self,
@ -123,59 +286,70 @@ class ChatAgent(ShieldRunnerMixin):
output_shields=agent_config.output_shields, output_shields=agent_config.output_shields,
) )
def turn_to_messages(self, turn: Turn) -> list[Message]: def _resolve_generation_options(
messages = [] self,
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
) -> dict[str, Any]:
def _pick(attr: str) -> Any:
value = getattr(request, attr, None)
if value is not None:
return value
return getattr(self.agent_config, attr)
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages temperature = _pick("temperature")
tool_call_ids = set() top_p = _pick("top_p")
for step in turn.steps: max_output_tokens = _pick("max_output_tokens")
if step.step_type == StepType.tool_execution.value: stop = _pick("stop")
for response in step.tool_responses:
tool_call_ids.add(response.call_id)
for m in turn.input_messages: # Ensure we don't share mutable defaults
msg = m.model_copy() if isinstance(stop, list):
# We do not want to keep adding RAG context to the input messages stop = list(stop)
# May be this should be a parameter of the agentic instance
# that can define its behavior in a custom way return {
if isinstance(msg, UserMessage): "temperature": temperature,
msg.context = None "top_p": top_p,
if isinstance(msg, ToolResponseMessage): "max_output_tokens": max_output_tokens,
if msg.call_id in tool_call_ids: "stop": stop,
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps }
def turn_to_messages(self, turn: Turn) -> list[OpenAIMessageParam]:
messages: list[OpenAIMessageParam] = []
tool_response_ids = {
response.tool_call_id
for step in turn.steps
if step.step_type == StepType.tool_execution.value
for response in step.tool_responses
}
for message in turn.input_messages:
copied = message.model_copy(deep=True)
if isinstance(copied, OpenAIToolMessageParam) and copied.tool_call_id in tool_response_ids:
# Skip tool responses; they will be reintroduced from the execution step
continue continue
messages.append(copied)
messages.append(msg)
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.inference.value: if step.step_type == StepType.inference.value:
messages.append(step.model_response) messages.append(step.model_response.model_copy(deep=True))
elif step.step_type == StepType.tool_execution.value: elif step.step_type == StepType.tool_execution.value:
for response in step.tool_responses: for response in step.tool_responses:
messages.append( messages.append(response.model_copy(deep=True))
ToolResponseMessage( elif step.step_type == StepType.shield_call.value and step.violation:
call_id=response.call_id, assistant_msg = OpenAIAssistantMessageParam(
content=response.content, content=str(step.violation.user_message),
)
)
elif step.step_type == StepType.shield_call.value:
if step.violation:
# CompletionMessage itself in the ShieldResponse
messages.append(
CompletionMessage(
content=step.violation.user_message,
stop_reason=StopReason.end_of_turn,
)
) )
messages.append(assistant_msg)
return messages return messages
async def create_session(self, name: str) -> str: async def create_session(self, name: str) -> str:
return await self.storage.create_session(name) return await self.storage.create_session(name)
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]: async def get_messages_from_turns(self, turns: list[Turn]) -> list[OpenAIMessageParam]:
messages = [] messages: list[OpenAIMessageParam] = []
if self.agent_config.instructions != "": if self.agent_config.instructions:
messages.append(SystemMessage(content=self.agent_config.instructions)) messages.append(OpenAISystemMessageParam(content=self.agent_config.instructions))
for turn in turns: for turn in turns:
messages.extend(self.turn_to_messages(turn)) messages.extend(self.turn_to_messages(turn))
@ -228,26 +402,19 @@ class ChatAgent(ShieldRunnerMixin):
if is_resume and len(turns) == 0: if is_resume and len(turns) == 0:
raise ValueError("No turns found for session") raise ValueError("No turns found for session")
steps = [] steps: list[Step] = []
messages = await self.get_messages_from_turns(turns) history_openai = await self.get_messages_from_turns(turns)
if turn_id is None:
turn_id = request.turn_id
if is_resume: if is_resume:
tool_response_messages = [ tool_response_messages = [resp.model_copy(deep=True) for resp in request.tool_responses]
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses history_openai.extend(tool_response_messages)
]
messages.extend(tool_response_messages)
last_turn = turns[-1] last_turn = turns[-1]
last_turn_messages = self.turn_to_messages(last_turn) steps = list(last_turn.steps)
last_turn_messages = [
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
]
last_turn_messages.extend(tool_response_messages)
# get steps from the turn
steps = last_turn.steps
# mark tool execution step as complete
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
# we'll create a new tool execution step with current time
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step( in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
request.session_id, request.turn_id request.session_id, request.turn_id
) )
@ -256,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())), step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
turn_id=request.turn_id, turn_id=request.turn_id,
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []), tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
tool_responses=request.tool_responses, tool_responses=tool_response_messages,
completed_at=now, completed_at=now,
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now), started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
) )
@ -270,26 +437,34 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
input_messages = last_turn.input_messages
turn_id = request.turn_id input_messages_openai = [msg.model_copy(deep=True) for msg in last_turn.input_messages]
start_time = last_turn.started_at start_time = last_turn.started_at
else: else:
messages.extend(request.messages) new_messages = [msg.model_copy(deep=True) for msg in request.messages]
history_openai.extend(new_messages)
input_messages_openai = new_messages
start_time = datetime.now(UTC).isoformat() start_time = datetime.now(UTC).isoformat()
input_messages = request.messages
output_message = None generation_options = self._resolve_generation_options(request)
output_completion: CompletionMessage | None = None
output_finish_reason: str | None = None
output_assistant_message: OpenAIAssistantMessageParam | None = None
async for chunk in self.run( async for chunk in self.run(
session_id=request.session_id, session_id=request.session_id,
turn_id=turn_id, turn_id=turn_id,
input_messages=messages, input_messages=history_openai,
sampling_params=self.agent_config.sampling_params,
stream=request.stream, stream=request.stream,
documents=request.documents if not is_resume else None, documents=request.documents if not is_resume else None,
temperature=generation_options["temperature"],
top_p=generation_options["top_p"],
max_output_tokens=generation_options["max_output_tokens"],
stop=generation_options["stop"],
): ):
if isinstance(chunk, CompletionMessage): if isinstance(chunk, CompletionMessage):
output_message = chunk output_completion = chunk
output_assistant_message, output_finish_reason = await _completion_to_openai_assistant(chunk)
continue continue
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}" assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
@ -299,19 +474,21 @@ class ChatAgent(ShieldRunnerMixin):
yield chunk yield chunk
assert output_message is not None assert output_completion is not None
assert output_assistant_message is not None
turn = Turn( turn = Turn(
turn_id=turn_id, turn_id=turn_id,
session_id=request.session_id, session_id=request.session_id,
input_messages=input_messages, input_messages=input_messages_openai,
output_message=output_message, output_message=output_assistant_message,
finish_reason=output_finish_reason,
started_at=start_time, started_at=start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC).isoformat(),
steps=steps, steps=steps,
) )
await self.storage.add_turn_to_session(request.session_id, turn) await self.storage.add_turn_to_session(request.session_id, turn)
if output_message.tool_calls: if output_assistant_message.tool_calls:
chunk = AgentTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseTurnAwaitingInputPayload( payload=AgentTurnResponseTurnAwaitingInputPayload(
@ -334,10 +511,13 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
input_messages: list[Message], input_messages: list[OpenAIMessageParam],
sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: list[Document] | None = None, documents: list[Document] | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_output_tokens: int | None = None,
stop: list[str] | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# Doing async generators makes downstream code much simpler and everything amenable to # Doing async generators makes downstream code much simpler and everything amenable to
# streaming. However, it also makes things complicated here because AsyncGenerators cannot # streaming. However, it also makes things complicated here because AsyncGenerators cannot
@ -357,9 +537,12 @@ class ChatAgent(ShieldRunnerMixin):
session_id, session_id,
turn_id, turn_id,
input_messages, input_messages,
sampling_params,
stream, stream,
documents, documents,
temperature,
top_p,
max_output_tokens,
stop,
): ):
if isinstance(res, bool): if isinstance(res, bool):
return return
@ -370,8 +553,9 @@ class ChatAgent(ShieldRunnerMixin):
yield res yield res
assert final_response is not None assert final_response is not None
final_assistant, final_finish_reason = await _completion_to_openai_assistant(copy.deepcopy(final_response))
# for output shields run on the full input and output combination # for output shields run on the full input and output combination
messages = input_messages + [final_response] messages = input_messages + [final_assistant.model_copy(deep=True)]
if len(self.output_shields) > 0: if len(self.output_shields) > 0:
async for res in self.run_multiple_shields_wrapper( async for res in self.run_multiple_shields_wrapper(
@ -387,7 +571,7 @@ class ChatAgent(ShieldRunnerMixin):
async def run_multiple_shields_wrapper( async def run_multiple_shields_wrapper(
self, self,
turn_id: str, turn_id: str,
messages: list[Message], messages: list[OpenAIMessageParam],
shields: list[str], shields: list[str],
touchpoint: str, touchpoint: str,
) -> AsyncGenerator: ) -> AsyncGenerator:
@ -412,7 +596,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
) )
await self.run_multiple_shields(messages, shields) legacy_messages = [_openai_message_param_to_legacy(m) for m in messages]
await self.run_multiple_shields(legacy_messages, shields)
except SafetyException as e: except SafetyException as e:
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
@ -461,29 +646,25 @@ class ChatAgent(ShieldRunnerMixin):
self, self,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
input_messages: list[Message], input_messages: list[OpenAIMessageParam],
sampling_params: SamplingParams,
stream: bool = False, stream: bool = False,
documents: list[Document] | None = None, documents: list[Document] | None = None,
temperature: float | None = None,
top_p: float | None = None,
max_output_tokens: int | None = None,
stop: list[str] | None = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
# if document is passed in a turn, we parse the raw text of the document conversation = [msg.model_copy(deep=True) for msg in input_messages]
# and sent it as a user message
if documents: # if document is passed in a turn, hydrate the last user message with the context
contexts = [] if documents and conversation:
appended_texts = []
for document in documents: for document in documents:
raw_document_text = await get_raw_document_text(document) raw_document_text = await get_raw_document_text(document)
contexts.append(raw_document_text) if raw_document_text:
appended_texts.append(raw_document_text)
attached_context = "\n".join(contexts) if appended_texts:
if isinstance(input_messages[-1].content, str): _append_text_to_openai_message(conversation[-1], "\n".join(appended_texts))
input_messages[-1].content += attached_context
elif isinstance(input_messages[-1].content, list):
input_messages[-1].content.append(TextContentItem(text=attached_context))
else:
input_messages[-1].content = [
input_messages[-1].content,
TextContentItem(text=attached_context),
]
session_info = await self.storage.get_session_info(session_id) session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it # if the session has a memory bank id, let the memory tool use it
@ -500,8 +681,12 @@ class ChatAgent(ShieldRunnerMixin):
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0 n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
# Build a map of custom tools to their definitions for faster lookup # Build a map of custom tools to their definitions for faster lookup
client_tools = {} client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {}
if self.agent_config.client_tools:
for tool in self.agent_config.client_tools: for tool in self.agent_config.client_tools:
if isinstance(tool, ToolDef) and tool.name:
client_tools[tool.name] = tool
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
client_tools[tool.name] = tool client_tools[tool.name] = tool
while True: while True:
step_id = str(uuid.uuid4()) step_id = str(uuid.uuid4())
@ -520,81 +705,33 @@ class ChatAgent(ShieldRunnerMixin):
stop_reason: StopReason | None = None stop_reason: StopReason | None = None
async with tracing.span("inference") as span: async with tracing.span("inference") as span:
if self.telemetry_enabled and span is not None: if self.telemetry_enabled and span is not None and self.agent_config.name:
if self.agent_config.name:
span.set_attribute("agent_name", self.agent_config.name) span.set_attribute("agent_name", self.agent_config.name)
def _serialize_nested(value):
"""Recursively serialize nested Pydantic models to dicts."""
from pydantic import BaseModel
if isinstance(value, BaseModel):
return value.model_dump(mode="json")
elif isinstance(value, dict):
return {k: _serialize_nested(v) for k, v in value.items()}
elif isinstance(value, list):
return [_serialize_nested(item) for item in value]
else:
return value
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
# Serialize any nested Pydantic models to plain dicts
openai_msg = _serialize_nested(openai_msg)
role = openai_msg.get("role")
if role == "user":
return OpenAIUserMessageParam(**openai_msg)
elif role == "system":
return OpenAISystemMessageParam(**openai_msg)
elif role == "assistant":
return OpenAIAssistantMessageParam(**openai_msg)
elif role == "tool":
return OpenAIToolMessageParam(**openai_msg)
elif role == "developer":
return OpenAIDeveloperMessageParam(**openai_msg)
else:
raise ValueError(f"Unknown message role: {role}")
# Convert messages to OpenAI format
openai_messages: list[OpenAIMessageParam] = [
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
]
# Convert tool definitions to OpenAI format
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])] openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
# Extract tool_choice from tool_config for OpenAI compatibility
# Note: tool_choice can only be provided when tools are also provided
tool_choice = None tool_choice = None
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice: if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
tc = self.agent_config.tool_config.tool_choice tc = self.agent_config.tool_config.tool_choice
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc) tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
# Convert tool_choice to OpenAI format
if tool_choice_str in ("auto", "none", "required"): if tool_choice_str in ("auto", "none", "required"):
tool_choice = tool_choice_str tool_choice = tool_choice_str
else: else:
# It's a specific tool name, wrap it in the proper format
tool_choice = {"type": "function", "function": {"name": tool_choice_str}} tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
max_tokens = getattr(sampling_params, "max_tokens", None)
# Use OpenAI chat completion
openai_stream = await self.inference_api.openai_chat_completion( openai_stream = await self.inference_api.openai_chat_completion(
model=self.agent_config.model, model=self.agent_config.model,
messages=openai_messages, messages=[msg.model_copy(deep=True) for msg in conversation],
tools=openai_tools if openai_tools else None, tools=openai_tools if openai_tools else None,
tool_choice=tool_choice, tool_choice=tool_choice,
response_format=self.agent_config.response_format, response_format=self.agent_config.response_format,
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
max_tokens=max_tokens, max_tokens=max_output_tokens,
stop=stop,
stream=True, stream=True,
) )
# Convert OpenAI stream back to Llama Stack format
response_stream = convert_openai_chat_completion_stream( response_stream = convert_openai_chat_completion_stream(
openai_stream, enable_incremental_tool_calls=True openai_stream, enable_incremental_tool_calls=True
) )
@ -644,7 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn) span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
span.set_attribute( span.set_attribute(
"input", "input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]), json.dumps([json.loads(m.model_copy(deep=True).model_dump_json()) for m in conversation]),
) )
output_attr = json.dumps( output_attr = json.dumps(
{ {
@ -671,6 +808,8 @@ class ChatAgent(ShieldRunnerMixin):
tool_calls=tool_calls, tool_calls=tool_calls,
) )
assistant_param, finish_reason = await _completion_to_openai_assistant(copy.deepcopy(message))
yield AgentTurnResponseStreamChunk( yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload( payload=AgentTurnResponseStepCompletePayload(
@ -682,7 +821,8 @@ class ChatAgent(ShieldRunnerMixin):
# `deepcopy` for now, but this is symptomatic of a deeper issue. # `deepcopy` for now, but this is symptomatic of a deeper issue.
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
model_response=copy.deepcopy(message), model_response=assistant_param,
finish_reason=finish_reason,
started_at=inference_start_time, started_at=inference_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC).isoformat(),
), ),
@ -703,9 +843,10 @@ class ChatAgent(ShieldRunnerMixin):
yield message yield message
break break
assistant_param = assistant_param.model_copy(deep=True)
if len(message.tool_calls) == 0: if len(message.tool_calls) == 0:
if stop_reason == StopReason.end_of_turn: if stop_reason == StopReason.end_of_turn:
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
if len(output_attachments) > 0: if len(output_attachments) > 0:
if isinstance(message.content, list): if isinstance(message.content, list):
message.content += output_attachments message.content += output_attachments
@ -714,18 +855,20 @@ class ChatAgent(ShieldRunnerMixin):
yield message yield message
else: else:
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}") logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message] conversation.append(assistant_param)
else: else:
input_messages = input_messages + [message] conversation.append(assistant_param)
# Process tool calls in the message # Process tool calls in the message
client_tool_calls = [] client_tool_calls = []
non_client_tool_calls = [] non_client_tool_calls = []
client_tool_calls_openai = []
# Separate client and non-client tool calls # Separate client and non-client tool calls
for tool_call in message.tool_calls: for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools: if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call) client_tool_calls.append(tool_call)
client_tool_calls_openai.append(_legacy_tool_call_to_openai(tool_call))
else: else:
non_client_tool_calls.append(tool_call) non_client_tool_calls.append(tool_call)
@ -781,18 +924,14 @@ class ChatAgent(ShieldRunnerMixin):
span.set_attribute("output", result_message.model_dump_json()) span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step # Store tool execution step
openai_tool_call = _legacy_tool_call_to_openai(tool_call)
openai_tool_response = _tool_response_message_to_openai(result_message)
tool_execution_step = ToolExecutionStep( tool_execution_step = ToolExecutionStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
tool_calls=[tool_call], tool_calls=[openai_tool_call],
tool_responses=[ tool_responses=[openai_tool_response],
ToolResponse(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time, started_at=tool_execution_start_time,
completed_at=datetime.now(UTC).isoformat(), completed_at=datetime.now(UTC).isoformat(),
) )
@ -808,8 +947,8 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
# Add the result message to input_messages for the next iteration # Add the result message to conversation for the next iteration
input_messages.append(result_message) conversation.append(openai_tool_response)
# TODO: add tool-input touchpoint and a "start" event for this step also # TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially # but that needs a lot more refactoring of Tool code potentially
@ -829,7 +968,7 @@ class ChatAgent(ShieldRunnerMixin):
ToolExecutionStep( ToolExecutionStep(
step_id=step_id, step_id=step_id,
turn_id=turn_id, turn_id=turn_id,
tool_calls=client_tool_calls, tool_calls=client_tool_calls_openai,
tool_responses=[], tool_responses=[],
started_at=datetime.now(UTC).isoformat(), started_at=datetime.now(UTC).isoformat(),
), ),
@ -866,19 +1005,15 @@ class ChatAgent(ShieldRunnerMixin):
toolgroup_to_args = toolgroup_to_args or {} toolgroup_to_args = toolgroup_to_args or {}
tool_name_to_def = {} tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {}
tool_name_to_args = {} tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
for tool_def in self.agent_config.client_tools: if self.agent_config.client_tools:
if tool_name_to_def.get(tool_def.name, None): for tool in self.agent_config.client_tools:
raise ValueError(f"Tool {tool_def.name} already exists") tool_definition = _client_tool_to_tool_definition(tool)
if tool_name_to_def.get(tool_definition.tool_name):
# Use input_schema from ToolDef directly raise ValueError(f"Tool {tool_definition.tool_name} already exists")
tool_name_to_def[tool_def.name] = ToolDefinition( tool_name_to_def[tool_definition.tool_name] = tool_definition
tool_name=tool_def.name,
description=tool_def.description,
input_schema=tool_def.input_schema,
)
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups: for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name) toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name) tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
@ -999,12 +1134,7 @@ async def get_raw_document_text(document: Document) -> str:
if isinstance(document.content, URL): if isinstance(document.content, URL):
return await load_data_from_url(document.content.uri) return await load_data_from_url(document.content.uri)
elif isinstance(document.content, str): return _openai_message_content_to_text(document.content)
return document.content
elif isinstance(document.content, TextContentItem):
return document.content.text
else:
raise ValueError(f"Unexpected document content type: {type(document.content)}")
def _interpret_content_as_attachment( def _interpret_content_as_attachment(
@ -1015,7 +1145,7 @@ def _interpret_content_as_attachment(
snippet = match.group(1) snippet = match.group(1)
data = json.loads(snippet) data = json.loads(snippet)
return Attachment( return Attachment(
url=URL(uri="file://" + data["filepath"]), content=URL(uri="file://" + data["filepath"]),
mime_type=data["mimetype"], mime_type=data["mimetype"],
) )

View file

@ -33,9 +33,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
Inference, Inference,
ToolConfig, ToolConfig,
ToolResponse, OpenAIMessageParam,
ToolResponseMessage, OpenAIToolMessageParam,
UserMessage,
) )
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.tools import ToolGroups, ToolRuntime
@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents):
self, self,
agent_id: str, agent_id: str,
session_id: str, session_id: str,
messages: list[UserMessage | ToolResponseMessage], messages: list[OpenAIMessageParam],
toolgroups: list[AgentToolGroup] | None = None, toolgroups: list[AgentToolGroup] | None = None,
documents: list[Document] | None = None, documents: list[Document] | None = None,
stream: bool | None = False, stream: bool | None = False,
@ -189,7 +188,7 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_id: str, session_id: str,
turn_id: str, turn_id: str,
tool_responses: list[ToolResponse], tool_responses: list[OpenAIToolMessageParam],
stream: bool | None = False, stream: bool | None = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
request = AgentTurnResumeRequest( request = AgentTurnResumeRequest(

View file

@ -62,14 +62,9 @@ def agent_config(llama_stack_client, text_model_id):
agent_config = dict( agent_config = dict(
model=text_model_id, model=text_model_id,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params={ temperature=0.0001,
"strategy": { top_p=0.9,
"type": "top_p", max_output_tokens=512,
"temperature": 0.0001,
"top_p": 0.9,
},
"max_tokens": 512,
},
tools=[], tools=[],
input_shields=available_shields, input_shields=available_shields,
output_shields=available_shields, output_shields=available_shields,
@ -83,14 +78,9 @@ def agent_config_without_safety(text_model_id):
agent_config = dict( agent_config = dict(
model=text_model_id, model=text_model_id,
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params={ temperature=0.0001,
"strategy": { top_p=0.9,
"type": "top_p", max_output_tokens=512,
"temperature": 0.0001,
"top_p": 0.9,
},
"max_tokens": 512,
},
tools=[], tools=[],
enable_session_persistence=False, enable_session_persistence=False,
) )
@ -194,14 +184,9 @@ def test_tool_config(agent_config):
common_params = dict( common_params = dict(
model="meta-llama/Llama-3.2-3B-Instruct", model="meta-llama/Llama-3.2-3B-Instruct",
instructions="You are a helpful assistant", instructions="You are a helpful assistant",
sampling_params={ temperature=1.0,
"strategy": { top_p=0.9,
"type": "top_p", max_output_tokens=512,
"temperature": 1.0,
"top_p": 0.9,
},
"max_tokens": 512,
},
toolgroups=[], toolgroups=[],
enable_session_persistence=False, enable_session_persistence=False,
) )
@ -212,40 +197,25 @@ def test_tool_config(agent_config):
agent_config = AgentConfig( agent_config = AgentConfig(
**common_params, **common_params,
tool_choice="auto", tool_config=ToolConfig(tool_choice="auto"),
) )
server_config = Server__AgentConfig(**agent_config) server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.auto assert server_config.tool_config.tool_choice == ToolChoice.auto
agent_config = AgentConfig( agent_config = AgentConfig(
**common_params, **common_params,
tool_choice="auto", tool_config=ToolConfig(tool_choice="auto"),
tool_config=ToolConfig(
tool_choice="auto",
),
) )
server_config = Server__AgentConfig(**agent_config) server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.auto assert server_config.tool_config.tool_choice == ToolChoice.auto
agent_config = AgentConfig( agent_config = AgentConfig(
**common_params, **common_params,
tool_config=ToolConfig( tool_config=ToolConfig(tool_choice="required"),
tool_choice="required",
),
) )
server_config = Server__AgentConfig(**agent_config) server_config = Server__AgentConfig(**agent_config)
assert server_config.tool_config.tool_choice == ToolChoice.required assert server_config.tool_config.tool_choice == ToolChoice.required
agent_config = AgentConfig(
**common_params,
tool_choice="required",
tool_config=ToolConfig(
tool_choice="auto",
),
)
with pytest.raises(ValueError, match="tool_choice is deprecated"):
Server__AgentConfig(**agent_config)
def test_builtin_tool_web_search(llama_stack_client, agent_config): def test_builtin_tool_web_search(llama_stack_client, agent_config):
agent_config = { agent_config = {

View file

@ -7,7 +7,7 @@
import pytest import pytest
from llama_stack.apis.agents import AgentConfig, Turn from llama_stack.apis.agents import AgentConfig, Turn
from llama_stack.apis.inference import SamplingParams, UserMessage from llama_stack.apis.inference import OpenAIUserMessageParam
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
@pytest.fixture @pytest.fixture
def sample_messages(): def sample_messages():
return [ return [
UserMessage(content="What's the weather like today?"), OpenAIUserMessageParam(content="What's the weather like today?"),
] ]
@ -36,7 +36,9 @@ def common_params(inference_model):
model=inference_model, model=inference_model,
instructions="You are a helpful assistant.", instructions="You are a helpful assistant.",
enable_session_persistence=True, enable_session_persistence=True,
sampling_params=SamplingParams(temperature=0.7, top_p=0.95), temperature=0.7,
top_p=0.95,
max_output_tokens=256,
input_shields=[], input_shields=[],
output_shields=[], output_shields=[],
tools=[], tools=[],

View file

@ -69,30 +69,26 @@ async def agents_impl(config, mock_apis):
@pytest.fixture @pytest.fixture
def sample_agent_config(): def sample_agent_config():
return AgentConfig( return AgentConfig(
sampling_params={ temperature=0.0,
"strategy": {"type": "greedy"}, top_p=1.0,
"max_tokens": 0, max_output_tokens=0,
"repetition_penalty": 1.0,
},
input_shields=["string"], input_shields=["string"],
output_shields=["string"], output_shields=["string"],
toolgroups=["mcp::my_mcp_server"], toolgroups=["mcp::my_mcp_server"],
client_tools=[ client_tools=[
{ {
"type": "function",
"name": "client_tool", "name": "client_tool",
"description": "Client Tool", "description": "Client Tool",
"parameters": [ "parameters": {
{ "type": "object",
"name": "string", "properties": {
"parameter_type": "string", "string": {
"type": "string",
"description": "string", "description": "string",
"required": True,
"default": None,
} }
], },
"metadata": { "required": ["string"],
"property1": None,
"property2": None,
}, },
} }
], ],