mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-22 16:23:08 +00:00
feat(agents)!: changing agents API signatures to use OpenAI types
Replace legacy Message/SamplingParams usage with OpenAI chat message structures across agents: schemas, meta-reference implementation, and tests now rely on OpenAI message/tool payloads and generation knobs.
This commit is contained in:
parent
548ccff368
commit
c56b2deb7d
6 changed files with 392 additions and 305 deletions
|
@ -10,12 +10,14 @@ import re
|
|||
import uuid
|
||||
import warnings
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
from datetime import UTC, datetime
|
||||
|
||||
import httpx
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
AgentConfig,
|
||||
OpenAIResponseInputTool,
|
||||
AgentToolGroup,
|
||||
AgentToolGroupWithArgs,
|
||||
AgentTurnCreateRequest,
|
||||
|
@ -32,16 +34,12 @@ from llama_stack.apis.agents import (
|
|||
Document,
|
||||
InferenceStep,
|
||||
ShieldCallStep,
|
||||
Step,
|
||||
StepType,
|
||||
ToolExecutionStep,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import (
|
||||
URL,
|
||||
TextContentItem,
|
||||
ToolCallDelta,
|
||||
ToolCallParseStatus,
|
||||
)
|
||||
from llama_stack.apis.common.content_types import URL, ToolCallDelta, ToolCallParseStatus
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.apis.inference import (
|
||||
ChatCompletionResponseEventType,
|
||||
|
@ -50,20 +48,24 @@ from llama_stack.apis.inference import (
|
|||
Message,
|
||||
OpenAIAssistantMessageParam,
|
||||
OpenAIDeveloperMessageParam,
|
||||
OpenAIChatCompletionContentPartImageParam,
|
||||
OpenAIChatCompletionContentPartTextParam,
|
||||
OpenAIChatCompletionMessageContent,
|
||||
OpenAIChatCompletionToolCall,
|
||||
OpenAIChatCompletionToolCallFunction,
|
||||
OpenAIImageURL,
|
||||
OpenAIMessageParam,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.tools import ToolDef, ToolGroups, ToolInvocationResult, ToolRuntime
|
||||
from llama_stack.apis.vector_io import VectorIO
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -90,6 +92,167 @@ RAG_TOOL_GROUP = "builtin::rag"
|
|||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
def _map_finish_reason_to_stop_reason(finish_reason: str | None) -> StopReason:
|
||||
if finish_reason == "length":
|
||||
return StopReason.out_of_tokens
|
||||
if finish_reason == "tool_calls":
|
||||
return StopReason.end_of_message
|
||||
# Default to end_of_turn for unknown or "stop"
|
||||
return StopReason.end_of_turn
|
||||
|
||||
|
||||
def _map_stop_reason_to_finish_reason(stop_reason: StopReason | None) -> str | None:
|
||||
if stop_reason == StopReason.out_of_tokens:
|
||||
return "length"
|
||||
if stop_reason == StopReason.end_of_message:
|
||||
return "tool_calls"
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
return "stop"
|
||||
return None
|
||||
|
||||
|
||||
def _openai_tool_call_to_legacy(tool_call: OpenAIChatCompletionToolCall) -> ToolCall:
|
||||
name = None
|
||||
if tool_call.function and tool_call.function.name:
|
||||
name = tool_call.function.name
|
||||
return ToolCall(
|
||||
call_id=tool_call.id or f"call_{uuid.uuid4()}",
|
||||
tool_name=name or "",
|
||||
arguments=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "{}",
|
||||
)
|
||||
|
||||
|
||||
def _legacy_tool_call_to_openai(tool_call: ToolCall, index: int | None = None) -> OpenAIChatCompletionToolCall:
|
||||
function_name = tool_call.tool_name if not isinstance(tool_call.tool_name, BuiltinTool) else tool_call.tool_name.value
|
||||
return OpenAIChatCompletionToolCall(
|
||||
index=index,
|
||||
id=tool_call.call_id,
|
||||
function=OpenAIChatCompletionToolCallFunction(
|
||||
name=function_name,
|
||||
arguments=tool_call.arguments,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _tool_response_message_to_openai(response: ToolResponseMessage) -> OpenAIToolMessageParam:
|
||||
content = _coerce_to_text(response.content)
|
||||
return OpenAIToolMessageParam(
|
||||
tool_call_id=response.call_id,
|
||||
content=content,
|
||||
)
|
||||
|
||||
|
||||
def _openai_message_content_to_text(
|
||||
content: OpenAIChatCompletionMessageContent,
|
||||
) -> str:
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, OpenAIChatCompletionContentPartTextParam):
|
||||
parts.append(item.text)
|
||||
elif isinstance(item, OpenAIChatCompletionContentPartImageParam) and item.image_url:
|
||||
if item.image_url.url:
|
||||
parts.append(item.image_url.url)
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def _append_text_to_openai_message(message: OpenAIMessageParam, text: str) -> None:
|
||||
if not text:
|
||||
return
|
||||
if isinstance(message, OpenAIUserMessageParam):
|
||||
content = message.content
|
||||
if content is None or content == "":
|
||||
message.content = text
|
||||
elif isinstance(content, str):
|
||||
message.content = f"{content}\n{text}"
|
||||
else:
|
||||
content.append(OpenAIChatCompletionContentPartTextParam(text=text))
|
||||
|
||||
|
||||
def _coerce_to_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
return "\n".join(_coerce_to_text(item) for item in content)
|
||||
if hasattr(content, "text"):
|
||||
return getattr(content, "text")
|
||||
if hasattr(content, "image"):
|
||||
image = getattr(content, "image")
|
||||
if hasattr(image, "url") and image.url:
|
||||
return getattr(image.url, "uri", "")
|
||||
return str(content)
|
||||
|
||||
|
||||
def _openai_message_param_to_legacy(message: OpenAIMessageParam) -> Message:
|
||||
if isinstance(message, OpenAIUserMessageParam):
|
||||
return UserMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAISystemMessageParam):
|
||||
return SystemMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAIToolMessageParam):
|
||||
return ToolResponseMessage(
|
||||
call_id=message.tool_call_id,
|
||||
content=_openai_message_content_to_text(message.content),
|
||||
)
|
||||
if isinstance(message, OpenAIDeveloperMessageParam):
|
||||
# Map developer messages to user role for legacy compatibility
|
||||
return UserMessage(content=_openai_message_content_to_text(message.content))
|
||||
if isinstance(message, OpenAIAssistantMessageParam):
|
||||
tool_calls = [
|
||||
_openai_tool_call_to_legacy(tool_call)
|
||||
for tool_call in message.tool_calls or []
|
||||
]
|
||||
return CompletionMessage(
|
||||
content=_openai_message_content_to_text(message.content) if message.content is not None else "",
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
raise ValueError(f"Unsupported OpenAI message type: {type(message)}")
|
||||
|
||||
|
||||
async def _legacy_message_to_openai(message: Message) -> OpenAIMessageParam:
|
||||
openai_dict = await convert_message_to_openai_dict_new(message)
|
||||
role = openai_dict.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_dict)
|
||||
if role == "system":
|
||||
return OpenAISystemMessageParam(**openai_dict)
|
||||
if role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_dict)
|
||||
if role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_dict)
|
||||
if role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_dict)
|
||||
raise ValueError(f"Unsupported OpenAI message role: {role}")
|
||||
|
||||
|
||||
async def _completion_to_openai_assistant(
|
||||
completion: CompletionMessage,
|
||||
) -> tuple[OpenAIAssistantMessageParam, str | None]:
|
||||
assistant_param = await _legacy_message_to_openai(completion)
|
||||
assert isinstance(assistant_param, OpenAIAssistantMessageParam)
|
||||
finish_reason = _map_stop_reason_to_finish_reason(completion.stop_reason)
|
||||
return assistant_param, finish_reason
|
||||
|
||||
|
||||
def _client_tool_to_tool_definition(tool: OpenAIResponseInputTool | ToolDef) -> ToolDefinition:
|
||||
if isinstance(tool, ToolDef):
|
||||
return ToolDefinition(
|
||||
tool_name=tool.name,
|
||||
description=tool.description,
|
||||
input_schema=tool.input_schema,
|
||||
)
|
||||
if getattr(tool, "type", None) == "function":
|
||||
return ToolDefinition(
|
||||
tool_name=tool.name,
|
||||
description=getattr(tool, "description", None),
|
||||
input_schema=getattr(tool, "parameters", None),
|
||||
)
|
||||
raise ValueError(f"Unsupported client tool type '{getattr(tool, 'type', None)}' for agent configuration")
|
||||
|
||||
|
||||
class ChatAgent(ShieldRunnerMixin):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -123,59 +286,70 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
output_shields=agent_config.output_shields,
|
||||
)
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[Message]:
|
||||
messages = []
|
||||
def _resolve_generation_options(
|
||||
self,
|
||||
request: AgentTurnCreateRequest | AgentTurnResumeRequest,
|
||||
) -> dict[str, Any]:
|
||||
def _pick(attr: str) -> Any:
|
||||
value = getattr(request, attr, None)
|
||||
if value is not None:
|
||||
return value
|
||||
return getattr(self.agent_config, attr)
|
||||
|
||||
# NOTE: if a toolcall response is in a step, we do not add it when processing the input messages
|
||||
tool_call_ids = set()
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
tool_call_ids.add(response.call_id)
|
||||
temperature = _pick("temperature")
|
||||
top_p = _pick("top_p")
|
||||
max_output_tokens = _pick("max_output_tokens")
|
||||
stop = _pick("stop")
|
||||
|
||||
for m in turn.input_messages:
|
||||
msg = m.model_copy()
|
||||
# We do not want to keep adding RAG context to the input messages
|
||||
# May be this should be a parameter of the agentic instance
|
||||
# that can define its behavior in a custom way
|
||||
if isinstance(msg, UserMessage):
|
||||
msg.context = None
|
||||
if isinstance(msg, ToolResponseMessage):
|
||||
if msg.call_id in tool_call_ids:
|
||||
# NOTE: do not add ToolResponseMessage here, we'll add them in tool_execution steps
|
||||
continue
|
||||
# Ensure we don't share mutable defaults
|
||||
if isinstance(stop, list):
|
||||
stop = list(stop)
|
||||
|
||||
messages.append(msg)
|
||||
return {
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"max_output_tokens": max_output_tokens,
|
||||
"stop": stop,
|
||||
}
|
||||
|
||||
def turn_to_messages(self, turn: Turn) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
|
||||
tool_response_ids = {
|
||||
response.tool_call_id
|
||||
for step in turn.steps
|
||||
if step.step_type == StepType.tool_execution.value
|
||||
for response in step.tool_responses
|
||||
}
|
||||
|
||||
for message in turn.input_messages:
|
||||
copied = message.model_copy(deep=True)
|
||||
if isinstance(copied, OpenAIToolMessageParam) and copied.tool_call_id in tool_response_ids:
|
||||
# Skip tool responses; they will be reintroduced from the execution step
|
||||
continue
|
||||
messages.append(copied)
|
||||
|
||||
for step in turn.steps:
|
||||
if step.step_type == StepType.inference.value:
|
||||
messages.append(step.model_response)
|
||||
messages.append(step.model_response.model_copy(deep=True))
|
||||
elif step.step_type == StepType.tool_execution.value:
|
||||
for response in step.tool_responses:
|
||||
messages.append(
|
||||
ToolResponseMessage(
|
||||
call_id=response.call_id,
|
||||
content=response.content,
|
||||
)
|
||||
)
|
||||
elif step.step_type == StepType.shield_call.value:
|
||||
if step.violation:
|
||||
# CompletionMessage itself in the ShieldResponse
|
||||
messages.append(
|
||||
CompletionMessage(
|
||||
content=step.violation.user_message,
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
)
|
||||
)
|
||||
messages.append(response.model_copy(deep=True))
|
||||
elif step.step_type == StepType.shield_call.value and step.violation:
|
||||
assistant_msg = OpenAIAssistantMessageParam(
|
||||
content=str(step.violation.user_message),
|
||||
)
|
||||
messages.append(assistant_msg)
|
||||
|
||||
return messages
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
return await self.storage.create_session(name)
|
||||
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[Message]:
|
||||
messages = []
|
||||
if self.agent_config.instructions != "":
|
||||
messages.append(SystemMessage(content=self.agent_config.instructions))
|
||||
async def get_messages_from_turns(self, turns: list[Turn]) -> list[OpenAIMessageParam]:
|
||||
messages: list[OpenAIMessageParam] = []
|
||||
if self.agent_config.instructions:
|
||||
messages.append(OpenAISystemMessageParam(content=self.agent_config.instructions))
|
||||
|
||||
for turn in turns:
|
||||
messages.extend(self.turn_to_messages(turn))
|
||||
|
@ -228,26 +402,19 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
if is_resume and len(turns) == 0:
|
||||
raise ValueError("No turns found for session")
|
||||
|
||||
steps = []
|
||||
messages = await self.get_messages_from_turns(turns)
|
||||
steps: list[Step] = []
|
||||
history_openai = await self.get_messages_from_turns(turns)
|
||||
|
||||
if turn_id is None:
|
||||
turn_id = request.turn_id
|
||||
|
||||
if is_resume:
|
||||
tool_response_messages = [
|
||||
ToolResponseMessage(call_id=x.call_id, content=x.content) for x in request.tool_responses
|
||||
]
|
||||
messages.extend(tool_response_messages)
|
||||
tool_response_messages = [resp.model_copy(deep=True) for resp in request.tool_responses]
|
||||
history_openai.extend(tool_response_messages)
|
||||
|
||||
last_turn = turns[-1]
|
||||
last_turn_messages = self.turn_to_messages(last_turn)
|
||||
last_turn_messages = [
|
||||
x for x in last_turn_messages if isinstance(x, UserMessage) or isinstance(x, ToolResponseMessage)
|
||||
]
|
||||
last_turn_messages.extend(tool_response_messages)
|
||||
steps = list(last_turn.steps)
|
||||
|
||||
# get steps from the turn
|
||||
steps = last_turn.steps
|
||||
|
||||
# mark tool execution step as complete
|
||||
# if there's no tool execution in progress step (due to storage, or tool call parsing on client),
|
||||
# we'll create a new tool execution step with current time
|
||||
in_progress_tool_call_step = await self.storage.get_in_progress_tool_call_step(
|
||||
request.session_id, request.turn_id
|
||||
)
|
||||
|
@ -256,7 +423,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
step_id=(in_progress_tool_call_step.step_id if in_progress_tool_call_step else str(uuid.uuid4())),
|
||||
turn_id=request.turn_id,
|
||||
tool_calls=(in_progress_tool_call_step.tool_calls if in_progress_tool_call_step else []),
|
||||
tool_responses=request.tool_responses,
|
||||
tool_responses=tool_response_messages,
|
||||
completed_at=now,
|
||||
started_at=(in_progress_tool_call_step.started_at if in_progress_tool_call_step else now),
|
||||
)
|
||||
|
@ -270,26 +437,34 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
input_messages = last_turn.input_messages
|
||||
|
||||
turn_id = request.turn_id
|
||||
input_messages_openai = [msg.model_copy(deep=True) for msg in last_turn.input_messages]
|
||||
start_time = last_turn.started_at
|
||||
else:
|
||||
messages.extend(request.messages)
|
||||
new_messages = [msg.model_copy(deep=True) for msg in request.messages]
|
||||
history_openai.extend(new_messages)
|
||||
input_messages_openai = new_messages
|
||||
start_time = datetime.now(UTC).isoformat()
|
||||
input_messages = request.messages
|
||||
|
||||
output_message = None
|
||||
generation_options = self._resolve_generation_options(request)
|
||||
|
||||
output_completion: CompletionMessage | None = None
|
||||
output_finish_reason: str | None = None
|
||||
output_assistant_message: OpenAIAssistantMessageParam | None = None
|
||||
async for chunk in self.run(
|
||||
session_id=request.session_id,
|
||||
turn_id=turn_id,
|
||||
input_messages=messages,
|
||||
sampling_params=self.agent_config.sampling_params,
|
||||
input_messages=history_openai,
|
||||
stream=request.stream,
|
||||
documents=request.documents if not is_resume else None,
|
||||
temperature=generation_options["temperature"],
|
||||
top_p=generation_options["top_p"],
|
||||
max_output_tokens=generation_options["max_output_tokens"],
|
||||
stop=generation_options["stop"],
|
||||
):
|
||||
if isinstance(chunk, CompletionMessage):
|
||||
output_message = chunk
|
||||
output_completion = chunk
|
||||
output_assistant_message, output_finish_reason = await _completion_to_openai_assistant(chunk)
|
||||
continue
|
||||
|
||||
assert isinstance(chunk, AgentTurnResponseStreamChunk), f"Unexpected type {type(chunk)}"
|
||||
|
@ -299,19 +474,21 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
yield chunk
|
||||
|
||||
assert output_message is not None
|
||||
assert output_completion is not None
|
||||
assert output_assistant_message is not None
|
||||
|
||||
turn = Turn(
|
||||
turn_id=turn_id,
|
||||
session_id=request.session_id,
|
||||
input_messages=input_messages,
|
||||
output_message=output_message,
|
||||
input_messages=input_messages_openai,
|
||||
output_message=output_assistant_message,
|
||||
finish_reason=output_finish_reason,
|
||||
started_at=start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
steps=steps,
|
||||
)
|
||||
await self.storage.add_turn_to_session(request.session_id, turn)
|
||||
if output_message.tool_calls:
|
||||
if output_assistant_message.tool_calls:
|
||||
chunk = AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseTurnAwaitingInputPayload(
|
||||
|
@ -334,10 +511,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
stream: bool = False,
|
||||
documents: list[Document] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# Doing async generators makes downstream code much simpler and everything amenable to
|
||||
# streaming. However, it also makes things complicated here because AsyncGenerators cannot
|
||||
|
@ -357,9 +537,12 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
session_id,
|
||||
turn_id,
|
||||
input_messages,
|
||||
sampling_params,
|
||||
stream,
|
||||
documents,
|
||||
temperature,
|
||||
top_p,
|
||||
max_output_tokens,
|
||||
stop,
|
||||
):
|
||||
if isinstance(res, bool):
|
||||
return
|
||||
|
@ -370,8 +553,9 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield res
|
||||
|
||||
assert final_response is not None
|
||||
final_assistant, final_finish_reason = await _completion_to_openai_assistant(copy.deepcopy(final_response))
|
||||
# for output shields run on the full input and output combination
|
||||
messages = input_messages + [final_response]
|
||||
messages = input_messages + [final_assistant.model_copy(deep=True)]
|
||||
|
||||
if len(self.output_shields) > 0:
|
||||
async for res in self.run_multiple_shields_wrapper(
|
||||
|
@ -387,7 +571,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
async def run_multiple_shields_wrapper(
|
||||
self,
|
||||
turn_id: str,
|
||||
messages: list[Message],
|
||||
messages: list[OpenAIMessageParam],
|
||||
shields: list[str],
|
||||
touchpoint: str,
|
||||
) -> AsyncGenerator:
|
||||
|
@ -412,7 +596,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
)
|
||||
await self.run_multiple_shields(messages, shields)
|
||||
legacy_messages = [_openai_message_param_to_legacy(m) for m in messages]
|
||||
await self.run_multiple_shields(legacy_messages, shields)
|
||||
|
||||
except SafetyException as e:
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
|
@ -461,29 +646,25 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
input_messages: list[Message],
|
||||
sampling_params: SamplingParams,
|
||||
input_messages: list[OpenAIMessageParam],
|
||||
stream: bool = False,
|
||||
documents: list[Document] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
) -> AsyncGenerator:
|
||||
# if document is passed in a turn, we parse the raw text of the document
|
||||
# and sent it as a user message
|
||||
if documents:
|
||||
contexts = []
|
||||
conversation = [msg.model_copy(deep=True) for msg in input_messages]
|
||||
|
||||
# if document is passed in a turn, hydrate the last user message with the context
|
||||
if documents and conversation:
|
||||
appended_texts = []
|
||||
for document in documents:
|
||||
raw_document_text = await get_raw_document_text(document)
|
||||
contexts.append(raw_document_text)
|
||||
|
||||
attached_context = "\n".join(contexts)
|
||||
if isinstance(input_messages[-1].content, str):
|
||||
input_messages[-1].content += attached_context
|
||||
elif isinstance(input_messages[-1].content, list):
|
||||
input_messages[-1].content.append(TextContentItem(text=attached_context))
|
||||
else:
|
||||
input_messages[-1].content = [
|
||||
input_messages[-1].content,
|
||||
TextContentItem(text=attached_context),
|
||||
]
|
||||
if raw_document_text:
|
||||
appended_texts.append(raw_document_text)
|
||||
if appended_texts:
|
||||
_append_text_to_openai_message(conversation[-1], "\n".join(appended_texts))
|
||||
|
||||
session_info = await self.storage.get_session_info(session_id)
|
||||
# if the session has a memory bank id, let the memory tool use it
|
||||
|
@ -500,9 +681,13 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
n_iter = await self.storage.get_num_infer_iters_in_turn(session_id, turn_id) or 0
|
||||
|
||||
# Build a map of custom tools to their definitions for faster lookup
|
||||
client_tools = {}
|
||||
for tool in self.agent_config.client_tools:
|
||||
client_tools[tool.name] = tool
|
||||
client_tools: dict[str, OpenAIResponseInputTool | ToolDef] = {}
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
if isinstance(tool, ToolDef) and tool.name:
|
||||
client_tools[tool.name] = tool
|
||||
elif getattr(tool, "type", None) == "function" and getattr(tool, "name", None):
|
||||
client_tools[tool.name] = tool
|
||||
while True:
|
||||
step_id = str(uuid.uuid4())
|
||||
inference_start_time = datetime.now(UTC).isoformat()
|
||||
|
@ -520,81 +705,33 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
stop_reason: StopReason | None = None
|
||||
|
||||
async with tracing.span("inference") as span:
|
||||
if self.telemetry_enabled and span is not None:
|
||||
if self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
if self.telemetry_enabled and span is not None and self.agent_config.name:
|
||||
span.set_attribute("agent_name", self.agent_config.name)
|
||||
|
||||
def _serialize_nested(value):
|
||||
"""Recursively serialize nested Pydantic models to dicts."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if isinstance(value, BaseModel):
|
||||
return value.model_dump(mode="json")
|
||||
elif isinstance(value, dict):
|
||||
return {k: _serialize_nested(v) for k, v in value.items()}
|
||||
elif isinstance(value, list):
|
||||
return [_serialize_nested(item) for item in value]
|
||||
else:
|
||||
return value
|
||||
|
||||
def _add_type(openai_msg: dict) -> OpenAIMessageParam:
|
||||
# Serialize any nested Pydantic models to plain dicts
|
||||
openai_msg = _serialize_nested(openai_msg)
|
||||
|
||||
role = openai_msg.get("role")
|
||||
if role == "user":
|
||||
return OpenAIUserMessageParam(**openai_msg)
|
||||
elif role == "system":
|
||||
return OpenAISystemMessageParam(**openai_msg)
|
||||
elif role == "assistant":
|
||||
return OpenAIAssistantMessageParam(**openai_msg)
|
||||
elif role == "tool":
|
||||
return OpenAIToolMessageParam(**openai_msg)
|
||||
elif role == "developer":
|
||||
return OpenAIDeveloperMessageParam(**openai_msg)
|
||||
else:
|
||||
raise ValueError(f"Unknown message role: {role}")
|
||||
|
||||
# Convert messages to OpenAI format
|
||||
openai_messages: list[OpenAIMessageParam] = [
|
||||
_add_type(await convert_message_to_openai_dict_new(message)) for message in input_messages
|
||||
]
|
||||
|
||||
# Convert tool definitions to OpenAI format
|
||||
openai_tools = [convert_tooldef_to_openai_tool(x) for x in (self.tool_defs or [])]
|
||||
|
||||
# Extract tool_choice from tool_config for OpenAI compatibility
|
||||
# Note: tool_choice can only be provided when tools are also provided
|
||||
tool_choice = None
|
||||
if openai_tools and self.agent_config.tool_config and self.agent_config.tool_config.tool_choice:
|
||||
tc = self.agent_config.tool_config.tool_choice
|
||||
tool_choice_str = tc.value if hasattr(tc, "value") else str(tc)
|
||||
# Convert tool_choice to OpenAI format
|
||||
if tool_choice_str in ("auto", "none", "required"):
|
||||
tool_choice = tool_choice_str
|
||||
else:
|
||||
# It's a specific tool name, wrap it in the proper format
|
||||
tool_choice = {"type": "function", "function": {"name": tool_choice_str}}
|
||||
|
||||
# Convert sampling params to OpenAI format (temperature, top_p, max_tokens)
|
||||
temperature = getattr(getattr(sampling_params, "strategy", None), "temperature", None)
|
||||
top_p = getattr(getattr(sampling_params, "strategy", None), "top_p", None)
|
||||
max_tokens = getattr(sampling_params, "max_tokens", None)
|
||||
|
||||
# Use OpenAI chat completion
|
||||
openai_stream = await self.inference_api.openai_chat_completion(
|
||||
model=self.agent_config.model,
|
||||
messages=openai_messages,
|
||||
messages=[msg.model_copy(deep=True) for msg in conversation],
|
||||
tools=openai_tools if openai_tools else None,
|
||||
tool_choice=tool_choice,
|
||||
response_format=self.agent_config.response_format,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_tokens=max_tokens,
|
||||
max_tokens=max_output_tokens,
|
||||
stop=stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Convert OpenAI stream back to Llama Stack format
|
||||
response_stream = convert_openai_chat_completion_stream(
|
||||
openai_stream, enable_incremental_tool_calls=True
|
||||
)
|
||||
|
@ -644,7 +781,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("stop_reason", stop_reason or StopReason.end_of_turn)
|
||||
span.set_attribute(
|
||||
"input",
|
||||
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
|
||||
json.dumps([json.loads(m.model_copy(deep=True).model_dump_json()) for m in conversation]),
|
||||
)
|
||||
output_attr = json.dumps(
|
||||
{
|
||||
|
@ -671,6 +808,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
assistant_param, finish_reason = await _completion_to_openai_assistant(copy.deepcopy(message))
|
||||
|
||||
yield AgentTurnResponseStreamChunk(
|
||||
event=AgentTurnResponseEvent(
|
||||
payload=AgentTurnResponseStepCompletePayload(
|
||||
|
@ -682,7 +821,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
# `deepcopy` for now, but this is symptomatic of a deeper issue.
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
model_response=copy.deepcopy(message),
|
||||
model_response=assistant_param,
|
||||
finish_reason=finish_reason,
|
||||
started_at=inference_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
),
|
||||
|
@ -703,9 +843,10 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
break
|
||||
|
||||
assistant_param = assistant_param.model_copy(deep=True)
|
||||
|
||||
if len(message.tool_calls) == 0:
|
||||
if stop_reason == StopReason.end_of_turn:
|
||||
# TODO: UPDATE RETURN TYPE TO SEND A TUPLE OF (MESSAGE, ATTACHMENTS)
|
||||
if len(output_attachments) > 0:
|
||||
if isinstance(message.content, list):
|
||||
message.content += output_attachments
|
||||
|
@ -714,18 +855,20 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
yield message
|
||||
else:
|
||||
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
|
||||
input_messages = input_messages + [message]
|
||||
conversation.append(assistant_param)
|
||||
else:
|
||||
input_messages = input_messages + [message]
|
||||
conversation.append(assistant_param)
|
||||
|
||||
# Process tool calls in the message
|
||||
client_tool_calls = []
|
||||
non_client_tool_calls = []
|
||||
client_tool_calls_openai = []
|
||||
|
||||
# Separate client and non-client tool calls
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.tool_name in client_tools:
|
||||
client_tool_calls.append(tool_call)
|
||||
client_tool_calls_openai.append(_legacy_tool_call_to_openai(tool_call))
|
||||
else:
|
||||
non_client_tool_calls.append(tool_call)
|
||||
|
||||
|
@ -781,18 +924,14 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
span.set_attribute("output", result_message.model_dump_json())
|
||||
|
||||
# Store tool execution step
|
||||
openai_tool_call = _legacy_tool_call_to_openai(tool_call)
|
||||
openai_tool_response = _tool_response_message_to_openai(result_message)
|
||||
|
||||
tool_execution_step = ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=[tool_call],
|
||||
tool_responses=[
|
||||
ToolResponse(
|
||||
call_id=tool_call.call_id,
|
||||
tool_name=tool_call.tool_name,
|
||||
content=tool_result.content,
|
||||
metadata=tool_result.metadata,
|
||||
)
|
||||
],
|
||||
tool_calls=[openai_tool_call],
|
||||
tool_responses=[openai_tool_response],
|
||||
started_at=tool_execution_start_time,
|
||||
completed_at=datetime.now(UTC).isoformat(),
|
||||
)
|
||||
|
@ -808,8 +947,8 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
)
|
||||
)
|
||||
|
||||
# Add the result message to input_messages for the next iteration
|
||||
input_messages.append(result_message)
|
||||
# Add the result message to conversation for the next iteration
|
||||
conversation.append(openai_tool_response)
|
||||
|
||||
# TODO: add tool-input touchpoint and a "start" event for this step also
|
||||
# but that needs a lot more refactoring of Tool code potentially
|
||||
|
@ -829,7 +968,7 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
ToolExecutionStep(
|
||||
step_id=step_id,
|
||||
turn_id=turn_id,
|
||||
tool_calls=client_tool_calls,
|
||||
tool_calls=client_tool_calls_openai,
|
||||
tool_responses=[],
|
||||
started_at=datetime.now(UTC).isoformat(),
|
||||
),
|
||||
|
@ -866,19 +1005,15 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
|
||||
toolgroup_to_args = toolgroup_to_args or {}
|
||||
|
||||
tool_name_to_def = {}
|
||||
tool_name_to_args = {}
|
||||
tool_name_to_def: dict[str | BuiltinTool, ToolDefinition] = {}
|
||||
tool_name_to_args: dict[str | BuiltinTool, dict[str, Any]] = {}
|
||||
|
||||
for tool_def in self.agent_config.client_tools:
|
||||
if tool_name_to_def.get(tool_def.name, None):
|
||||
raise ValueError(f"Tool {tool_def.name} already exists")
|
||||
|
||||
# Use input_schema from ToolDef directly
|
||||
tool_name_to_def[tool_def.name] = ToolDefinition(
|
||||
tool_name=tool_def.name,
|
||||
description=tool_def.description,
|
||||
input_schema=tool_def.input_schema,
|
||||
)
|
||||
if self.agent_config.client_tools:
|
||||
for tool in self.agent_config.client_tools:
|
||||
tool_definition = _client_tool_to_tool_definition(tool)
|
||||
if tool_name_to_def.get(tool_definition.tool_name):
|
||||
raise ValueError(f"Tool {tool_definition.tool_name} already exists")
|
||||
tool_name_to_def[tool_definition.tool_name] = tool_definition
|
||||
for toolgroup_name_with_maybe_tool_name in agent_config_toolgroups:
|
||||
toolgroup_name, input_tool_name = self._parse_toolgroup_name(toolgroup_name_with_maybe_tool_name)
|
||||
tools = await self.tool_groups_api.list_tools(toolgroup_id=toolgroup_name)
|
||||
|
@ -999,12 +1134,7 @@ async def get_raw_document_text(document: Document) -> str:
|
|||
|
||||
if isinstance(document.content, URL):
|
||||
return await load_data_from_url(document.content.uri)
|
||||
elif isinstance(document.content, str):
|
||||
return document.content
|
||||
elif isinstance(document.content, TextContentItem):
|
||||
return document.content.text
|
||||
else:
|
||||
raise ValueError(f"Unexpected document content type: {type(document.content)}")
|
||||
return _openai_message_content_to_text(document.content)
|
||||
|
||||
|
||||
def _interpret_content_as_attachment(
|
||||
|
@ -1015,7 +1145,7 @@ def _interpret_content_as_attachment(
|
|||
snippet = match.group(1)
|
||||
data = json.loads(snippet)
|
||||
return Attachment(
|
||||
url=URL(uri="file://" + data["filepath"]),
|
||||
content=URL(uri="file://" + data["filepath"]),
|
||||
mime_type=data["mimetype"],
|
||||
)
|
||||
|
||||
|
|
|
@ -33,9 +33,8 @@ from llama_stack.apis.common.responses import PaginatedResponse
|
|||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
OpenAIMessageParam,
|
||||
OpenAIToolMessageParam,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
|
@ -156,7 +155,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
messages: list[OpenAIMessageParam],
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
documents: list[Document] | None = None,
|
||||
stream: bool | None = False,
|
||||
|
@ -189,7 +188,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
tool_responses: list[OpenAIToolMessageParam],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue