mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
chore!: remove the agents (sessions and turns) API (#4055)
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s
Some checks failed
SqlStore Integration Tests / test-postgres (3.12) (push) Failing after 0s
Integration Auth Tests / test-matrix (oauth2_token) (push) Failing after 1s
Test External Providers Installed via Module / test-external-providers-from-module (venv) (push) Has been skipped
Pre-commit / pre-commit (push) Failing after 3s
Python Package Build Test / build (3.12) (push) Failing after 2s
Python Package Build Test / build (3.13) (push) Failing after 2s
Vector IO Integration Tests / test-matrix (push) Failing after 4s
Integration Tests (Replay) / Integration Tests (, , , client=, ) (push) Failing after 5s
Test External API and Providers / test-external (venv) (push) Failing after 5s
SqlStore Integration Tests / test-postgres (3.13) (push) Failing after 9s
Unit Tests / unit-tests (3.13) (push) Failing after 5s
Unit Tests / unit-tests (3.12) (push) Failing after 6s
API Conformance Tests / check-schema-compatibility (push) Successful in 13s
UI Tests / ui-tests (22) (push) Successful in 1m10s
- Removes the deprecated agents (sessions and turns) API that was marked alpha in 0.3.0 - Cleans up unused imports and orphaned types after the API removal - Removes `SessionNotFoundError` and `AgentTurnInputType` which are no longer needed The agents API is completely superseded by the Responses + Conversations APIs, and the client SDK Agent class already uses those implementations. Corresponding client-side PR: https://github.com/llamastack/llama-stack-client-python/pull/295
This commit is contained in:
parent
a6ddbae0ed
commit
a8a8aa56c0
1037 changed files with 393 additions and 309806 deletions
|
|
@ -5,30 +5,13 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal, Protocol, runtime_checkable
|
||||
from typing import Annotated, Protocol, runtime_checkable
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
|
||||
from llama_stack.apis.common.responses import Order, PaginatedResponse
|
||||
from llama_stack.apis.inference import (
|
||||
CompletionMessage,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
ToolCall,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
ToolPromptFormat,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import SafetyViolation
|
||||
from llama_stack.apis.tools import ToolDef
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1, LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, register_schema, webmethod
|
||||
from llama_stack.apis.common.responses import Order
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1
|
||||
from llama_stack.schema_utils import ExtraBodyField, json_schema_type, webmethod
|
||||
|
||||
from .openai_responses import (
|
||||
ListOpenAIResponseInputItem,
|
||||
|
|
@ -57,658 +40,12 @@ class ResponseGuardrailSpec(BaseModel):
|
|||
ResponseGuardrail = str | ResponseGuardrailSpec
|
||||
|
||||
|
||||
class Attachment(BaseModel):
|
||||
"""An attachment to an agent turn.
|
||||
|
||||
:param content: The content of the attachment.
|
||||
:param mime_type: The MIME type of the attachment.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class Document(BaseModel):
|
||||
"""A document to be used by an agent.
|
||||
|
||||
:param content: The content of the document.
|
||||
:param mime_type: The MIME type of the document.
|
||||
"""
|
||||
|
||||
content: InterleavedContent | URL
|
||||
mime_type: str
|
||||
|
||||
|
||||
class StepCommon(BaseModel):
|
||||
"""A common step in an agent turn.
|
||||
|
||||
:param turn_id: The ID of the turn.
|
||||
:param step_id: The ID of the step.
|
||||
:param started_at: The time the step started.
|
||||
:param completed_at: The time the step completed.
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
step_id: str
|
||||
started_at: datetime | None = None
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
class StepType(StrEnum):
|
||||
"""Type of the step in an agent turn.
|
||||
|
||||
:cvar inference: The step is an inference step that calls an LLM.
|
||||
:cvar tool_execution: The step is a tool execution step that executes a tool call.
|
||||
:cvar shield_call: The step is a shield call step that checks for safety violations.
|
||||
:cvar memory_retrieval: The step is a memory retrieval step that retrieves context for vector dbs.
|
||||
"""
|
||||
|
||||
inference = "inference"
|
||||
tool_execution = "tool_execution"
|
||||
shield_call = "shield_call"
|
||||
memory_retrieval = "memory_retrieval"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class InferenceStep(StepCommon):
|
||||
"""An inference step in an agent turn.
|
||||
|
||||
:param model_response: The response from the LLM.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
step_type: Literal[StepType.inference] = StepType.inference
|
||||
model_response: CompletionMessage
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolExecutionStep(StepCommon):
|
||||
"""A tool execution step in an agent turn.
|
||||
|
||||
:param tool_calls: The tool calls to execute.
|
||||
:param tool_responses: The tool responses from the tool calls.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.tool_execution] = StepType.tool_execution
|
||||
tool_calls: list[ToolCall]
|
||||
tool_responses: list[ToolResponse]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ShieldCallStep(StepCommon):
|
||||
"""A shield call step in an agent turn.
|
||||
|
||||
:param violation: The violation from the shield call.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.shield_call] = StepType.shield_call
|
||||
violation: SafetyViolation | None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class MemoryRetrievalStep(StepCommon):
|
||||
"""A memory retrieval step in an agent turn.
|
||||
|
||||
:param vector_store_ids: The IDs of the vector databases to retrieve context from.
|
||||
:param inserted_context: The context retrieved from the vector databases.
|
||||
"""
|
||||
|
||||
step_type: Literal[StepType.memory_retrieval] = StepType.memory_retrieval
|
||||
# TODO: should this be List[str]?
|
||||
vector_store_ids: str
|
||||
inserted_context: InterleavedContent
|
||||
|
||||
|
||||
Step = Annotated[
|
||||
InferenceStep | ToolExecutionStep | ShieldCallStep | MemoryRetrievalStep,
|
||||
Field(discriminator="step_type"),
|
||||
]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Turn(BaseModel):
|
||||
"""A single turn in an interaction with an Agentic System.
|
||||
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param input_messages: List of messages that initiated this turn
|
||||
:param steps: Ordered list of processing steps executed during this turn
|
||||
:param output_message: The model's generated response containing content and metadata
|
||||
:param output_attachments: (Optional) Files or media attached to the agent's response
|
||||
:param started_at: Timestamp when the turn began
|
||||
:param completed_at: (Optional) Timestamp when the turn finished, if completed
|
||||
"""
|
||||
|
||||
turn_id: str
|
||||
session_id: str
|
||||
input_messages: list[UserMessage | ToolResponseMessage]
|
||||
steps: list[Step]
|
||||
output_message: CompletionMessage
|
||||
output_attachments: list[Attachment] | None = Field(default_factory=lambda: [])
|
||||
|
||||
started_at: datetime
|
||||
completed_at: datetime | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Session(BaseModel):
|
||||
"""A single session of an interaction with an Agentic System.
|
||||
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param session_name: Human-readable name for the session
|
||||
:param turns: List of all turns that have occurred in this session
|
||||
:param started_at: Timestamp when the session was created
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
session_name: str
|
||||
turns: list[Turn]
|
||||
started_at: datetime
|
||||
|
||||
|
||||
class AgentToolGroupWithArgs(BaseModel):
|
||||
name: str
|
||||
args: dict[str, Any]
|
||||
|
||||
|
||||
AgentToolGroup = str | AgentToolGroupWithArgs
|
||||
register_schema(AgentToolGroup, name="AgentTool")
|
||||
|
||||
|
||||
class AgentConfigCommon(BaseModel):
|
||||
sampling_params: SamplingParams | None = Field(default_factory=SamplingParams)
|
||||
|
||||
input_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
output_shields: list[str] | None = Field(default_factory=lambda: [])
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
client_tools: list[ToolDef] | None = Field(default_factory=lambda: [])
|
||||
tool_choice: ToolChoice | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_prompt_format: ToolPromptFormat | None = Field(default=None, deprecated="use tool_config instead")
|
||||
tool_config: ToolConfig | None = Field(default=None)
|
||||
|
||||
max_infer_iters: int | None = 10
|
||||
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_config:
|
||||
if self.tool_choice and self.tool_config.tool_choice != self.tool_choice:
|
||||
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
|
||||
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
|
||||
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
|
||||
else:
|
||||
params = {}
|
||||
if self.tool_choice:
|
||||
params["tool_choice"] = self.tool_choice
|
||||
if self.tool_prompt_format:
|
||||
params["tool_prompt_format"] = self.tool_prompt_format
|
||||
self.tool_config = ToolConfig(**params)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentConfig(AgentConfigCommon):
|
||||
"""Configuration for an agent.
|
||||
|
||||
:param model: The model identifier to use for the agent
|
||||
:param instructions: The system instructions for the agent
|
||||
:param name: Optional name for the agent, used in telemetry and identification
|
||||
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
||||
:param response_format: Optional response format configuration
|
||||
"""
|
||||
|
||||
model: str
|
||||
instructions: str
|
||||
name: str | None = None
|
||||
enable_session_persistence: bool | None = False
|
||||
response_format: ResponseFormat | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Agent(BaseModel):
|
||||
"""An agent instance with configuration and metadata.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param agent_config: Configuration settings for the agent
|
||||
:param created_at: Timestamp when the agent was created
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
agent_config: AgentConfig
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class AgentConfigOverridablePerTurn(AgentConfigCommon):
|
||||
instructions: str | None = None
|
||||
|
||||
|
||||
class AgentTurnResponseEventType(StrEnum):
|
||||
step_start = "step_start"
|
||||
step_complete = "step_complete"
|
||||
step_progress = "step_progress"
|
||||
|
||||
turn_start = "turn_start"
|
||||
turn_complete = "turn_complete"
|
||||
turn_awaiting_input = "turn_awaiting_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepStartPayload(BaseModel):
|
||||
"""Payload for step start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param metadata: (Optional) Additional metadata for the step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_start] = AgentTurnResponseEventType.step_start
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
metadata: dict[str, Any] | None = Field(default_factory=lambda: {})
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepCompletePayload(BaseModel):
|
||||
"""Payload for step completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param step_details: Complete details of the executed step
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_complete] = AgentTurnResponseEventType.step_complete
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
step_details: Step
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStepProgressPayload(BaseModel):
|
||||
"""Payload for step progress events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param step_type: Type of step being executed
|
||||
:param step_id: Unique identifier for the step within a turn
|
||||
:param delta: Incremental content changes during step execution
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.step_progress] = AgentTurnResponseEventType.step_progress
|
||||
step_type: StepType
|
||||
step_id: str
|
||||
|
||||
delta: ContentDelta
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnStartPayload(BaseModel):
|
||||
"""Payload for turn start events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_start] = AgentTurnResponseEventType.turn_start
|
||||
turn_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnCompletePayload(BaseModel):
|
||||
"""Payload for turn completion events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Complete turn data including all steps and results
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_complete] = AgentTurnResponseEventType.turn_complete
|
||||
turn: Turn
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseTurnAwaitingInputPayload(BaseModel):
|
||||
"""Payload for turn awaiting input events in agent turn responses.
|
||||
|
||||
:param event_type: Type of event being reported
|
||||
:param turn: Turn data when waiting for external tool responses
|
||||
"""
|
||||
|
||||
event_type: Literal[AgentTurnResponseEventType.turn_awaiting_input] = AgentTurnResponseEventType.turn_awaiting_input
|
||||
turn: Turn
|
||||
|
||||
|
||||
AgentTurnResponseEventPayload = Annotated[
|
||||
AgentTurnResponseStepStartPayload
|
||||
| AgentTurnResponseStepProgressPayload
|
||||
| AgentTurnResponseStepCompletePayload
|
||||
| AgentTurnResponseTurnStartPayload
|
||||
| AgentTurnResponseTurnCompletePayload
|
||||
| AgentTurnResponseTurnAwaitingInputPayload,
|
||||
Field(discriminator="event_type"),
|
||||
]
|
||||
register_schema(AgentTurnResponseEventPayload, name="AgentTurnResponseEventPayload")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseEvent(BaseModel):
|
||||
"""An event in an agent turn response stream.
|
||||
|
||||
:param payload: Event-specific payload containing event data
|
||||
"""
|
||||
|
||||
payload: AgentTurnResponseEventPayload
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent.
|
||||
|
||||
:param agent_id: Unique identifier for the created agent
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentSessionCreateResponse(BaseModel):
|
||||
"""Response returned when creating a new agent session.
|
||||
|
||||
:param session_id: Unique identifier for the created session
|
||||
"""
|
||||
|
||||
session_id: str
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnCreateRequest(AgentConfigOverridablePerTurn):
|
||||
"""Request to create a new turn for an agent.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param messages: List of messages to start the turn with
|
||||
:param documents: (Optional) List of documents to provide to the agent
|
||||
:param toolgroups: (Optional) List of tool groups to make available for this turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
:param tool_config: (Optional) Tool configuration to override agent defaults
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
|
||||
# TODO: figure out how we can simplify this and make why
|
||||
# ToolResponseMessage needs to be here (it is function call
|
||||
# execution from outside the system)
|
||||
messages: list[UserMessage | ToolResponseMessage]
|
||||
|
||||
documents: list[Document] | None = None
|
||||
toolgroups: list[AgentToolGroup] | None = Field(default_factory=lambda: [])
|
||||
|
||||
stream: bool | None = False
|
||||
tool_config: ToolConfig | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResumeRequest(BaseModel):
|
||||
"""Request to resume an agent turn with tool responses.
|
||||
|
||||
:param agent_id: Unique identifier for the agent
|
||||
:param session_id: Unique identifier for the conversation session
|
||||
:param turn_id: Unique identifier for the turn within a session
|
||||
:param tool_responses: List of tool responses to submit to continue the turn
|
||||
:param stream: (Optional) Whether to stream the response
|
||||
"""
|
||||
|
||||
agent_id: str
|
||||
session_id: str
|
||||
turn_id: str
|
||||
tool_responses: list[ToolResponse]
|
||||
stream: bool | None = False
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnResponseStreamChunk(BaseModel):
|
||||
"""Streamed agent turn completion response.
|
||||
|
||||
:param event: Individual event in the agent turn response stream
|
||||
"""
|
||||
|
||||
event: AgentTurnResponseEvent
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentStepResponse(BaseModel):
|
||||
"""Response containing details of a specific agent step.
|
||||
|
||||
:param step: The complete step data and execution details
|
||||
"""
|
||||
|
||||
step: Step
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class Agents(Protocol):
|
||||
"""Agents
|
||||
|
||||
APIs for creating and interacting with agentic systems."""
|
||||
|
||||
@webmethod(
|
||||
route="/agents",
|
||||
method="POST",
|
||||
descriptive_name="create_agent",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
"""Create an agent with the given configuration.
|
||||
|
||||
:param agent_config: The configuration for the agent.
|
||||
:returns: An AgentCreateResponse with the agent ID.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Create a new turn for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the turn for.
|
||||
:param session_id: The ID of the session to create the turn for.
|
||||
:param messages: List of messages to start the turn with.
|
||||
:param stream: (Optional) If True, generate an SSE event stream of the response. Defaults to False.
|
||||
:param documents: (Optional) List of documents to create the turn with.
|
||||
:param toolgroups: (Optional) List of toolgroups to create the turn with, will be used in addition to the agent's config toolgroups for the request.
|
||||
:param tool_config: (Optional) The tool configuration to create the turn with, will be used to override the agent's tool_config.
|
||||
:returns: If stream=False, returns a Turn object.
|
||||
If stream=True, returns an SSE event stream of AgentTurnResponseStreamChunk.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/resume",
|
||||
method="POST",
|
||||
descriptive_name="resume_agent_turn",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> Turn | AsyncIterator[AgentTurnResponseStreamChunk]:
|
||||
"""Resume an agent turn with executed tool call responses.
|
||||
|
||||
When a Turn has the status `awaiting_input` due to pending input from client side tool calls, this endpoint can be used to submit the outputs from the tool calls once they are ready.
|
||||
|
||||
:param agent_id: The ID of the agent to resume.
|
||||
:param session_id: The ID of the session to resume.
|
||||
:param turn_id: The ID of the turn to resume.
|
||||
:param tool_responses: The tool call responses to resume the turn with.
|
||||
:param stream: Whether to stream the response.
|
||||
:returns: A Turn object if stream is False, otherwise an AsyncIterator of AgentTurnResponseStreamChunk objects.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
) -> Turn:
|
||||
"""Retrieve an agent turn by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the turn for.
|
||||
:param session_id: The ID of the session to get the turn for.
|
||||
:param turn_id: The ID of the turn to get.
|
||||
:returns: A Turn.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}/step/{step_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_step(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
step_id: str,
|
||||
) -> AgentStepResponse:
|
||||
"""Retrieve an agent step by its ID.
|
||||
|
||||
:param agent_id: The ID of the agent to get the step for.
|
||||
:param session_id: The ID of the session to get the step for.
|
||||
:param turn_id: The ID of the turn to get the step for.
|
||||
:param step_id: The ID of the step to get.
|
||||
:returns: An AgentStepResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session",
|
||||
method="POST",
|
||||
descriptive_name="create_agent_session",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
"""Create a new session for an agent.
|
||||
|
||||
:param agent_id: The ID of the agent to create the session for.
|
||||
:param session_name: The name of the session to create.
|
||||
:returns: An AgentSessionCreateResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="GET",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
"""Retrieve an agent session by its ID.
|
||||
|
||||
:param session_id: The ID of the session to get.
|
||||
:param agent_id: The ID of the agent to get the session for.
|
||||
:param turn_ids: (Optional) List of turn IDs to filter the session by.
|
||||
:returns: A Session.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(
|
||||
route="/agents/{agent_id}/session/{session_id}",
|
||||
method="DELETE",
|
||||
level=LLAMA_STACK_API_V1ALPHA,
|
||||
)
|
||||
async def delete_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent session by its ID and its associated turns.
|
||||
|
||||
:param session_id: The ID of the session to delete.
|
||||
:param agent_id: The ID of the agent to delete the session for.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="DELETE", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def delete_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
) -> None:
|
||||
"""Delete an agent by its ID and its associated sessions and turns.
|
||||
|
||||
:param agent_id: The ID of the agent to delete.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
"""List all agents.
|
||||
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of agents to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
"""Describe an agent by its ID.
|
||||
|
||||
:param agent_id: ID of the agent.
|
||||
:returns: An Agent of the agent.
|
||||
"""
|
||||
...
|
||||
|
||||
@webmethod(route="/agents/{agent_id}/sessions", method="GET", level=LLAMA_STACK_API_V1ALPHA)
|
||||
async def list_agent_sessions(
|
||||
self,
|
||||
agent_id: str,
|
||||
start_index: int | None = None,
|
||||
limit: int | None = None,
|
||||
) -> PaginatedResponse:
|
||||
"""List all session(s) of a given agent.
|
||||
|
||||
:param agent_id: The ID of the agent to list sessions for.
|
||||
:param start_index: The index to start the pagination from.
|
||||
:param limit: The number of sessions to return.
|
||||
:returns: A PaginatedResponse.
|
||||
"""
|
||||
...
|
||||
|
||||
# We situate the OpenAI Responses API in the Agents API just like we did things
|
||||
# for Inference. The Responses API, in its intent, serves the same purpose as
|
||||
# the Agents API above -- it is essentially a lightweight "agentic loop" with
|
||||
|
|
|
|||
|
|
@ -56,14 +56,6 @@ class ToolGroupNotFoundError(ResourceNotFoundError):
|
|||
super().__init__(toolgroup_name, "Tool Group", "client.toolgroups.list()")
|
||||
|
||||
|
||||
class SessionNotFoundError(ValueError):
|
||||
"""raised when Llama Stack cannot find a referenced session or access is denied"""
|
||||
|
||||
def __init__(self, session_name: str) -> None:
|
||||
message = f"Session '{session_name}' not found or access denied."
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
class ModelTypeError(TypeError):
|
||||
"""raised when a model is present but not the correct type"""
|
||||
|
||||
|
|
|
|||
|
|
@ -103,17 +103,6 @@ class CompletionInputType(BaseModel):
|
|||
type: Literal["completion_input"] = "completion_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentTurnInputType(BaseModel):
|
||||
"""Parameter type for agent turn input.
|
||||
|
||||
:param type: Discriminator type. Always "agent_turn_input"
|
||||
"""
|
||||
|
||||
# expects List[Message] for messages (may also include attachments?)
|
||||
type: Literal["agent_turn_input"] = "agent_turn_input"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DialogType(BaseModel):
|
||||
"""Parameter type for dialog data with semantic output labels.
|
||||
|
|
@ -135,8 +124,7 @@ ParamType = Annotated[
|
|||
| JsonType
|
||||
| UnionType
|
||||
| ChatCompletionInputType
|
||||
| CompletionInputType
|
||||
| AgentTurnInputType,
|
||||
| CompletionInputType,
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(ParamType, name="ParamType")
|
||||
|
|
|
|||
|
|
@ -4,17 +4,16 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Annotated, Any, Literal, Protocol
|
||||
from typing import Any, Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig
|
||||
from llama_stack.apis.common.job_types import Job
|
||||
from llama_stack.apis.inference import SamplingParams, SystemMessage
|
||||
from llama_stack.apis.scoring import ScoringResult
|
||||
from llama_stack.apis.scoring_functions import ScoringFnParams
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_V1ALPHA
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
@ -32,19 +31,7 @@ class ModelCandidate(BaseModel):
|
|||
system_message: SystemMessage | None = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class AgentCandidate(BaseModel):
|
||||
"""An agent candidate for evaluation.
|
||||
|
||||
:param config: The configuration for the agent candidate.
|
||||
"""
|
||||
|
||||
type: Literal["agent"] = "agent"
|
||||
config: AgentConfig
|
||||
|
||||
|
||||
EvalCandidate = Annotated[ModelCandidate | AgentCandidate, Field(discriminator="type")]
|
||||
register_schema(EvalCandidate, name="EvalCandidate")
|
||||
EvalCandidate = ModelCandidate
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -4,21 +4,9 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import (
|
||||
Agent,
|
||||
AgentConfig,
|
||||
AgentCreateResponse,
|
||||
Agents,
|
||||
AgentSessionCreateResponse,
|
||||
AgentStepResponse,
|
||||
AgentToolGroup,
|
||||
AgentTurnCreateRequest,
|
||||
AgentTurnResumeRequest,
|
||||
Document,
|
||||
ListOpenAIResponseInputItem,
|
||||
ListOpenAIResponseObject,
|
||||
OpenAIDeleteResponseObject,
|
||||
|
|
@ -26,19 +14,12 @@ from llama_stack.apis.agents import (
|
|||
OpenAIResponseInputTool,
|
||||
OpenAIResponseObject,
|
||||
Order,
|
||||
Session,
|
||||
Turn,
|
||||
)
|
||||
from llama_stack.apis.agents.agents import ResponseGuardrail
|
||||
from llama_stack.apis.agents.openai_responses import OpenAIResponsePrompt, OpenAIResponseText
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.apis.conversations import Conversations
|
||||
from llama_stack.apis.inference import (
|
||||
Inference,
|
||||
ToolConfig,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.tools import ToolGroups, ToolRuntime
|
||||
|
|
@ -46,12 +27,9 @@ from llama_stack.apis.vector_io import VectorIO
|
|||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl
|
||||
from llama_stack.providers.utils.pagination import paginate_records
|
||||
from llama_stack.providers.utils.responses.responses_store import ResponsesStore
|
||||
|
||||
from .agent_instance import ChatAgent
|
||||
from .config import MetaReferenceAgentsImplConfig
|
||||
from .persistence import AgentInfo
|
||||
from .responses.openai_responses import OpenAIResponsesImpl
|
||||
|
||||
logger = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
|
@ -97,229 +75,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
conversations_api=self.conversations_api,
|
||||
)
|
||||
|
||||
async def create_agent(
|
||||
self,
|
||||
agent_config: AgentConfig,
|
||||
) -> AgentCreateResponse:
|
||||
agent_id = str(uuid.uuid4())
|
||||
created_at = datetime.now(UTC)
|
||||
|
||||
agent_info = AgentInfo(
|
||||
**agent_config.model_dump(),
|
||||
created_at=created_at,
|
||||
)
|
||||
|
||||
# Store the agent info
|
||||
await self.persistence_store.set(
|
||||
key=f"agent:{agent_id}",
|
||||
value=agent_info.model_dump_json(),
|
||||
)
|
||||
|
||||
return AgentCreateResponse(
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
async def _get_agent_impl(self, agent_id: str) -> ChatAgent:
|
||||
agent_info_json = await self.persistence_store.get(
|
||||
key=f"agent:{agent_id}",
|
||||
)
|
||||
if not agent_info_json:
|
||||
raise ValueError(f"Could not find agent info for {agent_id}")
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Could not validate agent info for {agent_id}") from e
|
||||
|
||||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
tool_runtime_api=self.tool_runtime_api,
|
||||
tool_groups_api=self.tool_groups_api,
|
||||
persistence_store=(
|
||||
self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store
|
||||
),
|
||||
created_at=agent_info.created_at.isoformat(),
|
||||
policy=self.policy,
|
||||
telemetry_enabled=self.telemetry_enabled,
|
||||
)
|
||||
|
||||
async def create_agent_session(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_name: str,
|
||||
) -> AgentSessionCreateResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_id = await agent.create_session(session_name)
|
||||
return AgentSessionCreateResponse(
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def create_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
messages: list[UserMessage | ToolResponseMessage],
|
||||
stream: bool | None = False,
|
||||
documents: list[Document] | None = None,
|
||||
toolgroups: list[AgentToolGroup] | None = None,
|
||||
tool_config: ToolConfig | None = None,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnCreateRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=messages,
|
||||
stream=True,
|
||||
toolgroups=toolgroups,
|
||||
documents=documents,
|
||||
tool_config=tool_config,
|
||||
)
|
||||
if stream:
|
||||
return self._create_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _create_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnCreateRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.create_and_execute_turn(request):
|
||||
yield event
|
||||
|
||||
async def resume_agent_turn(
|
||||
self,
|
||||
agent_id: str,
|
||||
session_id: str,
|
||||
turn_id: str,
|
||||
tool_responses: list[ToolResponse],
|
||||
stream: bool | None = False,
|
||||
) -> AsyncGenerator:
|
||||
request = AgentTurnResumeRequest(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
turn_id=turn_id,
|
||||
tool_responses=tool_responses,
|
||||
stream=stream,
|
||||
)
|
||||
if stream:
|
||||
return self._continue_agent_turn_streaming(request)
|
||||
else:
|
||||
raise NotImplementedError("Non-streaming agent turns not yet implemented")
|
||||
|
||||
async def _continue_agent_turn_streaming(
|
||||
self,
|
||||
request: AgentTurnResumeRequest,
|
||||
) -> AsyncGenerator:
|
||||
agent = await self._get_agent_impl(request.agent_id)
|
||||
async for event in agent.resume_turn(request):
|
||||
yield event
|
||||
|
||||
async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
turn = await agent.storage.get_session_turn(session_id, turn_id)
|
||||
if turn is None:
|
||||
raise ValueError(f"Turn {turn_id} not found in session {session_id}")
|
||||
return turn
|
||||
|
||||
async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse:
|
||||
turn = await self.get_agents_turn(agent_id, session_id, turn_id)
|
||||
for step in turn.steps:
|
||||
if step.step_id == step_id:
|
||||
return AgentStepResponse(step=step)
|
||||
raise ValueError(f"Provided step_id {step_id} could not be found")
|
||||
|
||||
async def get_agents_session(
|
||||
self,
|
||||
session_id: str,
|
||||
agent_id: str,
|
||||
turn_ids: list[str] | None = None,
|
||||
) -> Session:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
session_info = await agent.storage.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise ValueError(f"Session {session_id} not found")
|
||||
turns = await agent.storage.get_session_turns(session_id)
|
||||
if turn_ids:
|
||||
turns = [turn for turn in turns if turn.turn_id in turn_ids]
|
||||
return Session(
|
||||
session_name=session_info.session_name,
|
||||
session_id=session_id,
|
||||
turns=turns,
|
||||
started_at=session_info.started_at,
|
||||
)
|
||||
|
||||
async def delete_agents_session(self, session_id: str, agent_id: str) -> None:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
|
||||
# Delete turns first, then the session
|
||||
await agent.storage.delete_session_turns(session_id)
|
||||
await agent.storage.delete_session(session_id)
|
||||
|
||||
async def delete_agent(self, agent_id: str) -> None:
|
||||
# First get all sessions for this agent
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
|
||||
# Delete all sessions
|
||||
for session in sessions:
|
||||
await self.delete_agents_session(agent_id, session.session_id)
|
||||
|
||||
# Finally delete the agent itself
|
||||
await self.persistence_store.delete(f"agent:{agent_id}")
|
||||
|
||||
async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse:
|
||||
agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff")
|
||||
agent_list: list[Agent] = []
|
||||
for agent_key in agent_keys:
|
||||
agent_id = agent_key.split(":")[1]
|
||||
|
||||
# Get the agent info using the key
|
||||
agent_info_json = await self.persistence_store.get(agent_key)
|
||||
if not agent_info_json:
|
||||
logger.error(f"Could not find agent info for key {agent_key}")
|
||||
continue
|
||||
|
||||
try:
|
||||
agent_info = AgentInfo.model_validate_json(agent_info_json)
|
||||
agent_list.append(
|
||||
Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_info,
|
||||
created_at=agent_info.created_at,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error parsing agent info for {agent_id}: {e}")
|
||||
continue
|
||||
|
||||
# Convert Agent objects to dictionaries
|
||||
agent_dicts = [agent.model_dump() for agent in agent_list]
|
||||
return paginate_records(agent_dicts, start_index, limit)
|
||||
|
||||
async def get_agent(self, agent_id: str) -> Agent:
|
||||
chat_agent = await self._get_agent_impl(agent_id)
|
||||
agent = Agent(
|
||||
agent_id=agent_id,
|
||||
agent_config=chat_agent.agent_config,
|
||||
created_at=datetime.fromisoformat(chat_agent.created_at),
|
||||
)
|
||||
return agent
|
||||
|
||||
async def list_agent_sessions(
|
||||
self, agent_id: str, start_index: int | None = None, limit: int | None = None
|
||||
) -> PaginatedResponse:
|
||||
agent = await self._get_agent_impl(agent_id)
|
||||
sessions = await agent.storage.list_sessions()
|
||||
# Convert Session objects to dictionaries
|
||||
session_dicts = [session.model_dump() for session in sessions]
|
||||
return paginate_records(session_dicts, start_index, limit)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
|
|
|
|||
|
|
@ -1,261 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
|
||||
from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn
|
||||
from llama_stack.apis.common.errors import SessionNotFoundError
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed
|
||||
from llama_stack.core.access_control.conditions import User as ProtocolUser
|
||||
from llama_stack.core.access_control.datatypes import AccessRule, Action
|
||||
from llama_stack.core.datatypes import User
|
||||
from llama_stack.core.request_headers import get_authenticated_user
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
|
||||
log = get_logger(name=__name__, category="agents::meta_reference")
|
||||
|
||||
|
||||
class AgentSessionInfo(Session):
|
||||
# TODO: is this used anywhere?
|
||||
vector_store_id: str | None = None
|
||||
started_at: datetime
|
||||
owner: User | None = None
|
||||
identifier: str | None = None
|
||||
type: str = "session"
|
||||
|
||||
|
||||
class AgentInfo(AgentConfig):
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@dataclass
|
||||
class SessionResource:
|
||||
"""Concrete implementation of ProtectedResource for session access control."""
|
||||
|
||||
type: str
|
||||
identifier: str
|
||||
owner: ProtocolUser # Use the protocol type for structural compatibility
|
||||
|
||||
|
||||
class AgentPersistence:
|
||||
def __init__(self, agent_id: str, kvstore: KVStore, policy: list[AccessRule]):
|
||||
self.agent_id = agent_id
|
||||
self.kvstore = kvstore
|
||||
self.policy = policy
|
||||
|
||||
async def create_session(self, name: str) -> str:
|
||||
session_id = str(uuid.uuid4())
|
||||
|
||||
# Get current user's auth attributes for new sessions
|
||||
user = get_authenticated_user()
|
||||
|
||||
session_info = AgentSessionInfo(
|
||||
session_id=session_id,
|
||||
session_name=name,
|
||||
started_at=datetime.now(UTC),
|
||||
owner=user,
|
||||
turns=[],
|
||||
identifier=name, # should this be qualified in any way?
|
||||
)
|
||||
# Only perform access control if we have an authenticated user
|
||||
if user is not None and session_info.identifier is not None:
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=user,
|
||||
)
|
||||
if not is_action_allowed(self.policy, Action.CREATE, resource, user):
|
||||
raise AccessDeniedError(Action.CREATE, resource, user)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
return session_id
|
||||
|
||||
async def get_session_info(self, session_id: str) -> AgentSessionInfo | None:
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
)
|
||||
if not value:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info = AgentSessionInfo(**json.loads(value))
|
||||
|
||||
# Check access to session
|
||||
if not self._check_session_access(session_info):
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
def _check_session_access(self, session_info: AgentSessionInfo) -> bool:
|
||||
"""Check if current user has access to the session."""
|
||||
# Handle backward compatibility for old sessions without access control
|
||||
if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"):
|
||||
return True
|
||||
|
||||
# Get current user - if None, skip access control (e.g., in tests)
|
||||
user = get_authenticated_user()
|
||||
if user is None:
|
||||
return True
|
||||
|
||||
# Access control requires identifier and owner to be set
|
||||
if session_info.identifier is None or session_info.owner is None:
|
||||
return True
|
||||
|
||||
# At this point, both identifier and owner are guaranteed to be non-None
|
||||
resource = SessionResource(
|
||||
type=session_info.type,
|
||||
identifier=session_info.identifier,
|
||||
owner=session_info.owner,
|
||||
)
|
||||
return is_action_allowed(self.policy, Action.READ, resource, user)
|
||||
|
||||
async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None:
|
||||
"""Get session info if the user has access to it. For internal use by sub-session methods."""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if not session_info:
|
||||
return None
|
||||
|
||||
return session_info
|
||||
|
||||
async def add_vector_db_to_session(self, session_id: str, vector_store_id: str):
|
||||
session_info = await self.get_session_if_accessible(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
session_info.vector_store_id = vector_store_id
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}",
|
||||
value=session_info.model_dump_json(),
|
||||
)
|
||||
|
||||
async def add_turn_to_session(self, session_id: str, turn: Turn):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
|
||||
value=turn.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_session_turns(self, session_id: str) -> list[Turn]:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:{session_id}:",
|
||||
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
turns = []
|
||||
for value in values:
|
||||
try:
|
||||
turn = Turn(**json.loads(value))
|
||||
turns.append(turn)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing turn: {e}")
|
||||
continue
|
||||
|
||||
# The kvstore does not guarantee order, so we sort by started_at
|
||||
# to ensure consistent ordering of turns.
|
||||
turns.sort(key=lambda t: t.started_at)
|
||||
|
||||
return turns
|
||||
|
||||
async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"session:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
if not value:
|
||||
return None
|
||||
return Turn(**json.loads(value))
|
||||
|
||||
async def set_in_progress_tool_call_step(self, session_id: str, turn_id: str, step: ToolExecutionStep):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=step.model_dump_json(),
|
||||
)
|
||||
|
||||
async def get_in_progress_tool_call_step(self, session_id: str, turn_id: str) -> ToolExecutionStep | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"in_progress_tool_call_step:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return ToolExecutionStep(**json.loads(value)) if value else None
|
||||
|
||||
async def set_num_infer_iters_in_turn(self, session_id: str, turn_id: str, num_infer_iters: int):
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.set(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
value=str(num_infer_iters),
|
||||
)
|
||||
|
||||
async def get_num_infer_iters_in_turn(self, session_id: str, turn_id: str) -> int | None:
|
||||
if not await self.get_session_if_accessible(session_id):
|
||||
return None
|
||||
|
||||
value = await self.kvstore.get(
|
||||
key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}",
|
||||
)
|
||||
return int(value) if value else None
|
||||
|
||||
async def list_sessions(self) -> list[Session]:
|
||||
values = await self.kvstore.values_in_range(
|
||||
start_key=f"session:{self.agent_id}:",
|
||||
end_key=f"session:{self.agent_id}:\xff\xff\xff\xff",
|
||||
)
|
||||
sessions = []
|
||||
for value in values:
|
||||
try:
|
||||
data = json.loads(value)
|
||||
if "turn_id" in data:
|
||||
continue
|
||||
|
||||
session_info = Session(**data)
|
||||
sessions.append(session_info)
|
||||
except Exception as e:
|
||||
log.error(f"Error parsing session info: {e}")
|
||||
continue
|
||||
return sessions
|
||||
|
||||
async def delete_session_turns(self, session_id: str) -> None:
|
||||
"""Delete all turns and their associated data for a session.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session whose turns should be deleted.
|
||||
"""
|
||||
turns = await self.get_session_turns(session_id)
|
||||
for turn in turns:
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}")
|
||||
|
||||
async def delete_session(self, session_id: str) -> None:
|
||||
"""Delete a session and all its associated turns.
|
||||
|
||||
Args:
|
||||
session_id: The ID of the session to delete.
|
||||
|
||||
Raises:
|
||||
ValueError: If the session does not exist.
|
||||
"""
|
||||
session_info = await self.get_session_info(session_id)
|
||||
if session_info is None:
|
||||
raise SessionNotFoundError(session_id)
|
||||
|
||||
await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}")
|
||||
|
|
@ -8,7 +8,7 @@ from typing import Any
|
|||
|
||||
from tqdm import tqdm
|
||||
|
||||
from llama_stack.apis.agents import Agents, StepType
|
||||
from llama_stack.apis.agents import Agents
|
||||
from llama_stack.apis.benchmarks import Benchmark
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
|
|
@ -18,13 +18,9 @@ from llama_stack.apis.inference import (
|
|||
OpenAICompletionRequestWithExtraBody,
|
||||
OpenAISystemMessageParam,
|
||||
OpenAIUserMessageParam,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
|
||||
MEMORY_QUERY_TOOL,
|
||||
)
|
||||
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
|
||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||
|
||||
|
|
@ -118,49 +114,6 @@ class MetaReferenceEvalImpl(
|
|||
self.jobs[job_id] = res
|
||||
return Job(job_id=job_id, status=JobStatus.completed)
|
||||
|
||||
async def _run_agent_generation(
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
create_response = await self.agents_api.create_agent(candidate.config)
|
||||
agent_id = create_response.agent_id
|
||||
|
||||
generations = []
|
||||
for i, x in tqdm(enumerate(input_rows)):
|
||||
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
|
||||
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
|
||||
input_messages = [UserMessage(**x) for x in input_messages if x["role"] == "user"]
|
||||
|
||||
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
|
||||
session_create_response = await self.agents_api.create_agent_session(agent_id, f"session-{i}")
|
||||
session_id = session_create_response.session_id
|
||||
|
||||
turn_request = dict(
|
||||
agent_id=agent_id,
|
||||
session_id=session_id,
|
||||
messages=input_messages,
|
||||
stream=True,
|
||||
)
|
||||
turn_response = [chunk async for chunk in await self.agents_api.create_agent_turn(**turn_request)]
|
||||
final_event = turn_response[-1].event.payload
|
||||
|
||||
# check if there's a memory retrieval step and extract the context
|
||||
memory_rag_context = None
|
||||
for step in final_event.turn.steps:
|
||||
if step.step_type == StepType.tool_execution.value:
|
||||
for tool_response in step.tool_responses:
|
||||
if tool_response.tool_name == MEMORY_QUERY_TOOL:
|
||||
memory_rag_context = " ".join(x.text for x in tool_response.content)
|
||||
|
||||
agent_generation = {}
|
||||
agent_generation[ColumnName.generated_answer.value] = final_event.turn.output_message.content
|
||||
if memory_rag_context:
|
||||
agent_generation[ColumnName.context.value] = memory_rag_context
|
||||
|
||||
generations.append(agent_generation)
|
||||
|
||||
return generations
|
||||
|
||||
async def _run_model_generation(
|
||||
self, input_rows: list[dict[str, Any]], benchmark_config: BenchmarkConfig
|
||||
) -> list[dict[str, Any]]:
|
||||
|
|
@ -215,9 +168,8 @@ class MetaReferenceEvalImpl(
|
|||
benchmark_config: BenchmarkConfig,
|
||||
) -> EvaluateResponse:
|
||||
candidate = benchmark_config.eval_candidate
|
||||
if candidate.type == "agent":
|
||||
generations = await self._run_agent_generation(input_rows, benchmark_config)
|
||||
elif candidate.type == "model":
|
||||
# Agent evaluation removed
|
||||
if candidate.type == "model":
|
||||
generations = await self._run_model_generation(input_rows, benchmark_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid candidate type: {candidate.type}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue