From c91e3552a3d2644a9068250594d88764b1cbeaf9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 7 May 2025 14:49:23 +0200 Subject: [PATCH] feat: implementation for agent/session list and describe (#1606) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Create a new agent: ``` curl --request POST \ --url http://localhost:8321/v1/agents \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } } }' ``` Get agent: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f {"agent_id":"9abad4ab-2c77-45f9-9d16-46b79d2bea1f","agent_config":{"sampling_params":{"strategy":{"type":"greedy"},"max_tokens":0,"repetition_penalty":1.0},"input_shields":["string"],"output_shields":["string"],"toolgroups":["string"],"client_tools":[{"name":"string","description":"string","parameters":[{"name":"string","parameter_type":"string","description":"string","required":true,"default":null}],"metadata":{"property1":null,"property2":null}}],"tool_choice":"auto","tool_prompt_format":"json","tool_config":{"tool_choice":"auto","tool_prompt_format":"json","system_message_behavior":"append"},"max_infer_iters":10,"model":"string","instructions":"string","enable_session_persistence":false,"response_format":{"type":"json_schema","json_schema":{"property1":null,"property2":null}}},"created_at":"2025-03-12T16:18:28.369144Z"}% ``` List agents: ``` curl http://127.0.0.1:8321/v1/agents|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 1680 100 1680 0 0 498k 0 --:--:-- --:--:-- --:--:-- 546k { "data": [ { "agent_id": "9abad4ab-2c77-45f9-9d16-46b79d2bea1f", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:18:28.369144Z" }, { "agent_id": "a6643aaa-96dd-46db-a405-333dc504b168", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:17:12.811273Z" } ] } ``` Create sessions: ``` curl --request POST \ --url http://localhost:8321/v1/agents/{agent_id}/session \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "session_name": "string" }' ``` List sessions: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f/sessions|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 263 100 263 0 0 90099 0 --:--:-- --:--:-- --:--:-- 128k [ { "session_id": "2b15c4fc-e348-46c1-ae32-f6d424441ac1", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:17.784328" }, { "session_id": "9432472d-d483-4b73-b682-7b1d35d64111", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:19.885834" } ] ``` Signed-off-by: Sébastien Han --- docs/_static/llama-stack-spec.html | 83 ++++--- docs/_static/llama-stack-spec.yaml | 62 ++--- llama_stack/apis/agents/agents.py | 29 +-- llama_stack/distribution/store/registry.py | 4 +- .../agents/meta_reference/agent_instance.py | 2 + .../inline/agents/meta_reference/agents.py | 108 ++++++-- .../agents/meta_reference/persistence.py | 56 ++++- .../inline/datasetio/localfs/datasetio.py | 2 +- .../inline/eval/meta_reference/eval.py | 2 +- .../providers/inline/vector_io/faiss/faiss.py | 2 +- .../datasetio/huggingface/huggingface.py | 2 +- llama_stack/providers/utils/kvstore/api.py | 4 +- .../providers/utils/kvstore/kvstore.py | 9 +- .../utils/kvstore/mongodb/mongodb.py | 9 +- .../utils/kvstore/postgres/postgres.py | 12 +- .../providers/utils/kvstore/redis/redis.py | 9 +- .../providers/utils/kvstore/sqlite/sqlite.py | 12 +- .../agent/test_meta_reference_agent.py | 230 ++++++++++++++++++ .../agents/test_persistence_access_control.py | 3 + 19 files changed, 510 insertions(+), 130 deletions(-) create mode 100644 tests/unit/providers/agent/test_meta_reference_agent.py diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index 84d4cf646..163d05087 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -307,11 +307,11 @@ "get": { "responses": { "200": { - "description": "A ListAgentsResponse.", + "description": "A PaginatedResponse.", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ListAgentsResponse" + "$ref": "#/components/schemas/PaginatedResponse" } } } @@ -333,7 +333,26 @@ "Agents" ], "description": "List all agents.", - "parameters": [] + "parameters": [ + { + "name": "start_index", + "in": "query", + "description": "The index to start the pagination from.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "limit", + "in": "query", + "description": "The number of agents to return.", + "required": false, + "schema": { + "type": "integer" + } + } + ] }, "post": { "responses": { @@ -691,7 +710,7 @@ "tags": [ "Agents" ], - "description": "Delete an agent by its ID.", + "description": "Delete an agent by its ID and its associated sessions and turns.", "parameters": [ { "name": "agent_id", @@ -789,7 +808,7 @@ "tags": [ "Agents" ], - "description": "Delete an agent session by its ID.", + "description": "Delete an agent session by its ID and its associated turns.", "parameters": [ { "name": "session_id", @@ -2410,11 +2429,11 @@ "get": { "responses": { "200": { - "description": "A ListAgentSessionsResponse.", + "description": "A PaginatedResponse.", "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/ListAgentSessionsResponse" + "$ref": "#/components/schemas/PaginatedResponse" } } } @@ -2445,6 +2464,24 @@ "schema": { "type": "string" } + }, + { + "name": "start_index", + "in": "query", + "description": "The index to start the pagination from.", + "required": false, + "schema": { + "type": "integer" + } + }, + { + "name": "limit", + "in": "query", + "description": "The number of sessions to return.", + "required": false, + "schema": { + "type": "integer" + } } ] } @@ -8996,38 +9033,6 @@ ], "title": "Job" }, - "ListAgentSessionsResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Session" - } - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListAgentSessionsResponse" - }, - "ListAgentsResponse": { - "type": "object", - "properties": { - "data": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Agent" - } - } - }, - "additionalProperties": false, - "required": [ - "data" - ], - "title": "ListAgentsResponse" - }, "BucketResponse": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 259cb3007..67c1d3dbd 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -197,11 +197,11 @@ paths: get: responses: '200': - description: A ListAgentsResponse. + description: A PaginatedResponse. content: application/json: schema: - $ref: '#/components/schemas/ListAgentsResponse' + $ref: '#/components/schemas/PaginatedResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -215,7 +215,19 @@ paths: tags: - Agents description: List all agents. - parameters: [] + parameters: + - name: start_index + in: query + description: The index to start the pagination from. + required: false + schema: + type: integer + - name: limit + in: query + description: The number of agents to return. + required: false + schema: + type: integer post: responses: '200': @@ -465,7 +477,8 @@ paths: $ref: '#/components/responses/DefaultError' tags: - Agents - description: Delete an agent by its ID. + description: >- + Delete an agent by its ID and its associated sessions and turns. parameters: - name: agent_id in: path @@ -534,7 +547,8 @@ paths: $ref: '#/components/responses/DefaultError' tags: - Agents - description: Delete an agent session by its ID. + description: >- + Delete an agent session by its ID and its associated turns. parameters: - name: session_id in: path @@ -1662,11 +1676,11 @@ paths: get: responses: '200': - description: A ListAgentSessionsResponse. + description: A PaginatedResponse. content: application/json: schema: - $ref: '#/components/schemas/ListAgentSessionsResponse' + $ref: '#/components/schemas/PaginatedResponse' '400': $ref: '#/components/responses/BadRequest400' '429': @@ -1688,6 +1702,18 @@ paths: required: true schema: type: string + - name: start_index + in: query + description: The index to start the pagination from. + required: false + schema: + type: integer + - name: limit + in: query + description: The number of sessions to return. + required: false + schema: + type: integer /v1/eval/benchmarks: get: responses: @@ -6217,28 +6243,6 @@ components: - job_id - status title: Job - ListAgentSessionsResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/Session' - additionalProperties: false - required: - - data - title: ListAgentSessionsResponse - ListAgentsResponse: - type: object - properties: - data: - type: array - items: - $ref: '#/components/schemas/Agent' - additionalProperties: false - required: - - data - title: ListAgentsResponse BucketResponse: type: object properties: diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index 84e3a9057..91e57dbbe 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable from pydantic import BaseModel, ConfigDict, Field from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent +from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.inference import ( CompletionMessage, ResponseFormat, @@ -241,16 +242,6 @@ class Agent(BaseModel): created_at: datetime -@json_schema_type -class ListAgentsResponse(BaseModel): - data: list[Agent] - - -@json_schema_type -class ListAgentSessionsResponse(BaseModel): - data: list[Session] - - class AgentConfigOverridablePerTurn(AgentConfigCommon): instructions: str | None = None @@ -526,7 +517,7 @@ class Agents(Protocol): session_id: str, agent_id: str, ) -> None: - """Delete an agent session by its ID. + """Delete an agent session by its ID and its associated turns. :param session_id: The ID of the session to delete. :param agent_id: The ID of the agent to delete the session for. @@ -538,17 +529,19 @@ class Agents(Protocol): self, agent_id: str, ) -> None: - """Delete an agent by its ID. + """Delete an agent by its ID and its associated sessions and turns. :param agent_id: The ID of the agent to delete. """ ... @webmethod(route="/agents", method="GET") - async def list_agents(self) -> ListAgentsResponse: + async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse: """List all agents. - :returns: A ListAgentsResponse. + :param start_index: The index to start the pagination from. + :param limit: The number of agents to return. + :returns: A PaginatedResponse. """ ... @@ -565,11 +558,15 @@ class Agents(Protocol): async def list_agent_sessions( self, agent_id: str, - ) -> ListAgentSessionsResponse: + start_index: int | None = None, + limit: int | None = None, + ) -> PaginatedResponse: """List all session(s) of a given agent. :param agent_id: The ID of the agent to list sessions for. - :returns: A ListAgentSessionsResponse. + :param start_index: The index to start the pagination from. + :param limit: The number of sessions to return. + :returns: A PaginatedResponse. """ ... diff --git a/llama_stack/distribution/store/registry.py b/llama_stack/distribution/store/registry.py index ae97f600c..a6b400136 100644 --- a/llama_stack/distribution/store/registry.py +++ b/llama_stack/distribution/store/registry.py @@ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry): async def get_all(self) -> list[RoutableObjectWithProvider]: start_key, end_key = _get_registry_key_range() - values = await self.kvstore.range(start_key, end_key) + values = await self.kvstore.values_in_range(start_key, end_key) return _parse_registry_values(values) async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: @@ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): return start_key, end_key = _get_registry_key_range() - values = await self.kvstore.range(start_key, end_key) + values = await self.kvstore.values_in_range(start_key, end_key) objects = _parse_registry_values(values) async with self._locked_cache() as cache: diff --git a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py index 3807ddb86..2e387e7e8 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agent_instance.py +++ b/llama_stack/providers/inline/agents/meta_reference/agent_instance.py @@ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin): tool_groups_api: ToolGroups, vector_io_api: VectorIO, persistence_store: KVStore, + created_at: str, ): self.agent_id = agent_id self.agent_config = agent_config @@ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin): self.storage = AgentPersistence(agent_id, persistence_store) self.tool_runtime_api = tool_runtime_api self.tool_groups_api = tool_groups_api + self.created_at = created_at ShieldRunnerMixin.__init__( self, diff --git a/llama_stack/providers/inline/agents/meta_reference/agents.py b/llama_stack/providers/inline/agents/meta_reference/agents.py index a87739925..19d60c816 100644 --- a/llama_stack/providers/inline/agents/meta_reference/agents.py +++ b/llama_stack/providers/inline/agents/meta_reference/agents.py @@ -4,10 +4,10 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json import logging import uuid from collections.abc import AsyncGenerator +from datetime import datetime, timezone from llama_stack.apis.agents import ( Agent, @@ -20,14 +20,13 @@ from llama_stack.apis.agents import ( AgentTurnCreateRequest, AgentTurnResumeRequest, Document, - ListAgentSessionsResponse, - ListAgentsResponse, OpenAIResponseInputMessage, OpenAIResponseInputTool, OpenAIResponseObject, Session, Turn, ) +from llama_stack.apis.common.responses import PaginatedResponse from llama_stack.apis.inference import ( Inference, ToolConfig, @@ -38,14 +37,15 @@ from llama_stack.apis.inference import ( from llama_stack.apis.safety import Safety from llama_stack.apis.tools import ToolGroups, ToolRuntime from llama_stack.apis.vector_io import VectorIO +from llama_stack.providers.utils.datasetio.pagination import paginate_records from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig from .openai_responses import OpenAIResponsesImpl +from .persistence import AgentInfo logger = logging.getLogger() -logger.setLevel(logging.INFO) class MetaReferenceAgentsImpl(Agents): @@ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents): agent_config: AgentConfig, ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) + created_at = datetime.now(timezone.utc) + agent_info = AgentInfo( + **agent_config.model_dump(), + created_at=created_at, + ) + + # Store the agent info await self.persistence_store.set( key=f"agent:{agent_id}", - value=agent_config.model_dump_json(), + value=agent_info.model_dump_json(), ) + return AgentCreateResponse( agent_id=agent_id, ) async def _get_agent_impl(self, agent_id: str) -> ChatAgent: - agent_config = await self.persistence_store.get( + agent_info_json = await self.persistence_store.get( key=f"agent:{agent_id}", ) - if not agent_config: - raise ValueError(f"Could not find agent config for {agent_id}") + if not agent_info_json: + raise ValueError(f"Could not find agent info for {agent_id}") try: - agent_config = json.loads(agent_config) - except json.JSONDecodeError as e: - raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e - - try: - agent_config = AgentConfig(**agent_config) + agent_info = AgentInfo.model_validate_json(agent_info_json) except Exception as e: - raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e + raise ValueError(f"Could not validate agent info for {agent_id}") from e return ChatAgent( agent_id=agent_id, - agent_config=agent_config, + agent_config=agent_info, inference_api=self.inference_api, safety_api=self.safety_api, vector_io_api=self.vector_io_api, tool_runtime_api=self.tool_runtime_api, tool_groups_api=self.tool_groups_api, persistence_store=( - self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store + self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store ), + created_at=agent_info.created_at, ) async def create_agent_session( @@ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents): turn_ids: list[str] | None = None, ) -> Session: agent = await self._get_agent_impl(agent_id) + session_info = await agent.storage.get_session_info(session_id) if session_info is None: raise ValueError(f"Session {session_id} not found") @@ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents): ) async def delete_agents_session(self, agent_id: str, session_id: str) -> None: - await self.persistence_store.delete(f"session:{agent_id}:{session_id}") + agent = await self._get_agent_impl(agent_id) + session_info = await agent.storage.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + # Delete turns first, then the session + await agent.storage.delete_session_turns(session_id) + await agent.storage.delete_session(session_id) async def delete_agent(self, agent_id: str) -> None: + # First get all sessions for this agent + agent = await self._get_agent_impl(agent_id) + sessions = await agent.storage.list_sessions() + + # Delete all sessions + for session in sessions: + await self.delete_agents_session(agent_id, session.session_id) + + # Finally delete the agent itself await self.persistence_store.delete(f"agent:{agent_id}") - async def shutdown(self) -> None: - pass + async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse: + agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff") + agent_list: list[Agent] = [] + for agent_key in agent_keys: + agent_id = agent_key.split(":")[1] - async def list_agents(self) -> ListAgentsResponse: - pass + # Get the agent info using the key + agent_info_json = await self.persistence_store.get(agent_key) + if not agent_info_json: + logger.error(f"Could not find agent info for key {agent_key}") + continue + + try: + agent_info = AgentInfo.model_validate_json(agent_info_json) + agent_list.append( + Agent( + agent_id=agent_id, + agent_config=agent_info, + created_at=agent_info.created_at, + ) + ) + except Exception as e: + logger.error(f"Error parsing agent info for {agent_id}: {e}") + continue + + # Convert Agent objects to dictionaries + agent_dicts = [agent.model_dump() for agent in agent_list] + return paginate_records(agent_dicts, start_index, limit) async def get_agent(self, agent_id: str) -> Agent: - pass + chat_agent = await self._get_agent_impl(agent_id) + agent = Agent( + agent_id=agent_id, + agent_config=chat_agent.agent_config, + created_at=chat_agent.created_at, + ) + return agent async def list_agent_sessions( - self, - agent_id: str, - ) -> ListAgentSessionsResponse: + self, agent_id: str, start_index: int | None = None, limit: int | None = None + ) -> PaginatedResponse: + agent = await self._get_agent_impl(agent_id) + sessions = await agent.storage.list_sessions() + # Convert Session objects to dictionaries + session_dicts = [session.model_dump() for session in sessions] + return paginate_records(session_dicts, start_index, limit) + + async def shutdown(self) -> None: pass # OpenAI responses diff --git a/llama_stack/providers/inline/agents/meta_reference/persistence.py b/llama_stack/providers/inline/agents/meta_reference/persistence.py index 60ce91395..5031a4a90 100644 --- a/llama_stack/providers/inline/agents/meta_reference/persistence.py +++ b/llama_stack/providers/inline/agents/meta_reference/persistence.py @@ -9,9 +9,7 @@ import logging import uuid from datetime import datetime, timezone -from pydantic import BaseModel - -from llama_stack.apis.agents import ToolExecutionStep, Turn +from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn from llama_stack.distribution.access_control import check_access from llama_stack.distribution.datatypes import AccessAttributes from llama_stack.distribution.request_headers import get_auth_attributes @@ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore log = logging.getLogger(__name__) -class AgentSessionInfo(BaseModel): - session_id: str - session_name: str +class AgentSessionInfo(Session): # TODO: is this used anywhere? vector_db_id: str | None = None started_at: datetime access_attributes: AccessAttributes | None = None +class AgentInfo(AgentConfig): + created_at: datetime + + class AgentPersistence: def __init__(self, agent_id: str, kvstore: KVStore): self.agent_id = agent_id @@ -46,6 +46,7 @@ class AgentPersistence: session_name=name, started_at=datetime.now(timezone.utc), access_attributes=access_attributes, + turns=[], ) await self.kvstore.set( @@ -109,7 +110,7 @@ class AgentPersistence: if not await self.get_session_if_accessible(session_id): raise ValueError(f"Session {session_id} not found or access denied") - values = await self.kvstore.range( + values = await self.kvstore.values_in_range( start_key=f"session:{self.agent_id}:{session_id}:", end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", ) @@ -121,7 +122,6 @@ class AgentPersistence: except Exception as e: log.error(f"Error parsing turn: {e}") continue - turns.sort(key=lambda x: (x.completed_at or datetime.min)) return turns async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: @@ -170,3 +170,43 @@ class AgentPersistence: key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", ) return int(value) if value else None + + async def list_sessions(self) -> list[Session]: + values = await self.kvstore.values_in_range( + start_key=f"session:{self.agent_id}:", + end_key=f"session:{self.agent_id}:\xff\xff\xff\xff", + ) + sessions = [] + for value in values: + try: + session_info = Session(**json.loads(value)) + sessions.append(session_info) + except Exception as e: + log.error(f"Error parsing session info: {e}") + continue + return sessions + + async def delete_session_turns(self, session_id: str) -> None: + """Delete all turns and their associated data for a session. + + Args: + session_id: The ID of the session whose turns should be deleted. + """ + turns = await self.get_session_turns(session_id) + for turn in turns: + await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}") + + async def delete_session(self, session_id: str) -> None: + """Delete a session and all its associated turns. + + Args: + session_id: The ID of the session to delete. + + Raises: + ValueError: If the session does not exist. + """ + session_info = await self.get_session_info(session_id) + if session_info is None: + raise ValueError(f"Session {session_id} not found") + + await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}") diff --git a/llama_stack/providers/inline/datasetio/localfs/datasetio.py b/llama_stack/providers/inline/datasetio/localfs/datasetio.py index 260a640bd..5640df6ab 100644 --- a/llama_stack/providers/inline/datasetio/localfs/datasetio.py +++ b/llama_stack/providers/inline/datasetio/localfs/datasetio.py @@ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): # Load existing datasets from kvstore start_key = DATASETS_PREFIX end_key = f"{DATASETS_PREFIX}\xff" - stored_datasets = await self.kvstore.range(start_key, end_key) + stored_datasets = await self.kvstore.values_in_range(start_key, end_key) for dataset in stored_datasets: dataset = Dataset.model_validate_json(dataset) diff --git a/llama_stack/providers/inline/eval/meta_reference/eval.py b/llama_stack/providers/inline/eval/meta_reference/eval.py index 12dde66b5..bc0898dc5 100644 --- a/llama_stack/providers/inline/eval/meta_reference/eval.py +++ b/llama_stack/providers/inline/eval/meta_reference/eval.py @@ -58,7 +58,7 @@ class MetaReferenceEvalImpl( # Load existing benchmarks from kvstore start_key = EVAL_TASKS_PREFIX end_key = f"{EVAL_TASKS_PREFIX}\xff" - stored_benchmarks = await self.kvstore.range(start_key, end_key) + stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key) for benchmark in stored_benchmarks: benchmark = Benchmark.model_validate_json(benchmark) diff --git a/llama_stack/providers/inline/vector_io/faiss/faiss.py b/llama_stack/providers/inline/vector_io/faiss/faiss.py index 5d5b1da35..d3dc7e694 100644 --- a/llama_stack/providers/inline/vector_io/faiss/faiss.py +++ b/llama_stack/providers/inline/vector_io/faiss/faiss.py @@ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): # Load existing banks from kvstore start_key = VECTOR_DBS_PREFIX end_key = f"{VECTOR_DBS_PREFIX}\xff" - stored_vector_dbs = await self.kvstore.range(start_key, end_key) + stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) for vector_db_data in stored_vector_dbs: vector_db = VectorDB.model_validate_json(vector_db_data) diff --git a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py index baecf45e9..e329c88a7 100644 --- a/llama_stack/providers/remote/datasetio/huggingface/huggingface.py +++ b/llama_stack/providers/remote/datasetio/huggingface/huggingface.py @@ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): # Load existing datasets from kvstore start_key = DATASETS_PREFIX end_key = f"{DATASETS_PREFIX}\xff" - stored_datasets = await self.kvstore.range(start_key, end_key) + stored_datasets = await self.kvstore.values_in_range(start_key, end_key) for dataset in stored_datasets: dataset = Dataset.model_validate_json(dataset) diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py index 8efde4ea9..d17dc66e1 100644 --- a/llama_stack/providers/utils/kvstore/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -16,4 +16,6 @@ class KVStore(Protocol): async def delete(self, key: str) -> None: ... - async def range(self, start_key: str, end_key: str) -> list[str]: ... + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index 0eb969b65..3a1ee8a26 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore): async def set(self, key: str, value: str) -> None: self._store[key] = value - async def range(self, start_key: str, end_key: str) -> list[str]: + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key] + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + """Get all keys in the given range.""" + return [key for key in self._store.keys() if key >= start_key and key < end_key] + + async def delete(self, key: str) -> None: + del self._store[key] + async def kvstore_impl(config: KVStoreConfig) -> KVStore: if config.type == KVStoreType.redis.value: diff --git a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py index 330a079bf..3842773d9 100644 --- a/llama_stack/providers/utils/kvstore/mongodb/mongodb.py +++ b/llama_stack/providers/utils/kvstore/mongodb/mongodb.py @@ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore): key = self._namespaced_key(key) await self.collection.delete_one({"key": key}) - async def range(self, start_key: str, end_key: str) -> list[str]: + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) query = { @@ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore): async for doc in cursor: result.append(doc["value"]) return result + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + query = {"key": {"$gte": start_key, "$lt": end_key}} + cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1) + return [doc["key"] for doc in cursor] diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py index 6bfbc4f81..bd35decfc 100644 --- a/llama_stack/providers/utils/kvstore/postgres/postgres.py +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore): (key,), ) - async def range(self, start_key: str, end_key: str) -> list[str]: + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) @@ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore): (start_key, end_key), ) return [row[0] for row in self.cursor.fetchall()] + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + self.cursor.execute( + f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s", + (start_key, end_key), + ) + return [row[0] for row in self.cursor.fetchall()] diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py index d95de0291..3d2d956c3 100644 --- a/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore): key = self._namespaced_key(key) await self.redis.delete(key) - async def range(self, start_key: str, end_key: str) -> list[str]: + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) cursor = 0 @@ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore): ] return [] + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + """Get all keys in the given range.""" + matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}") + if not matching_keys: + return [] + return [k.decode("utf-8") for k in matching_keys] diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index faca77887..4e49e4d8c 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore): await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.commit() - async def range(self, start_key: str, end_key: str) -> list[str]: + async def values_in_range(self, start_key: str, end_key: str) -> list[str]: async with aiosqlite.connect(self.db_path) as db: async with db.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", @@ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore): _, value, _ = row result.append(value) return result + + async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: + """Get all keys in the given range.""" + async with aiosqlite.connect(self.db_path) as db: + cursor = await db.execute( + f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) + rows = await cursor.fetchall() + return [row[0] for row in rows] diff --git a/tests/unit/providers/agent/test_meta_reference_agent.py b/tests/unit/providers/agent/test_meta_reference_agent.py new file mode 100644 index 000000000..bef24e123 --- /dev/null +++ b/tests/unit/providers/agent/test_meta_reference_agent.py @@ -0,0 +1,230 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from unittest.mock import AsyncMock + +import pytest +import pytest_asyncio + +from llama_stack.apis.agents import ( + Agent, + AgentConfig, + AgentCreateResponse, +) +from llama_stack.apis.common.responses import PaginatedResponse +from llama_stack.apis.inference import Inference +from llama_stack.apis.safety import Safety +from llama_stack.apis.tools import ToolGroups, ToolRuntime +from llama_stack.apis.vector_io import VectorIO +from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl +from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig +from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo + + +@pytest.fixture +def mock_apis(): + return { + "inference_api": AsyncMock(spec=Inference), + "vector_io_api": AsyncMock(spec=VectorIO), + "safety_api": AsyncMock(spec=Safety), + "tool_runtime_api": AsyncMock(spec=ToolRuntime), + "tool_groups_api": AsyncMock(spec=ToolGroups), + } + + +@pytest.fixture +def config(tmp_path): + return MetaReferenceAgentsImplConfig( + persistence_store={ + "type": "sqlite", + "db_path": str(tmp_path / "test.db"), + }, + ) + + +@pytest_asyncio.fixture +async def agents_impl(config, mock_apis): + impl = MetaReferenceAgentsImpl( + config, + mock_apis["inference_api"], + mock_apis["vector_io_api"], + mock_apis["safety_api"], + mock_apis["tool_runtime_api"], + mock_apis["tool_groups_api"], + ) + await impl.initialize() + yield impl + await impl.shutdown() + + +@pytest.fixture +def sample_agent_config(): + return AgentConfig( + sampling_params={ + "strategy": {"type": "greedy"}, + "max_tokens": 0, + "repetition_penalty": 1.0, + }, + input_shields=["string"], + output_shields=["string"], + toolgroups=["string"], + client_tools=[ + { + "name": "string", + "description": "string", + "parameters": [ + { + "name": "string", + "parameter_type": "string", + "description": "string", + "required": True, + "default": None, + } + ], + "metadata": { + "property1": None, + "property2": None, + }, + } + ], + tool_choice="auto", + tool_prompt_format="json", + tool_config={ + "tool_choice": "auto", + "tool_prompt_format": "json", + "system_message_behavior": "append", + }, + max_infer_iters=10, + model="string", + instructions="string", + enable_session_persistence=False, + response_format={ + "type": "json_schema", + "json_schema": { + "property1": None, + "property2": None, + }, + }, + ) + + +@pytest.mark.asyncio +async def test_create_agent(agents_impl, sample_agent_config): + response = await agents_impl.create_agent(sample_agent_config) + + assert isinstance(response, AgentCreateResponse) + assert response.agent_id is not None + + stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}") + assert stored_agent is not None + agent_info = AgentInfo.model_validate_json(stored_agent) + assert agent_info.model == sample_agent_config.model + assert agent_info.created_at is not None + assert isinstance(agent_info.created_at, datetime) + + +@pytest.mark.asyncio +async def test_get_agent(agents_impl, sample_agent_config): + create_response = await agents_impl.create_agent(sample_agent_config) + agent_id = create_response.agent_id + + agent = await agents_impl.get_agent(agent_id) + + assert isinstance(agent, Agent) + assert agent.agent_id == agent_id + assert agent.agent_config.model == sample_agent_config.model + assert agent.created_at is not None + assert isinstance(agent.created_at, datetime) + + +@pytest.mark.asyncio +async def test_list_agents(agents_impl, sample_agent_config): + agent1_response = await agents_impl.create_agent(sample_agent_config) + agent2_response = await agents_impl.create_agent(sample_agent_config) + + response = await agents_impl.list_agents() + + assert isinstance(response, PaginatedResponse) + assert len(response.data) == 2 + agent_ids = {agent["agent_id"] for agent in response.data} + assert agent1_response.agent_id in agent_ids + assert agent2_response.agent_id in agent_ids + + +@pytest.mark.asyncio +@pytest.mark.parametrize("enable_session_persistence", [True, False]) +async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence): + # Create an agent with specified persistence setting + config = sample_agent_config.model_copy() + config.enable_session_persistence = enable_session_persistence + response = await agents_impl.create_agent(config) + agent_id = response.agent_id + + # Create a session + session_response = await agents_impl.create_agent_session(agent_id, "test_session") + assert session_response.session_id is not None + + # Verify the session was stored + session = await agents_impl.get_agents_session(agent_id, session_response.session_id) + assert session.session_name == "test_session" + assert session.session_id == session_response.session_id + assert session.started_at is not None + assert session.turns == [] + + # Delete the session + await agents_impl.delete_agents_session(agent_id, session_response.session_id) + + # Verify the session was deleted + with pytest.raises(ValueError): + await agents_impl.get_agents_session(agent_id, session_response.session_id) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("enable_session_persistence", [True, False]) +async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence): + # Create an agent with specified persistence setting + config = sample_agent_config.model_copy() + config.enable_session_persistence = enable_session_persistence + response = await agents_impl.create_agent(config) + agent_id = response.agent_id + + # Create multiple sessions + session1 = await agents_impl.create_agent_session(agent_id, "session1") + session2 = await agents_impl.create_agent_session(agent_id, "session2") + + # List sessions + sessions = await agents_impl.list_agent_sessions(agent_id) + assert len(sessions.data) == 2 + session_ids = {s["session_id"] for s in sessions.data} + assert session1.session_id in session_ids + assert session2.session_id in session_ids + + # Delete one session + await agents_impl.delete_agents_session(agent_id, session1.session_id) + + # Verify the session was deleted + with pytest.raises(ValueError): + await agents_impl.get_agents_session(agent_id, session1.session_id) + + # List sessions again + sessions = await agents_impl.list_agent_sessions(agent_id) + assert len(sessions.data) == 1 + assert session2.session_id in {s["session_id"] for s in sessions.data} + + +@pytest.mark.asyncio +async def test_delete_agent(agents_impl, sample_agent_config): + # Create an agent + response = await agents_impl.create_agent(sample_agent_config) + agent_id = response.agent_id + + # Delete the agent + await agents_impl.delete_agent(agent_id) + + # Verify the agent was deleted + with pytest.raises(ValueError): + await agents_impl.get_agent(agent_id) diff --git a/tests/unit/providers/agents/test_persistence_access_control.py b/tests/unit/providers/agents/test_persistence_access_control.py index 06e1a778a..48fa647a8 100644 --- a/tests/unit/providers/agents/test_persistence_access_control.py +++ b/tests/unit/providers/agents/test_persistence_access_control.py @@ -54,6 +54,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup): session_name="Restricted Session", started_at=datetime.now(), access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), + turns=[], ) await agent_persistence.kvstore.set( @@ -85,6 +86,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): session_name="Restricted Session", started_at=datetime.now(), access_attributes=AccessAttributes(roles=["admin"]), + turns=[], ) await agent_persistence.kvstore.set( @@ -137,6 +139,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes session_name="Restricted Session", started_at=datetime.now(), access_attributes=AccessAttributes(roles=["admin"]), + turns=[], ) await agent_persistence.kvstore.set(