chore!: remove the agents (sessions and turns) API (#4055)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s

- Removes the deprecated agents (sessions and turns) API that was marked
alpha in 0.3.0
- Cleans up unused imports and orphaned types after the API removal
- Removes `SessionNotFoundError` and `AgentTurnInputType` which are no
longer needed

The agents API is completely superseded by the Responses + Conversations
APIs, and the client SDK Agent class already uses those implementations.

Corresponding client-side PR:
https://github.com/llamastack/llama-stack-client-python/pull/295
This commit is contained in:
Ashwin Bharambe 2025-11-04 09:38:39 -08:00 committed by GitHub
parent a6ddbae0ed
commit a8a8aa56c0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
1037 changed files with 393 additions and 309806 deletions

View file

@ -5,30 +5,13 @@
# the root directory of this source tree.
from collections.abc import AsyncIterator
from datetime import datetime
from enum import StrEnum
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
from typing import Annotated, Protocol, runtime_checkable
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.common.responses import Order, PaginatedResponse
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
from llama_stack.apis.common.responses import Order
from llama_stack.apis.version import LLAMA_STACK_API_V1
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, webmethod
from .openai_responses import (
ListOpenAIResponseInputItem,
@ -57,658 +40,12 @@ class ResponseGuardrailSpec(BaseModel):
ResponseGuardrail = str | ResponseGuardrailSpec
class Attachment(BaseModel):
"""An attachment to an agent turn.
:param content: The content of the attachment.
:param mime_type: The MIME type of the attachment.
"""
content: InterleavedContent | URL
mime_type: str
class Document(BaseModel):
"""A document to be used by an agent.
:param content: The content of the document.
:param mime_type: The MIME type of the document.
"""
content: InterleavedContent | URL
mime_type: str
class StepCommon(BaseModel):
"""A common step in an agent turn.
:param turn_id: The ID of the turn.
:param step_id: The ID of the step.
:param started_at: The time the step started.
:param completed_at: The time the step completed.
"""
turn_id: str
step_id: str
started_at: datetime | None = None
completed_at: datetime | None = None
class StepType(StrEnum):
"""Type of the step in an agent turn.
:cvar inference: The step is an inference step that calls an LLM.
:cvar tool_execution: The step is a tool execution step that executes a tool call.
:cvar shield_call: The step is a shield call step that checks for safety violations.
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
"""
inference = "inference"
tool_execution = "tool_execution"
shield_call = "shield_call"
memory_retrieval = "memory_retrieval"
@json_schema_type
class InferenceStep(StepCommon):
"""An inference step in an agent turn.
:param model_response: The response from the LLM.
"""
model_config = ConfigDict(protected_namespaces=())
step_type: Literal[StepType.inference] = StepType.inference
model_response: CompletionMessage
@json_schema_type
class ToolExecutionStep(StepCommon):
"""A tool execution step in an agent turn.
:param tool_calls: The tool calls to execute.
:param tool_responses: The tool responses from the tool calls.
"""
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
tool_calls: list[ToolCall]
tool_responses: list[ToolResponse]
@json_schema_type
class ShieldCallStep(StepCommon):
"""A shield call step in an agent turn.
:param violation: The violation from the shield call.
"""
step_type: Literal[StepType.shield_call] = StepType.shield_call
violation: SafetyViolation | None
@json_schema_type
class MemoryRetrievalStep(StepCommon):
"""A memory retrieval step in an agent turn.
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
:param inserted_context: The context retrieved from the vector databases.
"""
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
# TODO: should this be List[str]?
vector_store_ids: str
inserted_context: InterleavedContent
Step = Annotated[
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
Field(discriminator="step_type"),
]
@json_schema_type
class Turn(BaseModel):
"""A single turn in an interaction with an Agentic System.
:param turn_id: Unique identifier for the turn within a session
:param session_id: Unique identifier for the conversation session
:param input_messages: List of messages that initiated this turn
:param steps: Ordered list of processing steps executed during this turn
:param output_message: The model's generated response containing content and metadata
:param output_attachments: (Optional) Files or media attached to the agent's response
:param started_at: Timestamp when the turn began
:param completed_at: (Optional) Timestamp when the turn finished, if completed
"""
turn_id: str
session_id: str
input_messages: list[UserMessage | ToolResponseMessage]
steps: list[Step]
output_message: CompletionMessage
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
started_at: datetime
completed_at: datetime | None = None
@json_schema_type
class Session(BaseModel):
"""A single session of an interaction with an Agentic System.
:param session_id: Unique identifier for the conversation session
:param session_name: Human-readable name for the session
:param turns: List of all turns that have occurred in this session
:param started_at: Timestamp when the session was created
"""
session_id: str
session_name: str
turns: list[Turn]
started_at: datetime
class AgentToolGroupWithArgs(BaseModel):
name: str
args: dict[str, Any]
AgentToolGroup = str | AgentToolGroupWithArgs
register_schema(AgentToolGroup, name="AgentTool")
class AgentConfigCommon(BaseModel):
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
input_shields: list[str] | None = Field(default_factory=lambda: [])
output_shields: list[str] | None = Field(default_factory=lambda: [])
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
tool_config: ToolConfig | None = Field(default=None)
max_infer_iters: int | None = 10
def model_post_init(self, __context):
if self.tool_config:
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
class AgentConfig(AgentConfigCommon):
"""Configuration for an agent.
:param model: The model identifier to use for the agent
:param instructions: The system instructions for the agent
:param name: Optional name for the agent, used in telemetry and identification
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
:param response_format: Optional response format configuration
"""
model: str
instructions: str
name: str | None = None
enable_session_persistence: bool | None = False
response_format: ResponseFormat | None = None
@json_schema_type
class Agent(BaseModel):
"""An agent instance with configuration and metadata.
:param agent_id: Unique identifier for the agent
:param agent_config: Configuration settings for the agent
:param created_at: Timestamp when the agent was created
"""
agent_id: str
agent_config: AgentConfig
created_at: datetime
class AgentConfigOverridablePerTurn(AgentConfigCommon):
instructions: str | None = None
class AgentTurnResponseEventType(StrEnum):
step_start = "step_start"
step_complete = "step_complete"
step_progress = "step_progress"
turn_start = "turn_start"
turn_complete = "turn_complete"
turn_awaiting_input = "turn_awaiting_input"
@json_schema_type
class AgentTurnResponseStepStartPayload(BaseModel):
"""Payload for step start events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param metadata: (Optional) Additional metadata for the step
"""
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
step_type: StepType
step_id: str
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
@json_schema_type
class AgentTurnResponseStepCompletePayload(BaseModel):
"""Payload for step completion events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param step_details: Complete details of the executed step
"""
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
step_type: StepType
step_id: str
step_details: Step
@json_schema_type
class AgentTurnResponseStepProgressPayload(BaseModel):
"""Payload for step progress events in agent turn responses.
:param event_type: Type of event being reported
:param step_type: Type of step being executed
:param step_id: Unique identifier for the step within a turn
:param delta: Incremental content changes during step execution
"""
model_config = ConfigDict(protected_namespaces=())
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
step_type: StepType
step_id: str
delta: ContentDelta
@json_schema_type
class AgentTurnResponseTurnStartPayload(BaseModel):
"""Payload for turn start events in agent turn responses.
:param event_type: Type of event being reported
:param turn_id: Unique identifier for the turn within a session
"""
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
turn_id: str
@json_schema_type
class AgentTurnResponseTurnCompletePayload(BaseModel):
"""Payload for turn completion events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Complete turn data including all steps and results
"""
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
turn: Turn
@json_schema_type
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
"""Payload for turn awaiting input events in agent turn responses.
:param event_type: Type of event being reported
:param turn: Turn data when waiting for external tool responses
"""
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
turn: Turn
AgentTurnResponseEventPayload = Annotated[
AgentTurnResponseStepStartPayload
| AgentTurnResponseStepProgressPayload
| AgentTurnResponseStepCompletePayload
| AgentTurnResponseTurnStartPayload
| AgentTurnResponseTurnCompletePayload
| AgentTurnResponseTurnAwaitingInputPayload,
Field(discriminator="event_type"),
]
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
@json_schema_type
class AgentTurnResponseEvent(BaseModel):
"""An event in an agent turn response stream.
:param payload: Event-specific payload containing event data
"""
payload: AgentTurnResponseEventPayload
@json_schema_type
class AgentCreateResponse(BaseModel):
"""Response returned when creating a new agent.
:param agent_id: Unique identifier for the created agent
"""
agent_id: str
@json_schema_type
class AgentSessionCreateResponse(BaseModel):
"""Response returned when creating a new agent session.
:param session_id: Unique identifier for the created session
"""
session_id: str
@json_schema_type
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
"""Request to create a new turn for an agent.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param messages: List of messages to start the turn with
:param documents: (Optional) List of documents to provide to the agent
:param toolgroups: (Optional) List of tool groups to make available for this turn
:param stream: (Optional) Whether to stream the response
:param tool_config: (Optional) Tool configuration to override agent defaults
"""
agent_id: str
session_id: str
# TODO: figure out how we can simplify this and make why
# ToolResponseMessage needs to be here (it is function call
# execution from outside the system)
messages: list[UserMessage | ToolResponseMessage]
documents: list[Document] | None = None
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
stream: bool | None = False
tool_config: ToolConfig | None = None
@json_schema_type
class AgentTurnResumeRequest(BaseModel):
"""Request to resume an agent turn with tool responses.
:param agent_id: Unique identifier for the agent
:param session_id: Unique identifier for the conversation session
:param turn_id: Unique identifier for the turn within a session
:param tool_responses: List of tool responses to submit to continue the turn
:param stream: (Optional) Whether to stream the response
"""
agent_id: str
session_id: str
turn_id: str
tool_responses: list[ToolResponse]
stream: bool | None = False
@json_schema_type
class AgentTurnResponseStreamChunk(BaseModel):
"""Streamed agent turn completion response.
:param event: Individual event in the agent turn response stream
"""
event: AgentTurnResponseEvent
@json_schema_type
class AgentStepResponse(BaseModel):
"""Response containing details of a specific agent step.
:param step: The complete step data and execution details
"""
step: Step
@runtime_checkable
class Agents(Protocol):
"""Agents
APIs for creating and interacting with agentic systems."""
@webmethod(
route="/agents",
method="POST",
descriptive_name="create_agent",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
"""Create an agent with the given configuration.
:param agent_config: The configuration for the agent.
:returns: An AgentCreateResponse with the agent ID.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn",
method="POST",
descriptive_name="create_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Create a new turn for an agent.
:param agent_id: The ID of the agent to create the turn for.
:param session_id: The ID of the session to create the turn for.
:param messages: List of messages to start the turn with.
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
:param documents: (Optional) List of documents to create the turn with.
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
:returns: If stream=False, returns a Turn object.
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
method="POST",
descriptive_name="resume_agent_turn",
level=LLAMA_STACK_API_V1ALPHA,
)
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
"""Resume an agent turn with executed tool call responses.
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
:param agent_id: The ID of the agent to resume.
:param session_id: The ID of the session to resume.
:param turn_id: The ID of the turn to resume.
:param tool_responses: The tool call responses to resume the turn with.
:param stream: Whether to stream the response.
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
) -> Turn:
"""Retrieve an agent turn by its ID.
:param agent_id: The ID of the agent to get the turn for.
:param session_id: The ID of the session to get the turn for.
:param turn_id: The ID of the turn to get.
:returns: A Turn.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_step(
self,
agent_id: str,
session_id: str,
turn_id: str,
step_id: str,
) -> AgentStepResponse:
"""Retrieve an agent step by its ID.
:param agent_id: The ID of the agent to get the step for.
:param session_id: The ID of the session to get the step for.
:param turn_id: The ID of the turn to get the step for.
:param step_id: The ID of the step to get.
:returns: An AgentStepResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session",
method="POST",
descriptive_name="create_agent_session",
level=LLAMA_STACK_API_V1ALPHA,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
"""Create a new session for an agent.
:param agent_id: The ID of the agent to create the session for.
:param session_name: The name of the session to create.
:returns: An AgentSessionCreateResponse.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="GET",
level=LLAMA_STACK_API_V1ALPHA,
)
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
"""Retrieve an agent session by its ID.
:param session_id: The ID of the session to get.
:param agent_id: The ID of the agent to get the session for.
:param turn_ids: (Optional) List of turn IDs to filter the session by.
:returns: A Session.
"""
...
@webmethod(
route="/agents/{agent_id}/session/{session_id}",
method="DELETE",
level=LLAMA_STACK_API_V1ALPHA,
)
async def delete_agents_session(
self,
session_id: str,
agent_id: str,
) -> None:
"""Delete an agent session by its ID and its associated turns.
:param session_id: The ID of the session to delete.
:param agent_id: The ID of the agent to delete the session for.
"""
...
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
async def delete_agent(
self,
agent_id: str,
) -> None:
"""Delete an agent by its ID and its associated sessions and turns.
:param agent_id: The ID of the agent to delete.
"""
...
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
"""List all agents.
:param start_index: The index to start the pagination from.
:param limit: The number of agents to return.
:returns: A PaginatedResponse.
"""
...
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def get_agent(self, agent_id: str) -> Agent:
"""Describe an agent by its ID.
:param agent_id: ID of the agent.
:returns: An Agent of the agent.
"""
...
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
async def list_agent_sessions(
self,
agent_id: str,
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""List all session(s) of a given agent.
:param agent_id: The ID of the agent to list sessions for.
:param start_index: The index to start the pagination from.
:param limit: The number of sessions to return.
:returns: A PaginatedResponse.
"""
...
# We situate the OpenAI Responses API in the Agents API just like we did things
# for Inference. The Responses API, in its intent, serves the same purpose as
# the Agents API above -- it is essentially a lightweight "agentic loop" with

View file

@ -56,14 +56,6 @@ class ToolGroupNotFoundError(ResourceNotFoundError):
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
class SessionNotFoundError(ValueError):
"""raised when Llama Stack cannot find a referenced session or access is denied"""
def __init__(self, session_name: str) -> None:
message = f"Session '{session_name}' not found or access denied."
super().__init__(message)
class ModelTypeError(TypeError):
"""raised when a model is present but not the correct type"""

View file

@ -103,17 +103,6 @@ class CompletionInputType(BaseModel):
type: Literal["completion_input"] = "completion_input"
@json_schema_type
class AgentTurnInputType(BaseModel):
"""Parameter type for agent turn input.
:param type: Discriminator type. Always "agent_turn_input"
"""
# expects List[Message] for messages (may also include attachments?)
type: Literal["agent_turn_input"] = "agent_turn_input"
@json_schema_type
class DialogType(BaseModel):
"""Parameter type for dialog data with semantic output labels.
@ -135,8 +124,7 @@ ParamType = Annotated[
| JsonType
| UnionType
| ChatCompletionInputType
| CompletionInputType
| AgentTurnInputType,
| CompletionInputType,
Field(discriminator="type"),
]
register_schema(ParamType, name="ParamType")

View file

@ -4,17 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Annotated, Any, Literal, Protocol
from typing import Any, Literal, Protocol
from pydantic import BaseModel, Field
from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
@ -32,19 +31,7 @@ class ModelCandidate(BaseModel):
system_message: SystemMessage | None = None
@json_schema_type
class AgentCandidate(BaseModel):
"""An agent candidate for evaluation.
:param config: The configuration for the agent candidate.
"""
type: Literal["agent"] = "agent"
config: AgentConfig
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
register_schema(EvalCandidate, name="EvalCandidate")
EvalCandidate = ModelCandidate
@json_schema_type

View file

@ -4,21 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import uuid
from collections.abc import AsyncGenerator
from datetime import UTC, datetime
from llama_stack.apis.agents import (
Agent,
AgentConfig,
AgentCreateResponse,
Agents,
AgentSessionCreateResponse,
AgentStepResponse,
AgentToolGroup,
AgentTurnCreateRequest,
AgentTurnResumeRequest,
Document,
ListOpenAIResponseInputItem,
ListOpenAIResponseObject,
OpenAIDeleteResponseObject,
@ -26,19 +14,12 @@ from llama_stack.apis.agents import (
OpenAIResponseInputTool,
OpenAIResponseObject,
Order,
Session,
Turn,
)
from llama_stack.apis.agents.agents import ResponseGuardrail
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.conversations import Conversations
from llama_stack.apis.inference import (
Inference,
ToolConfig,
ToolResponse,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import ToolGroups, ToolRuntime
@ -46,12 +27,9 @@ from llama_stack.apis.vector_io import VectorIO
from llama_stack.core.datatypes import AccessRule
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
from llama_stack.providers.utils.pagination import paginate_records
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents::meta_reference")
@ -97,229 +75,6 @@ class MetaReferenceAgentsImpl(Agents):
conversations_api=self.conversations_api,
)
async def create_agent(
self,
agent_config: AgentConfig,
) -> AgentCreateResponse:
agent_id = str(uuid.uuid4())
created_at = datetime.now(UTC)
agent_info = AgentInfo(
**agent_config.model_dump(),
created_at=created_at,
)
# Store the agent info
await self.persistence_store.set(
key=f"agent:{agent_id}",
value=agent_info.model_dump_json(),
)
return AgentCreateResponse(
agent_id=agent_id,
)
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
agent_info_json = await self.persistence_store.get(
key=f"agent:{agent_id}",
)
if not agent_info_json:
raise ValueError(f"Could not find agent info for {agent_id}")
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
except Exception as e:
raise ValueError(f"Could not validate agent info for {agent_id}") from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_info,
inference_api=self.inference_api,
safety_api=self.safety_api,
vector_io_api=self.vector_io_api,
tool_runtime_api=self.tool_runtime_api,
tool_groups_api=self.tool_groups_api,
persistence_store=(
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
),
created_at=agent_info.created_at.isoformat(),
policy=self.policy,
telemetry_enabled=self.telemetry_enabled,
)
async def create_agent_session(
self,
agent_id: str,
session_name: str,
) -> AgentSessionCreateResponse:
agent = await self._get_agent_impl(agent_id)
session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse(
session_id=session_id,
)
async def create_agent_turn(
self,
agent_id: str,
session_id: str,
messages: list[UserMessage | ToolResponseMessage],
stream: bool | None = False,
documents: list[Document] | None = None,
toolgroups: list[AgentToolGroup] | None = None,
tool_config: ToolConfig | None = None,
) -> AsyncGenerator:
request = AgentTurnCreateRequest(
agent_id=agent_id,
session_id=session_id,
messages=messages,
stream=True,
toolgroups=toolgroups,
documents=documents,
tool_config=tool_config,
)
if stream:
return self._create_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _create_agent_turn_streaming(
self,
request: AgentTurnCreateRequest,
) -> AsyncGenerator:
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.create_and_execute_turn(request):
yield event
async def resume_agent_turn(
self,
agent_id: str,
session_id: str,
turn_id: str,
tool_responses: list[ToolResponse],
stream: bool | None = False,
) -> AsyncGenerator:
request = AgentTurnResumeRequest(
agent_id=agent_id,
session_id=session_id,
turn_id=turn_id,
tool_responses=tool_responses,
stream=stream,
)
if stream:
return self._continue_agent_turn_streaming(request)
else:
raise NotImplementedError("Non-streaming agent turns not yet implemented")
async def _continue_agent_turn_streaming(
self,
request: AgentTurnResumeRequest,
) -> AsyncGenerator:
agent = await self._get_agent_impl(request.agent_id)
async for event in agent.resume_turn(request):
yield event
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
agent = await self._get_agent_impl(agent_id)
turn = await agent.storage.get_session_turn(session_id, turn_id)
if turn is None:
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
return turn
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
for step in turn.steps:
if step.step_id == step_id:
return AgentStepResponse(step=step)
raise ValueError(f"Provided step_id {step_id} could not be found")
async def get_agents_session(
self,
session_id: str,
agent_id: str,
turn_ids: list[str] | None = None,
) -> Session:
agent = await self._get_agent_impl(agent_id)
session_info = await agent.storage.get_session_info(session_id)
if session_info is None:
raise ValueError(f"Session {session_id} not found")
turns = await agent.storage.get_session_turns(session_id)
if turn_ids:
turns = [turn for turn in turns if turn.turn_id in turn_ids]
return Session(
session_name=session_info.session_name,
session_id=session_id,
turns=turns,
started_at=session_info.started_at,
)
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
agent = await self._get_agent_impl(agent_id)
# Delete turns first, then the session
await agent.storage.delete_session_turns(session_id)
await agent.storage.delete_session(session_id)
async def delete_agent(self, agent_id: str) -> None:
# First get all sessions for this agent
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Delete all sessions
for session in sessions:
await self.delete_agents_session(agent_id, session.session_id)
# Finally delete the agent itself
await self.persistence_store.delete(f"agent:{agent_id}")
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
agent_list: list[Agent] = []
for agent_key in agent_keys:
agent_id = agent_key.split(":")[1]
# Get the agent info using the key
agent_info_json = await self.persistence_store.get(agent_key)
if not agent_info_json:
logger.error(f"Could not find agent info for key {agent_key}")
continue
try:
agent_info = AgentInfo.model_validate_json(agent_info_json)
agent_list.append(
Agent(
agent_id=agent_id,
agent_config=agent_info,
created_at=agent_info.created_at,
)
)
except Exception as e:
logger.error(f"Error parsing agent info for {agent_id}: {e}")
continue
# Convert Agent objects to dictionaries
agent_dicts = [agent.model_dump() for agent in agent_list]
return paginate_records(agent_dicts, start_index, limit)
async def get_agent(self, agent_id: str) -> Agent:
chat_agent = await self._get_agent_impl(agent_id)
agent = Agent(
agent_id=agent_id,
agent_config=chat_agent.agent_config,
created_at=datetime.fromisoformat(chat_agent.created_at),
)
return agent
async def list_agent_sessions(
self, agent_id: str, start_index: int | None = None, limit: int | None = None
) -> PaginatedResponse:
agent = await self._get_agent_impl(agent_id)
sessions = await agent.storage.list_sessions()
# Convert Session objects to dictionaries
session_dicts = [session.model_dump() for session in sessions]
return paginate_records(session_dicts, start_index, limit)
async def shutdown(self) -> None:
pass

View file

@ -1,261 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import uuid
from dataclasses import dataclass
from datetime import UTC, datetime
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
from llama_stack.apis.common.errors import SessionNotFoundError
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
from llama_stack.core.access_control.conditions import User as ProtocolUser
from llama_stack.core.access_control.datatypes import AccessRule, Action
from llama_stack.core.datatypes import User
from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):
# TODO: is this used anywhere?
vector_store_id: str | None = None
started_at: datetime
owner: User | None = None
identifier: str | None = None
type: str = "session"
class AgentInfo(AgentConfig):
created_at: datetime
@dataclass
class SessionResource:
"""Concrete implementation of ProtectedResource for session access control."""
type: str
identifier: str
owner: ProtocolUser # Use the protocol type for structural compatibility
class AgentPersistence:
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
self.agent_id = agent_id
self.kvstore = kvstore
self.policy = policy
async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4())
# Get current user's auth attributes for new sessions
user = get_authenticated_user()
session_info = AgentSessionInfo(
session_id=session_id,
session_name=name,
started_at=datetime.now(UTC),
owner=user,
turns=[],
identifier=name, # should this be qualified in any way?
)
# Only perform access control if we have an authenticated user
if user is not None and session_info.identifier is not None:
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=user,
)
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
raise AccessDeniedError(Action.CREATE, resource, user)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
return session_id
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
raise SessionNotFoundError(session_id)
session_info = AgentSessionInfo(**json.loads(value))
# Check access to session
if not self._check_session_access(session_info):
return None
return session_info
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
"""Check if current user has access to the session."""
# Handle backward compatibility for old sessions without access control
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
return True
# Get current user - if None, skip access control (e.g., in tests)
user = get_authenticated_user()
if user is None:
return True
# Access control requires identifier and owner to be set
if session_info.identifier is None or session_info.owner is None:
return True
# At this point, both identifier and owner are guaranteed to be non-None
resource = SessionResource(
type=session_info.type,
identifier=session_info.identifier,
owner=session_info.owner,
)
return is_action_allowed(self.policy, Action.READ, resource, user)
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
"""Get session info if the user has access to it. For internal use by sub-session methods."""
session_info = await self.get_session_info(session_id)
if not session_info:
return None
return session_info
async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
session_info = await self.get_session_if_accessible(session_id)
if session_info is None:
raise SessionNotFoundError(session_id)
session_info.vector_store_id = vector_store_id
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}",
value=session_info.model_dump_json(),
)
async def add_turn_to_session(self, session_id: str, turn: Turn):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=turn.model_dump_json(),
)
async def get_session_turns(self, session_id: str) -> list[Turn]:
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
log.error(f"Error parsing turn: {e}")
continue
# The kvstore does not guarantee order, so we sort by started_at
# to ensure consistent ordering of turns.
turns.sort(key=lambda t: t.started_at)
return turns
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
value = await self.kvstore.get(
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
)
if not value:
return None
return Turn(**json.loads(value))
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
value=step.model_dump_json(),
)
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
)
return ToolExecutionStep(**json.loads(value)) if value else None
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
if not await self.get_session_if_accessible(session_id):
raise SessionNotFoundError(session_id)
await self.kvstore.set(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
value=str(num_infer_iters),
)
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
if not await self.get_session_if_accessible(session_id):
return None
value = await self.kvstore.get(
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
)
return int(value) if value else None
async def list_sessions(self) -> list[Session]:
values = await self.kvstore.values_in_range(
start_key=f"session:{self.agent_id}:",
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
)
sessions = []
for value in values:
try:
data = json.loads(value)
if "turn_id" in data:
continue
session_info = Session(**data)
sessions.append(session_info)
except Exception as e:
log.error(f"Error parsing session info: {e}")
continue
return sessions
async def delete_session_turns(self, session_id: str) -> None:
"""Delete all turns and their associated data for a session.
Args:
session_id: The ID of the session whose turns should be deleted.
"""
turns = await self.get_session_turns(session_id)
for turn in turns:
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
async def delete_session(self, session_id: str) -> None:
"""Delete a session and all its associated turns.
Args:
session_id: The ID of the session to delete.
Raises:
ValueError: If the session does not exist.
"""
session_info = await self.get_session_info(session_id)
if session_info is None:
raise SessionNotFoundError(session_id)
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")

View file

@ -8,7 +8,7 @@ from typing import Any
from tqdm import tqdm
from llama_stack.apis.agents import Agents, StepType
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
@ -18,13 +18,9 @@ from llama_stack.apis.inference import (
OpenAICompletionRequestWithExtraBody,
OpenAISystemMessageParam,
OpenAIUserMessageParam,
UserMessage,
)
from llama_stack.apis.scoring import Scoring
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -118,49 +114,6 @@ class MetaReferenceEvalImpl(
self.jobs[job_id] = res
return Job(job_id=job_id, status=JobStatus.completed)
async def _run_agent_generation(
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
candidate = benchmark_config.eval_candidate
create_response = await self.agents_api.create_agent(candidate.config)
agent_id = create_response.agent_id
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
session_id = session_create_response.session_id
turn_request = dict(
agent_id=agent_id,
session_id=session_id,
messages=input_messages,
stream=True,
)
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
final_event = turn_response[-1].event.payload
# check if there's a memory retrieval step and extract the context
memory_rag_context = None
for step in final_event.turn.steps:
if step.step_type == StepType.tool_execution.value:
for tool_response in step.tool_responses:
if tool_response.tool_name == MEMORY_QUERY_TOOL:
memory_rag_context = " ".join(x.text for x in tool_response.content)
agent_generation = {}
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
if memory_rag_context:
agent_generation[ColumnName.context.value] = memory_rag_context
generations.append(agent_generation)
return generations
async def _run_model_generation(
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
) -> list[dict[str, Any]]:
@ -215,9 +168,8 @@ class MetaReferenceEvalImpl(
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
candidate = benchmark_config.eval_candidate
if candidate.type == "agent":
generations = await self._run_agent_generation(input_rows, benchmark_config)
elif candidate.type == "model":
# Agent evaluation removed
if candidate.type == "model":
generations = await self._run_model_generation(input_rows, benchmark_config)
else:
raise ValueError(f"Invalid candidate type: {candidate.type}")