diff --git a/pyproject.toml b/pyproject.toml index 999c3d9a3..e99299dab 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -284,12 +284,6 @@ exclude = [ "^src/llama_stack/models/llama/llama3/interface\\.py$", "^src/llama_stack/models/llama/llama3/tokenizer\\.py$", "^src/llama_stack/models/llama/llama3/tool_utils\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/agents\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/agent_instance\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/config\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/persistence\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/safety\\.py$", - "^src/llama_stack/providers/inline/agents/meta_reference/__init__\\.py$", "^src/llama_stack/providers/inline/datasetio/localfs/", "^src/llama_stack/providers/inline/eval/meta_reference/eval\\.py$", "^src/llama_stack/providers/inline/inference/meta_reference/inference\\.py$", diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 55bf31f57..b6fad553a 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -11,7 +11,7 @@ import uuid import warnings from collections.abc import AsyncGenerator from datetime import UTC, datetime -from typing import Any +from typing import Any, cast import httpx @@ -363,7 +363,7 @@ class ChatAgent(ShieldRunnerMixin): if self.input_shields: async for res in self.run_multiple_shields_wrapper( - turn_id, input_messages, self.input_shields, "user-input" + turn_id, cast(list[OpenAIMessageParam], input_messages), self.input_shields, "user-input" ): if isinstance(res, bool): return @@ -392,7 +392,7 @@ class ChatAgent(ShieldRunnerMixin): if self.output_shields: async for res in self.run_multiple_shields_wrapper( - turn_id, messages, self.output_shields, "assistant-output" + turn_id, cast(list[OpenAIMessageParam], messages), self.output_shields, "assistant-output" ): if isinstance(res, bool): return @@ -404,7 +404,7 @@ class ChatAgent(ShieldRunnerMixin): async def run_multiple_shields_wrapper( self, turn_id: str, - messages: list[Message], + messages: list[OpenAIMessageParam], shields: list[str], touchpoint: str, ) -> AsyncGenerator: diff --git a/src/llama_stack/providers/inline/agents/meta_reference/agents.py b/src/llama_stack/providers/inline/agents/meta_reference/agents.py index b4b77bacd..cd6158e50 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -21,6 +21,7 @@ from llama_stack.apis.agents import ( Document, ListOpenAIResponseInputItem, ListOpenAIResponseObject, + OpenAIDeleteResponseObject, OpenAIResponseInput, OpenAIResponseInputTool, OpenAIResponseObject, @@ -141,7 +142,7 @@ class MetaReferenceAgentsImpl(Agents): persistence_store=( self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), - created_at=agent_info.created_at, + created_at=agent_info.created_at.isoformat(), policy=self.policy, telemetry_enabled=self.telemetry_enabled, ) @@ -163,9 +164,9 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_id: str, messages: list[UserMessage | ToolResponseMessage], - toolgroups: list[AgentToolGroup] | None = None, - documents: list[Document] | None = None, stream: bool | None = False, + documents: list[Document] | None = None, + toolgroups: list[AgentToolGroup] | None = None, tool_config: ToolConfig | None = None, ) -> AsyncGenerator: request = AgentTurnCreateRequest( @@ -221,6 +222,8 @@ class MetaReferenceAgentsImpl(Agents): async def get_agents_turn(self, agent_id: str, session_id: str, turn_id: str) -> Turn: agent = await self._get_agent_impl(agent_id) turn = await agent.storage.get_session_turn(session_id, turn_id) + if turn is None: + raise ValueError(f"Turn {turn_id} not found in session {session_id}") return turn async def get_agents_step(self, agent_id: str, session_id: str, turn_id: str, step_id: str) -> AgentStepResponse: @@ -232,13 +235,15 @@ class MetaReferenceAgentsImpl(Agents): async def get_agents_session( self, - agent_id: str, session_id: str, + agent_id: str, turn_ids: list[str] | None = None, ) -> Session: agent = await self._get_agent_impl(agent_id) session_info = await agent.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") turns = await agent.storage.get_session_turns(session_id) if turn_ids: turns = [turn for turn in turns if turn.turn_id in turn_ids] @@ -249,7 +254,7 @@ class MetaReferenceAgentsImpl(Agents): started_at=session_info.started_at, ) - async def delete_agents_session(self, agent_id: str, session_id: str) -> None: + async def delete_agents_session(self, session_id: str, agent_id: str) -> None: agent = await self._get_agent_impl(agent_id) # Delete turns first, then the session @@ -302,7 +307,7 @@ class MetaReferenceAgentsImpl(Agents): agent = Agent( agent_id=agent_id, agent_config=chat_agent.agent_config, - created_at=chat_agent.created_at, + created_at=datetime.fromisoformat(chat_agent.created_at), ) return agent @@ -323,6 +328,7 @@ class MetaReferenceAgentsImpl(Agents): self, response_id: str, ) -> OpenAIResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.get_openai_response(response_id) async def create_openai_response( @@ -342,7 +348,12 @@ class MetaReferenceAgentsImpl(Agents): max_infer_iters: int | None = 10, guardrails: list[ResponseGuardrail] | None = None, ) -> OpenAIResponseObject: - return await self.openai_responses_impl.create_openai_response( + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" + from llama_stack.apis.agents.agents import ResponseGuardrailSpec + from typing import cast as typing_cast + # Cast guardrails to the more specific type expected by the implementation + guardrails_spec = typing_cast(list[ResponseGuardrailSpec] | None, guardrails) + result = await self.openai_responses_impl.create_openai_response( input, model, prompt, @@ -356,8 +367,9 @@ class MetaReferenceAgentsImpl(Agents): tools, include, max_infer_iters, - guardrails, + guardrails_spec, ) + return typing_cast(OpenAIResponseObject, result) async def list_openai_responses( self, @@ -366,6 +378,7 @@ class MetaReferenceAgentsImpl(Agents): model: str | None = None, order: Order | None = Order.desc, ) -> ListOpenAIResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.list_openai_responses(after, limit, model, order) async def list_openai_response_input_items( @@ -377,9 +390,11 @@ class MetaReferenceAgentsImpl(Agents): limit: int | None = 20, order: Order | None = Order.desc, ) -> ListOpenAIResponseInputItem: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.list_openai_response_input_items( response_id, after, before, include, limit, order ) - async def delete_openai_response(self, response_id: str) -> None: + async def delete_openai_response(self, response_id: str) -> OpenAIDeleteResponseObject: + assert self.openai_responses_impl is not None, "OpenAI responses not initialized" return await self.openai_responses_impl.delete_openai_response(response_id) diff --git a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py index 26a2151e3..3a2affbd8 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -7,11 +7,12 @@ import json import uuid from datetime import UTC, datetime +from typing import cast from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.apis.common.errors import SessionNotFoundError from llama_stack.core.access_control.access_control import AccessDeniedError, is_action_allowed -from llama_stack.core.access_control.datatypes import AccessRule +from llama_stack.core.access_control.datatypes import AccessRule, Action from llama_stack.core.datatypes import User from llama_stack.core.request_headers import get_authenticated_user from llama_stack.log import get_logger @@ -53,8 +54,12 @@ class AgentPersistence: turns=[], identifier=name, # should this be qualified in any way? ) - if not is_action_allowed(self.policy, "create", session_info, user): - raise AccessDeniedError("create", session_info, user) + # Both identifier and owner are set above, safe to use for access control + assert session_info.identifier is not None and session_info.owner is not None + from llama_stack.core.access_control.conditions import ProtectedResource + resource = cast(ProtectedResource, session_info) + if not is_action_allowed(self.policy, Action.CREATE, resource, user): + raise AccessDeniedError(Action.CREATE, resource, user) await self.kvstore.set( key=f"session:{self.agent_id}:{session_id}", @@ -62,7 +67,7 @@ class AgentPersistence: ) return session_id - async def get_session_info(self, session_id: str) -> AgentSessionInfo: + async def get_session_info(self, session_id: str) -> AgentSessionInfo | None: value = await self.kvstore.get( key=f"session:{self.agent_id}:{session_id}", ) @@ -83,7 +88,15 @@ class AgentPersistence: if not hasattr(session_info, "access_attributes") and not hasattr(session_info, "owner"): return True - return is_action_allowed(self.policy, "read", session_info, get_authenticated_user()) + # Access control requires identifier and owner to be set + if session_info.identifier is None or session_info.owner is None: + return True + + # At this point, both identifier and owner are guaranteed to be non-None + assert session_info.identifier is not None and session_info.owner is not None + from llama_stack.core.access_control.conditions import ProtectedResource + resource = cast(ProtectedResource, session_info) + return is_action_allowed(self.policy, Action.READ, resource, get_authenticated_user()) async def get_session_if_accessible(self, session_id: str) -> AgentSessionInfo | None: """Get session info if the user has access to it. For internal use by sub-session methods.""" diff --git a/src/llama_stack/providers/inline/agents/meta_reference/safety.py b/src/llama_stack/providers/inline/agents/meta_reference/safety.py index 9baf5a14d..f0ae51423 100644 --- a/src/llama_stack/providers/inline/agents/meta_reference/safety.py +++ b/src/llama_stack/providers/inline/agents/meta_reference/safety.py @@ -6,7 +6,7 @@ import asyncio -from llama_stack.apis.inference import Message +from llama_stack.apis.inference import OpenAIMessageParam from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel from llama_stack.core.telemetry import tracing from llama_stack.log import get_logger @@ -31,7 +31,7 @@ class ShieldRunnerMixin: self.input_shields = input_shields self.output_shields = output_shields - async def run_multiple_shields(self, messages: list[Message], identifiers: list[str]) -> None: + async def run_multiple_shields(self, messages: list[OpenAIMessageParam], identifiers: list[str]) -> None: async def run_shield_with_span(identifier: str): async with tracing.span(f"run_shield_{identifier}"): return await self.safety_api.run_shield(