mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-21 16:07:16 +00:00
feat(agents)!: changing agents API signatures to use OpenAI types
Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
parent
548ccff368
commit
c56b2deb7d
6 changed files with 392 additions and 305 deletions
|
@ -11,19 +11,18 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
|||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta
|
||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIChatCompletionMessageContent,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIMessageParam,
|
||||
OpenAIResponseFormatParam,
|
||||
OpenAIToolMessageParam,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
|
@ -63,7 +62,7 @@ class Attachment(BaseModel):
|
|||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
content: OpenAIChatCompletionMessageContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
|
@ -74,7 +73,7 @@ class Document(BaseModel):
|
|||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
content: OpenAIChatCompletionMessageContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
|
@ -108,6 +107,7 @@ class StepType(StrEnum):
|
|||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
@ -118,7 +118,8 @@ class InferenceStep(StepCommon):
|
|||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
model_response: OpenAIAssistantMessageParam
|
||||
finish_reason: str | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -130,8 +131,8 @@ class ToolExecutionStep(StepCommon):
|
|||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
tool_calls: list[OpenAIChatCompletionToolCall]
|
||||
tool_responses: list[OpenAIToolMessageParam]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -156,7 +157,7 @@ class MemoryRetrievalStep(StepCommon):
|
|||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_db_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
inserted_context: OpenAIChatCompletionMessageContent
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
|
@ -181,9 +182,10 @@ class Turn(BaseModel):
|
|||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
input_messages: list[OpenAIMessageParam]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_message: OpenAIAssistantMessageParam
|
||||
finish_reason: str | None = None
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
|
@ -216,31 +218,22 @@ register_schema(AgentToolGroup, name="AgentTool")
|
|||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
max_output_tokens: int | None = None
|
||||
temperature: float | None = None
|
||||
top_p: float | None = None
|
||||
stop: list[str] | None = None
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
input_shields: list[str] | None = Field(default_factory=list)
|
||||
output_shields: list[str] | None = Field(default_factory=list)
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=list)
|
||||
client_tools: list[OpenAIResponseInputTool | ToolDef] | None = Field(default_factory=list)
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
else:
|
||||
params = {}
|
||||
if self.tool_choice:
|
||||
params["tool_choice"] = self.tool_choice
|
||||
if self.tool_prompt_format:
|
||||
params["tool_prompt_format"] = self.tool_prompt_format
|
||||
self.tool_config = ToolConfig(**params)
|
||||
if self.tool_config is None:
|
||||
self.tool_config = ToolConfig()
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -258,7 +251,7 @@ class AgentConfig(AgentConfigCommon):
|
|||
instructions: str
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
response_format: OpenAIResponseFormatParam | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -434,10 +427,7 @@ class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
messages: list[OpenAIMessageParam]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
@ -460,7 +450,7 @@ class AgentTurnResumeRequest(BaseModel):
|
|||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: list[ToolResponse]
|
||||
tool_responses: list[OpenAIToolMessageParam]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
|
@ -531,7 +521,7 @@ class Agents(Protocol):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
messages: list[OpenAIMessageParam],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
|
@ -569,7 +559,7 @@ class Agents(Protocol):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
tool_responses: list[OpenAIToolMessageParam],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
|
|
@ -10,12 +10,14 @@ import re
|
|||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
OpenAIResponseInputTool,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
|
@ -32,16 +34,12 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
InferenceStep,
|
||||
ShieldCallStep,
|
||||
Step,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
TextContentItem,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -50,20 +48,24 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionMessageContent,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIImageURL,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -90,6 +92,167 @@ RAG_TOOL_GROUP = "builtin::rag"
|
|||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(finish_reason: str | None) -> StopReason:
|
||||
if finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
if finish_reason == "tool_calls":
|
||||
return StopReason.end_of_message
|
||||
# Default to end_of_turn for unknown or "stop"
|
||||
return StopReason.end_of_turn
|
||||
|
||||
|
||||
def _map_stop_reason_to_finish_reason(stop_reason: StopReason | None) -> str | None:
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
return "length"
|
||||
if stop_reason == StopReason.end_of_message:
|
||||
return "tool_calls"
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
return "stop"
|
||||
return None
|
||||
|
||||
|
||||
def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> ToolCall:
|
||||
name = None
|
||||
if tool_call.function and tool_call.function.name:
|
||||
name = tool_call.function.name
|
||||
return ToolCall(
|
||||
call_id=tool_call.id or f"call_{uuid.uuid4()}",
|
||||
tool_name=name or "",
|
||||
arguments=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "{}",
|
||||
)
|
||||
|
||||
|
||||
def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall:
|
||||
function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
|
||||
return OpenAIChatCompletionToolCall(
|
||||
index=index,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=function_name,
|
||||
arguments=tool_call.arguments,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _tool_response_message_to_openai(response: ToolResponseMessage) -> OpenAIToolMessageParam:
|
||||
content = _coerce_to_text(response.content)
|
||||
return OpenAIToolMessageParam(
|
||||
tool_call_id=response.call_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _openai_message_content_to_text(
|
||||
content: OpenAIChatCompletionMessageContent,
|
||||
) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
parts.append(item.text)
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartImageParam) and item.image_url:
|
||||
if item.image_url.url:
|
||||
parts.append(item.image_url.url)
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _append_text_to_openai_message(message: OpenAIMessageParam, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
if isinstance(message, OpenAIUserMessageParam):
|
||||
content = message.content
|
||||
if content is None or content == "":
|
||||
message.content = text
|
||||
elif isinstance(content, str):
|
||||
message.content = f"{content}\n{text}"
|
||||
else:
|
||||
content.append(OpenAIChatCompletionContentPartTextParam(text=text))
|
||||
|
||||
|
||||
def _coerce_to_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return "\n".join(_coerce_to_text(item) for item in content)
|
||||
if hasattr(content, "text"):
|
||||
return getattr(content, "text")
|
||||
if hasattr(content, "image"):
|
||||
image = getattr(content, "image")
|
||||
if hasattr(image, "url") and image.url:
|
||||
return getattr(image.url, "uri", "")
|
||||
return str(content)
|
||||
|
||||
|
||||
def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message:
|
||||
if isinstance(message, OpenAIUserMessageParam):
|
||||
return UserMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAISystemMessageParam):
|
||||
return SystemMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAIToolMessageParam):
|
||||
return ToolResponseMessage(
|
||||
call_id=message.tool_call_id,
|
||||
content=_openai_message_content_to_text(message.content),
|
||||
)
|
||||
if isinstance(message, OpenAIDeveloperMessageParam):
|
||||
# Map developer messages to user role for legacy compatibility
|
||||
return UserMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAIAssistantMessageParam):
|
||||
tool_calls = [
|
||||
_openai_tool_call_to_legacy(tool_call)
|
||||
for tool_call in message.tool_calls or []
|
||||
]
|
||||
return CompletionMessage(
|
||||
content=_openai_message_content_to_text(message.content) if message.content is not None else "",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
raise ValueError(f"Unsupported OpenAI message type: {type(message)}")
|
||||
|
||||
|
||||
async def _legacy_message_to_openai(message: Message) -> OpenAIMessageParam:
|
||||
openai_dict = await convert_message_to_openai_dict_new(message)
|
||||
role = openai_dict.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_dict)
|
||||
if role == "system":
|
||||
return OpenAISystemMessageParam(**openai_dict)
|
||||
if role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_dict)
|
||||
if role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_dict)
|
||||
if role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_dict)
|
||||
raise ValueError(f"Unsupported OpenAI message role: {role}")
|
||||
|
||||
|
||||
async def _completion_to_openai_assistant(
|
||||
completion: CompletionMessage,
|
||||
) -> tuple[OpenAIAssistantMessageParam, str | None]:
|
||||
assistant_param = await _legacy_message_to_openai(completion)
|
||||
assert isinstance(assistant_param, OpenAIAssistantMessageParam)
|
||||
finish_reason = _map_stop_reason_to_finish_reason(completion.stop_reason)
|
||||
return assistant_param, finish_reason
|
||||
|
||||
|
||||
def _client_tool_to_tool_definition(tool: OpenAIResponseInputTool | ToolDef) -> ToolDefinition:
|
||||
if isinstance(tool, ToolDef):
|
||||
return ToolDefinition(
|
||||
tool_name=tool.name,
|
||||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
if getattr(tool, "type", None) == "function":
|
||||
return ToolDefinition(
|
||||
tool_name=tool.name,
|
||||
description=getattr(tool, "description", None),
|
||||
input_schema=getattr(tool, "parameters", None),
|
||||
)
|
||||
raise ValueError(f"Unsupported client tool type '{getattr(tool, 'type', None)}' for agent configuration")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -123,59 +286,70 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
def _resolve_generation_options(
|
||||
self,
|
||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||
) -> dict[str, Any]:
|
||||
def _pick(attr: str) -> Any:
|
||||
value = getattr(request, attr, None)
|
||||
if value is not None:
|
||||
return value
|
||||
return getattr(self.agent_config, attr)
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
tool_call_ids = set()
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
temperature = _pick("temperature")
|
||||
top_p = _pick("top_p")
|
||||
max_output_tokens = _pick("max_output_tokens")
|
||||
stop = _pick("stop")
|
||||
|
||||
for m in turn.input_messages:
|
||||
msg = m.model_copy()
|
||||
# We do not want to keep adding RAG context to the input messages
|
||||
# May be this should be a parameter of the agentic instance
|
||||
# that can define its behavior in a custom way
|
||||
if isinstance(msg, UserMessage):
|
||||
msg.context = None
|
||||
if isinstance(msg, ToolResponseMessage):
|
||||
if msg.call_id in tool_call_ids:
|
||||
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
||||
continue
|
||||
# Ensure we don't share mutable defaults
|
||||
if isinstance(stop, list):
|
||||
stop = list(stop)
|
||||
|
||||
messages.append(msg)
|
||||
return {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
|
||||
tool_response_ids = {
|
||||
response.tool_call_id
|
||||
for step in turn.steps
|
||||
if step.step_type == StepType.tool_execution.value
|
||||
for response in step.tool_responses
|
||||
}
|
||||
|
||||
for message in turn.input_messages:
|
||||
copied = message.model_copy(deep=True)
|
||||
if isinstance(copied, OpenAIToolMessageParam) and copied.tool_call_id in tool_response_ids:
|
||||
# Skip tool responses; they will be reintroduced from the execution step
|
||||
continue
|
||||
messages.append(copied)
|
||||
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
messages.append(step.model_response.model_copy(deep=True))
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
if step.violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=step.violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
messages.append(response.model_copy(deep=True))
|
||||
elif step.step_type == StepType.shield_call.value and step.violation:
|
||||
assistant_msg = OpenAIAssistantMessageParam(
|
||||
content=str(step.violation.user_message),
|
||||
)
|
||||
messages.append(assistant_msg)
|
||||
|
||||
return messages
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if self.agent_config.instructions:
|
||||
messages.append(OpenAISystemMessageParam(content=self.agent_config.instructions))
|
||||
|
||||
for turn in turns:
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
@ -228,26 +402,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if is_resume and len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
steps: list[Step] = []
|
||||
history_openai = await self.get_messages_from_turns(turns)
|
||||
|
||||
if turn_id is None:
|
||||
turn_id = request.turn_id
|
||||
|
||||
if is_resume:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
tool_response_messages = [resp.model_copy(deep=True) for resp in request.tool_responses]
|
||||
history_openai.extend(tool_response_messages)
|
||||
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
last_turn_messages = [
|
||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
last_turn_messages.extend(tool_response_messages)
|
||||
steps = list(last_turn.steps)
|
||||
|
||||
# get steps from the turn
|
||||
steps = last_turn.steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
# we'll create a new tool execution step with current time
|
||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
|
@ -256,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=request.tool_responses,
|
||||
tool_responses=tool_response_messages,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
@ -270,26 +437,34 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn.input_messages
|
||||
|
||||
turn_id = request.turn_id
|
||||
input_messages_openai = [msg.model_copy(deep=True) for msg in last_turn.input_messages]
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
new_messages = [msg.model_copy(deep=True) for msg in request.messages]
|
||||
history_openai.extend(new_messages)
|
||||
input_messages_openai = new_messages
|
||||
start_time = datetime.now(UTC).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
generation_options = self._resolve_generation_options(request)
|
||||
|
||||
output_completion: CompletionMessage | None = None
|
||||
output_finish_reason: str | None = None
|
||||
output_assistant_message: OpenAIAssistantMessageParam | None = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
input_messages=history_openai,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
temperature=generation_options["temperature"],
|
||||
top_p=generation_options["top_p"],
|
||||
max_output_tokens=generation_options["max_output_tokens"],
|
||||
stop=generation_options["stop"],
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
output_completion = chunk
|
||||
output_assistant_message, output_finish_reason = await _completion_to_openai_assistant(chunk)
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
|
@ -299,19 +474,21 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
assert output_completion is not None
|
||||
assert output_assistant_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
input_messages=input_messages_openai,
|
||||
output_message=output_assistant_message,
|
||||
finish_reason=output_finish_reason,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
if output_message.tool_calls:
|
||||
if output_assistant_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
|
@ -334,10 +511,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
stream: bool = False,
|
||||
documents: list[Document] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -357,9 +537,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id,
|
||||
turn_id,
|
||||
input_messages,
|
||||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
temperature,
|
||||
top_p,
|
||||
max_output_tokens,
|
||||
stop,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -370,8 +553,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
assert final_response is not None
|
||||
final_assistant, final_finish_reason = await _completion_to_openai_assistant(copy.deepcopy(final_response))
|
||||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
messages = input_messages + [final_assistant.model_copy(deep=True)]
|
||||
|
||||
if len(self.output_shields) > 0:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
|
@ -387,7 +571,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
|
@ -412,7 +596,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
legacy_messages = [_openai_message_param_to_legacy(m) for m in messages]
|
||||
await self.run_multiple_shields(legacy_messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -461,29 +646,25 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
stream: bool = False,
|
||||
documents: list[Document] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
if documents:
|
||||
contexts = []
|
||||
conversation = [msg.model_copy(deep=True) for msg in input_messages]
|
||||
|
||||
# if document is passed in a turn, hydrate the last user message with the context
|
||||
if documents and conversation:
|
||||
appended_texts = []
|
||||
for document in documents:
|
||||
raw_document_text = await get_raw_document_text(document)
|
||||
contexts.append(raw_document_text)
|
||||
|
||||
attached_context = "\n".join(contexts)
|
||||
if isinstance(input_messages[-1].content, str):
|
||||
input_messages[-1].content += attached_context
|
||||
elif isinstance(input_messages[-1].content, list):
|
||||
input_messages[-1].content.append(TextContentItem(text=attached_context))
|
||||
else:
|
||||
input_messages[-1].content = [
|
||||
input_messages[-1].content,
|
||||
TextContentItem(text=attached_context),
|
||||
]
|
||||
if raw_document_text:
|
||||
appended_texts.append(raw_document_text)
|
||||
if appended_texts:
|
||||
_append_text_to_openai_message(conversation[-1], "\n".join(appended_texts))
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
|
@ -500,9 +681,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {}
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
if isinstance(tool, ToolDef) and tool.name:
|
||||
client_tools[tool.name] = tool
|
||||
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
|
@ -520,81 +705,33 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.telemetry_enabled and span is not None:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled and span is not None and self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
elif isinstance(value, dict):
|
||||
return {k: _serialize_nested(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_serialize_nested(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
role = openai_msg.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**openai_msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[OpenAIMessageParam] = [
|
||||
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
|
||||
]
|
||||
|
||||
# Convert tool definitions to OpenAI format
|
||||
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
||||
|
||||
# Extract tool_choice from tool_config for OpenAI compatibility
|
||||
# Note: tool_choice can only be provided when tools are also provided
|
||||
tool_choice = None
|
||||
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
||||
tc = self.agent_config.tool_config.tool_choice
|
||||
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
|
||||
# Convert tool_choice to OpenAI format
|
||||
if tool_choice_str in ("auto", "none", "required"):
|
||||
tool_choice = tool_choice_str
|
||||
else:
|
||||
# It's a specific tool name, wrap it in the proper format
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
|
||||
|
||||
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
||||
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
||||
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
|
||||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||
|
||||
# Use OpenAI chat completion
|
||||
openai_stream = await self.inference_api.openai_chat_completion(
|
||||
model=self.agent_config.model,
|
||||
messages=openai_messages,
|
||||
messages=[msg.model_copy(deep=True) for msg in conversation],
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
max_tokens=max_output_tokens,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
)
|
||||
|
@ -644,7 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
json.dumps([json.loads(m.model_copy(deep=True).model_dump_json()) for m in conversation]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
|
@ -671,6 +808,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
assistant_param, finish_reason = await _completion_to_openai_assistant(copy.deepcopy(message))
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
@ -682,7 +821,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# `deepcopy` for now, but this is symptomatic of a deeper issue.
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
model_response=assistant_param,
|
||||
finish_reason=finish_reason,
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
),
|
||||
|
@ -703,9 +843,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
break
|
||||
|
||||
assistant_param = assistant_param.model_copy(deep=True)
|
||||
|
||||
if len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += output_attachments
|
||||
|
@ -714,18 +855,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
conversation.append(assistant_param)
|
||||
else:
|
||||
input_messages = input_messages + [message]
|
||||
conversation.append(assistant_param)
|
||||
|
||||
# Process tool calls in the message
|
||||
client_tool_calls = []
|
||||
non_client_tool_calls = []
|
||||
client_tool_calls_openai = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
client_tool_calls_openai.append(_legacy_tool_call_to_openai(tool_call))
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
|
@ -781,18 +924,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
openai_tool_call = _legacy_tool_call_to_openai(tool_call)
|
||||
openai_tool_response = _tool_response_message_to_openai(result_message)
|
||||
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
tool_calls=[openai_tool_call],
|
||||
tool_responses=[openai_tool_response],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
@ -808,8 +947,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
# Add the result message to conversation for the next iteration
|
||||
conversation.append(openai_tool_response)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
@ -829,7 +968,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=client_tool_calls,
|
||||
tool_calls=client_tool_calls_openai,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(UTC).isoformat(),
|
||||
),
|
||||
|
@ -866,19 +1005,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_name_to_args = {}
|
||||
tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {}
|
||||
tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
tool_definition = _client_tool_to_tool_definition(tool)
|
||||
if tool_name_to_def.get(tool_definition.tool_name):
|
||||
raise ValueError(f"Tool {tool_definition.tool_name} already exists")
|
||||
tool_name_to_def[tool_definition.tool_name] = tool_definition
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
|
@ -999,12 +1134,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
return document.content
|
||||
elif isinstance(document.content, TextContentItem):
|
||||
return document.content.text
|
||||
else:
|
||||
raise ValueError(f"Unexpected document content type: {type(document.content)}")
|
||||
return _openai_message_content_to_text(document.content)
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
@ -1015,7 +1145,7 @@ def _interpret_content_as_attachment(
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
content=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
|
|
|
@ -33,9 +33,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
|
|||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
|
@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
messages: list[OpenAIMessageParam],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
|
@ -189,7 +188,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
tool_responses: list[OpenAIToolMessageParam],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
|
@ -62,14 +62,9 @@ def agent_config(llama_stack_client, text_model_id):
|
|||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"max_tokens": 512,
|
||||
},
|
||||
temperature=0.0001,
|
||||
top_p=0.9,
|
||||
max_output_tokens=512,
|
||||
tools=[],
|
||||
input_shields=available_shields,
|
||||
output_shields=available_shields,
|
||||
|
@ -83,14 +78,9 @@ def agent_config_without_safety(text_model_id):
|
|||
agent_config = dict(
|
||||
model=text_model_id,
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 0.0001,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"max_tokens": 512,
|
||||
},
|
||||
temperature=0.0001,
|
||||
top_p=0.9,
|
||||
max_output_tokens=512,
|
||||
tools=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
@ -194,14 +184,9 @@ def test_tool_config(agent_config):
|
|||
common_params = dict(
|
||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||
instructions="You are a helpful assistant",
|
||||
sampling_params={
|
||||
"strategy": {
|
||||
"type": "top_p",
|
||||
"temperature": 1.0,
|
||||
"top_p": 0.9,
|
||||
},
|
||||
"max_tokens": 512,
|
||||
},
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
max_output_tokens=512,
|
||||
toolgroups=[],
|
||||
enable_session_persistence=False,
|
||||
)
|
||||
|
@ -212,40 +197,25 @@ def test_tool_config(agent_config):
|
|||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_choice="auto",
|
||||
tool_config=ToolConfig(tool_choice="auto"),
|
||||
)
|
||||
server_config = Server__AgentConfig(**agent_config)
|
||||
assert server_config.tool_config.tool_choice == ToolChoice.auto
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_choice="auto",
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
),
|
||||
tool_config=ToolConfig(tool_choice="auto"),
|
||||
)
|
||||
server_config = Server__AgentConfig(**agent_config)
|
||||
assert server_config.tool_config.tool_choice == ToolChoice.auto
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="required",
|
||||
),
|
||||
tool_config=ToolConfig(tool_choice="required"),
|
||||
)
|
||||
server_config = Server__AgentConfig(**agent_config)
|
||||
assert server_config.tool_config.tool_choice == ToolChoice.required
|
||||
|
||||
agent_config = AgentConfig(
|
||||
**common_params,
|
||||
tool_choice="required",
|
||||
tool_config=ToolConfig(
|
||||
tool_choice="auto",
|
||||
),
|
||||
)
|
||||
with pytest.raises(ValueError, match="tool_choice is deprecated"):
|
||||
Server__AgentConfig(**agent_config)
|
||||
|
||||
|
||||
def test_builtin_tool_web_search(llama_stack_client, agent_config):
|
||||
agent_config = {
|
||||
|
|
|
@ -7,7 +7,7 @@
|
|||
import pytest
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Turn
|
||||
from llama_stack.apis.inference import SamplingParams, UserMessage
|
||||
from llama_stack.apis.inference import OpenAIUserMessageParam
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
||||
|
@ -16,7 +16,7 @@ from llama_stack.providers.utils.kvstore.config import SqliteKVStoreConfig
|
|||
@pytest.fixture
|
||||
def sample_messages():
|
||||
return [
|
||||
UserMessage(content="What's the weather like today?"),
|
||||
OpenAIUserMessageParam(content="What's the weather like today?"),
|
||||
]
|
||||
|
||||
|
||||
|
@ -36,7 +36,9 @@ def common_params(inference_model):
|
|||
model=inference_model,
|
||||
instructions="You are a helpful assistant.",
|
||||
enable_session_persistence=True,
|
||||
sampling_params=SamplingParams(temperature=0.7, top_p=0.95),
|
||||
temperature=0.7,
|
||||
top_p=0.95,
|
||||
max_output_tokens=256,
|
||||
input_shields=[],
|
||||
output_shields=[],
|
||||
tools=[],
|
||||
|
|
|
@ -69,30 +69,26 @@ async def agents_impl(config, mock_apis):
|
|||
@pytest.fixture
|
||||
def sample_agent_config():
|
||||
return AgentConfig(
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
"max_tokens": 0,
|
||||
"repetition_penalty": 1.0,
|
||||
},
|
||||
temperature=0.0,
|
||||
top_p=1.0,
|
||||
max_output_tokens=0,
|
||||
input_shields=["string"],
|
||||
output_shields=["string"],
|
||||
toolgroups=["mcp::my_mcp_server"],
|
||||
client_tools=[
|
||||
{
|
||||
"type": "function",
|
||||
"name": "client_tool",
|
||||
"description": "Client Tool",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "string",
|
||||
"parameter_type": "string",
|
||||
"description": "string",
|
||||
"required": True,
|
||||
"default": None,
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"property1": None,
|
||||
"property2": None,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"string": {
|
||||
"type": "string",
|
||||
"description": "string",
|
||||
}
|
||||
},
|
||||
"required": ["string"],
|
||||
},
|
||||
}
|
||||
],
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue