forked from phoenix-oss/llama-stack-mirror
		
	feat: implementation for agent/session list and describe (#1606)
Create a new agent: ``` curl --request POST \ --url http://localhost:8321/v1/agents \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } } }' ``` Get agent: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f {"agent_id":"9abad4ab-2c77-45f9-9d16-46b79d2bea1f","agent_config":{"sampling_params":{"strategy":{"type":"greedy"},"max_tokens":0,"repetition_penalty":1.0},"input_shields":["string"],"output_shields":["string"],"toolgroups":["string"],"client_tools":[{"name":"string","description":"string","parameters":[{"name":"string","parameter_type":"string","description":"string","required":true,"default":null}],"metadata":{"property1":null,"property2":null}}],"tool_choice":"auto","tool_prompt_format":"json","tool_config":{"tool_choice":"auto","tool_prompt_format":"json","system_message_behavior":"append"},"max_infer_iters":10,"model":"string","instructions":"string","enable_session_persistence":false,"response_format":{"type":"json_schema","json_schema":{"property1":null,"property2":null}}},"created_at":"2025-03-12T16:18:28.369144Z"}% ``` List agents: ``` curl http://127.0.0.1:8321/v1/agents|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 1680 100 1680 0 0 498k 0 --:--:-- --:--:-- --:--:-- 546k { "data": [ { "agent_id": "9abad4ab-2c77-45f9-9d16-46b79d2bea1f", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:18:28.369144Z" }, { "agent_id": "a6643aaa-96dd-46db-a405-333dc504b168", "agent_config": { "sampling_params": { "strategy": { "type": "greedy" }, "max_tokens": 0, "repetition_penalty": 1.0 }, "input_shields": [ "string" ], "output_shields": [ "string" ], "toolgroups": [ "string" ], "client_tools": [ { "name": "string", "description": "string", "parameters": [ { "name": "string", "parameter_type": "string", "description": "string", "required": true, "default": null } ], "metadata": { "property1": null, "property2": null } } ], "tool_choice": "auto", "tool_prompt_format": "json", "tool_config": { "tool_choice": "auto", "tool_prompt_format": "json", "system_message_behavior": "append" }, "max_infer_iters": 10, "model": "string", "instructions": "string", "enable_session_persistence": false, "response_format": { "type": "json_schema", "json_schema": { "property1": null, "property2": null } } }, "created_at": "2025-03-12T16:17:12.811273Z" } ] } ``` Create sessions: ``` curl --request POST \ --url http://localhost:8321/v1/agents/{agent_id}/session \ --header 'Accept: application/json' \ --header 'Content-Type: application/json' \ --data '{ "session_name": "string" }' ``` List sessions: ``` curl http://127.0.0.1:8321/v1/agents/9abad4ab-2c77-45f9-9d16-46b79d2bea1f/sessions|jq % Total % Received % Xferd Average Speed Time Time Time Current Dload Upload Total Spent Left Speed 100 263 100 263 0 0 90099 0 --:--:-- --:--:-- --:--:-- 128k [ { "session_id": "2b15c4fc-e348-46c1-ae32-f6d424441ac1", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:17.784328" }, { "session_id": "9432472d-d483-4b73-b682-7b1d35d64111", "session_name": "string", "turns": [], "started_at": "2025-03-12T17:19:19.885834" } ] ``` Signed-off-by: Sébastien Han <seb@redhat.com>
This commit is contained in:
		
							parent
							
								
									40e71758d9
								
							
						
					
					
						commit
						c91e3552a3
					
				
					 19 changed files with 510 additions and 130 deletions
				
			
		
							
								
								
									
										83
									
								
								docs/_static/llama-stack-spec.html
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										83
									
								
								docs/_static/llama-stack-spec.html
									
										
									
									
										vendored
									
									
								
							|  | @ -307,11 +307,11 @@ | |||
|             "get": { | ||||
|                 "responses": { | ||||
|                     "200": { | ||||
|                         "description": "A ListAgentsResponse.", | ||||
|                         "description": "A PaginatedResponse.", | ||||
|                         "content": { | ||||
|                             "application/json": { | ||||
|                                 "schema": { | ||||
|                                     "$ref": "#/components/schemas/ListAgentsResponse" | ||||
|                                     "$ref": "#/components/schemas/PaginatedResponse" | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
|  | @ -333,7 +333,26 @@ | |||
|                     "Agents" | ||||
|                 ], | ||||
|                 "description": "List all agents.", | ||||
|                 "parameters": [] | ||||
|                 "parameters": [ | ||||
|                     { | ||||
|                         "name": "start_index", | ||||
|                         "in": "query", | ||||
|                         "description": "The index to start the pagination from.", | ||||
|                         "required": false, | ||||
|                         "schema": { | ||||
|                             "type": "integer" | ||||
|                         } | ||||
|                     }, | ||||
|                     { | ||||
|                         "name": "limit", | ||||
|                         "in": "query", | ||||
|                         "description": "The number of agents to return.", | ||||
|                         "required": false, | ||||
|                         "schema": { | ||||
|                             "type": "integer" | ||||
|                         } | ||||
|                     } | ||||
|                 ] | ||||
|             }, | ||||
|             "post": { | ||||
|                 "responses": { | ||||
|  | @ -691,7 +710,7 @@ | |||
|                 "tags": [ | ||||
|                     "Agents" | ||||
|                 ], | ||||
|                 "description": "Delete an agent by its ID.", | ||||
|                 "description": "Delete an agent by its ID and its associated sessions and turns.", | ||||
|                 "parameters": [ | ||||
|                     { | ||||
|                         "name": "agent_id", | ||||
|  | @ -789,7 +808,7 @@ | |||
|                 "tags": [ | ||||
|                     "Agents" | ||||
|                 ], | ||||
|                 "description": "Delete an agent session by its ID.", | ||||
|                 "description": "Delete an agent session by its ID and its associated turns.", | ||||
|                 "parameters": [ | ||||
|                     { | ||||
|                         "name": "session_id", | ||||
|  | @ -2410,11 +2429,11 @@ | |||
|             "get": { | ||||
|                 "responses": { | ||||
|                     "200": { | ||||
|                         "description": "A ListAgentSessionsResponse.", | ||||
|                         "description": "A PaginatedResponse.", | ||||
|                         "content": { | ||||
|                             "application/json": { | ||||
|                                 "schema": { | ||||
|                                     "$ref": "#/components/schemas/ListAgentSessionsResponse" | ||||
|                                     "$ref": "#/components/schemas/PaginatedResponse" | ||||
|                                 } | ||||
|                             } | ||||
|                         } | ||||
|  | @ -2445,6 +2464,24 @@ | |||
|                         "schema": { | ||||
|                             "type": "string" | ||||
|                         } | ||||
|                     }, | ||||
|                     { | ||||
|                         "name": "start_index", | ||||
|                         "in": "query", | ||||
|                         "description": "The index to start the pagination from.", | ||||
|                         "required": false, | ||||
|                         "schema": { | ||||
|                             "type": "integer" | ||||
|                         } | ||||
|                     }, | ||||
|                     { | ||||
|                         "name": "limit", | ||||
|                         "in": "query", | ||||
|                         "description": "The number of sessions to return.", | ||||
|                         "required": false, | ||||
|                         "schema": { | ||||
|                             "type": "integer" | ||||
|                         } | ||||
|                     } | ||||
|                 ] | ||||
|             } | ||||
|  | @ -8996,38 +9033,6 @@ | |||
|                 ], | ||||
|                 "title": "Job" | ||||
|             }, | ||||
|             "ListAgentSessionsResponse": { | ||||
|                 "type": "object", | ||||
|                 "properties": { | ||||
|                     "data": { | ||||
|                         "type": "array", | ||||
|                         "items": { | ||||
|                             "$ref": "#/components/schemas/Session" | ||||
|                         } | ||||
|                     } | ||||
|                 }, | ||||
|                 "additionalProperties": false, | ||||
|                 "required": [ | ||||
|                     "data" | ||||
|                 ], | ||||
|                 "title": "ListAgentSessionsResponse" | ||||
|             }, | ||||
|             "ListAgentsResponse": { | ||||
|                 "type": "object", | ||||
|                 "properties": { | ||||
|                     "data": { | ||||
|                         "type": "array", | ||||
|                         "items": { | ||||
|                             "$ref": "#/components/schemas/Agent" | ||||
|                         } | ||||
|                     } | ||||
|                 }, | ||||
|                 "additionalProperties": false, | ||||
|                 "required": [ | ||||
|                     "data" | ||||
|                 ], | ||||
|                 "title": "ListAgentsResponse" | ||||
|             }, | ||||
|             "BucketResponse": { | ||||
|                 "type": "object", | ||||
|                 "properties": { | ||||
|  |  | |||
							
								
								
									
										62
									
								
								docs/_static/llama-stack-spec.yaml
									
										
									
									
										vendored
									
									
								
							
							
						
						
									
										62
									
								
								docs/_static/llama-stack-spec.yaml
									
										
									
									
										vendored
									
									
								
							|  | @ -197,11 +197,11 @@ paths: | |||
|     get: | ||||
|       responses: | ||||
|         '200': | ||||
|           description: A ListAgentsResponse. | ||||
|           description: A PaginatedResponse. | ||||
|           content: | ||||
|             application/json: | ||||
|               schema: | ||||
|                 $ref: '#/components/schemas/ListAgentsResponse' | ||||
|                 $ref: '#/components/schemas/PaginatedResponse' | ||||
|         '400': | ||||
|           $ref: '#/components/responses/BadRequest400' | ||||
|         '429': | ||||
|  | @ -215,7 +215,19 @@ paths: | |||
|       tags: | ||||
|         - Agents | ||||
|       description: List all agents. | ||||
|       parameters: [] | ||||
|       parameters: | ||||
|         - name: start_index | ||||
|           in: query | ||||
|           description: The index to start the pagination from. | ||||
|           required: false | ||||
|           schema: | ||||
|             type: integer | ||||
|         - name: limit | ||||
|           in: query | ||||
|           description: The number of agents to return. | ||||
|           required: false | ||||
|           schema: | ||||
|             type: integer | ||||
|     post: | ||||
|       responses: | ||||
|         '200': | ||||
|  | @ -465,7 +477,8 @@ paths: | |||
|           $ref: '#/components/responses/DefaultError' | ||||
|       tags: | ||||
|         - Agents | ||||
|       description: Delete an agent by its ID. | ||||
|       description: >- | ||||
|         Delete an agent by its ID and its associated sessions and turns. | ||||
|       parameters: | ||||
|         - name: agent_id | ||||
|           in: path | ||||
|  | @ -534,7 +547,8 @@ paths: | |||
|           $ref: '#/components/responses/DefaultError' | ||||
|       tags: | ||||
|         - Agents | ||||
|       description: Delete an agent session by its ID. | ||||
|       description: >- | ||||
|         Delete an agent session by its ID and its associated turns. | ||||
|       parameters: | ||||
|         - name: session_id | ||||
|           in: path | ||||
|  | @ -1662,11 +1676,11 @@ paths: | |||
|     get: | ||||
|       responses: | ||||
|         '200': | ||||
|           description: A ListAgentSessionsResponse. | ||||
|           description: A PaginatedResponse. | ||||
|           content: | ||||
|             application/json: | ||||
|               schema: | ||||
|                 $ref: '#/components/schemas/ListAgentSessionsResponse' | ||||
|                 $ref: '#/components/schemas/PaginatedResponse' | ||||
|         '400': | ||||
|           $ref: '#/components/responses/BadRequest400' | ||||
|         '429': | ||||
|  | @ -1688,6 +1702,18 @@ paths: | |||
|           required: true | ||||
|           schema: | ||||
|             type: string | ||||
|         - name: start_index | ||||
|           in: query | ||||
|           description: The index to start the pagination from. | ||||
|           required: false | ||||
|           schema: | ||||
|             type: integer | ||||
|         - name: limit | ||||
|           in: query | ||||
|           description: The number of sessions to return. | ||||
|           required: false | ||||
|           schema: | ||||
|             type: integer | ||||
|   /v1/eval/benchmarks: | ||||
|     get: | ||||
|       responses: | ||||
|  | @ -6217,28 +6243,6 @@ components: | |||
|         - job_id | ||||
|         - status | ||||
|       title: Job | ||||
|     ListAgentSessionsResponse: | ||||
|       type: object | ||||
|       properties: | ||||
|         data: | ||||
|           type: array | ||||
|           items: | ||||
|             $ref: '#/components/schemas/Session' | ||||
|       additionalProperties: false | ||||
|       required: | ||||
|         - data | ||||
|       title: ListAgentSessionsResponse | ||||
|     ListAgentsResponse: | ||||
|       type: object | ||||
|       properties: | ||||
|         data: | ||||
|           type: array | ||||
|           items: | ||||
|             $ref: '#/components/schemas/Agent' | ||||
|       additionalProperties: false | ||||
|       required: | ||||
|         - data | ||||
|       title: ListAgentsResponse | ||||
|     BucketResponse: | ||||
|       type: object | ||||
|       properties: | ||||
|  |  | |||
|  | @ -13,6 +13,7 @@ from typing import Annotated, Any, Literal, Protocol, runtime_checkable | |||
| from pydantic import BaseModel, ConfigDict, Field | ||||
| 
 | ||||
| from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent | ||||
| from llama_stack.apis.common.responses import PaginatedResponse | ||||
| from llama_stack.apis.inference import ( | ||||
|     CompletionMessage, | ||||
|     ResponseFormat, | ||||
|  | @ -241,16 +242,6 @@ class Agent(BaseModel): | |||
|     created_at: datetime | ||||
| 
 | ||||
| 
 | ||||
| @json_schema_type | ||||
| class ListAgentsResponse(BaseModel): | ||||
|     data: list[Agent] | ||||
| 
 | ||||
| 
 | ||||
| @json_schema_type | ||||
| class ListAgentSessionsResponse(BaseModel): | ||||
|     data: list[Session] | ||||
| 
 | ||||
| 
 | ||||
| class AgentConfigOverridablePerTurn(AgentConfigCommon): | ||||
|     instructions: str | None = None | ||||
| 
 | ||||
|  | @ -526,7 +517,7 @@ class Agents(Protocol): | |||
|         session_id: str, | ||||
|         agent_id: str, | ||||
|     ) -> None: | ||||
|         """Delete an agent session by its ID. | ||||
|         """Delete an agent session by its ID and its associated turns. | ||||
| 
 | ||||
|         :param session_id: The ID of the session to delete. | ||||
|         :param agent_id: The ID of the agent to delete the session for. | ||||
|  | @ -538,17 +529,19 @@ class Agents(Protocol): | |||
|         self, | ||||
|         agent_id: str, | ||||
|     ) -> None: | ||||
|         """Delete an agent by its ID. | ||||
|         """Delete an agent by its ID and its associated sessions and turns. | ||||
| 
 | ||||
|         :param agent_id: The ID of the agent to delete. | ||||
|         """ | ||||
|         ... | ||||
| 
 | ||||
|     @webmethod(route="/agents", method="GET") | ||||
|     async def list_agents(self) -> ListAgentsResponse: | ||||
|     async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse: | ||||
|         """List all agents. | ||||
| 
 | ||||
|         :returns: A ListAgentsResponse. | ||||
|         :param start_index: The index to start the pagination from. | ||||
|         :param limit: The number of agents to return. | ||||
|         :returns: A PaginatedResponse. | ||||
|         """ | ||||
|         ... | ||||
| 
 | ||||
|  | @ -565,11 +558,15 @@ class Agents(Protocol): | |||
|     async def list_agent_sessions( | ||||
|         self, | ||||
|         agent_id: str, | ||||
|     ) -> ListAgentSessionsResponse: | ||||
|         start_index: int | None = None, | ||||
|         limit: int | None = None, | ||||
|     ) -> PaginatedResponse: | ||||
|         """List all session(s) of a given agent. | ||||
| 
 | ||||
|         :param agent_id: The ID of the agent to list sessions for. | ||||
|         :returns: A ListAgentSessionsResponse. | ||||
|         :param start_index: The index to start the pagination from. | ||||
|         :param limit: The number of sessions to return. | ||||
|         :returns: A PaginatedResponse. | ||||
|         """ | ||||
|         ... | ||||
| 
 | ||||
|  |  | |||
|  | @ -73,7 +73,7 @@ class DiskDistributionRegistry(DistributionRegistry): | |||
| 
 | ||||
|     async def get_all(self) -> list[RoutableObjectWithProvider]: | ||||
|         start_key, end_key = _get_registry_key_range() | ||||
|         values = await self.kvstore.range(start_key, end_key) | ||||
|         values = await self.kvstore.values_in_range(start_key, end_key) | ||||
|         return _parse_registry_values(values) | ||||
| 
 | ||||
|     async def get(self, type: str, identifier: str) -> RoutableObjectWithProvider | None: | ||||
|  | @ -134,7 +134,7 @@ class CachedDiskDistributionRegistry(DiskDistributionRegistry): | |||
|                 return | ||||
| 
 | ||||
|             start_key, end_key = _get_registry_key_range() | ||||
|             values = await self.kvstore.range(start_key, end_key) | ||||
|             values = await self.kvstore.values_in_range(start_key, end_key) | ||||
|             objects = _parse_registry_values(values) | ||||
| 
 | ||||
|             async with self._locked_cache() as cache: | ||||
|  |  | |||
|  | @ -95,6 +95,7 @@ class ChatAgent(ShieldRunnerMixin): | |||
|         tool_groups_api: ToolGroups, | ||||
|         vector_io_api: VectorIO, | ||||
|         persistence_store: KVStore, | ||||
|         created_at: str, | ||||
|     ): | ||||
|         self.agent_id = agent_id | ||||
|         self.agent_config = agent_config | ||||
|  | @ -104,6 +105,7 @@ class ChatAgent(ShieldRunnerMixin): | |||
|         self.storage = AgentPersistence(agent_id, persistence_store) | ||||
|         self.tool_runtime_api = tool_runtime_api | ||||
|         self.tool_groups_api = tool_groups_api | ||||
|         self.created_at = created_at | ||||
| 
 | ||||
|         ShieldRunnerMixin.__init__( | ||||
|             self, | ||||
|  |  | |||
|  | @ -4,10 +4,10 @@ | |||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| import json | ||||
| import logging | ||||
| import uuid | ||||
| from collections.abc import AsyncGenerator | ||||
| from datetime import datetime, timezone | ||||
| 
 | ||||
| from llama_stack.apis.agents import ( | ||||
|     Agent, | ||||
|  | @ -20,14 +20,13 @@ from llama_stack.apis.agents import ( | |||
|     AgentTurnCreateRequest, | ||||
|     AgentTurnResumeRequest, | ||||
|     Document, | ||||
|     ListAgentSessionsResponse, | ||||
|     ListAgentsResponse, | ||||
|     OpenAIResponseInputMessage, | ||||
|     OpenAIResponseInputTool, | ||||
|     OpenAIResponseObject, | ||||
|     Session, | ||||
|     Turn, | ||||
| ) | ||||
| from llama_stack.apis.common.responses import PaginatedResponse | ||||
| from llama_stack.apis.inference import ( | ||||
|     Inference, | ||||
|     ToolConfig, | ||||
|  | @ -38,14 +37,15 @@ from llama_stack.apis.inference import ( | |||
| from llama_stack.apis.safety import Safety | ||||
| from llama_stack.apis.tools import ToolGroups, ToolRuntime | ||||
| from llama_stack.apis.vector_io import VectorIO | ||||
| from llama_stack.providers.utils.datasetio.pagination import paginate_records | ||||
| from llama_stack.providers.utils.kvstore import InmemoryKVStoreImpl, kvstore_impl | ||||
| 
 | ||||
| from .agent_instance import ChatAgent | ||||
| from .config import MetaReferenceAgentsImplConfig | ||||
| from .openai_responses import OpenAIResponsesImpl | ||||
| from .persistence import AgentInfo | ||||
| 
 | ||||
| logger = logging.getLogger() | ||||
| logger.setLevel(logging.INFO) | ||||
| 
 | ||||
| 
 | ||||
| class MetaReferenceAgentsImpl(Agents): | ||||
|  | @ -82,43 +82,47 @@ class MetaReferenceAgentsImpl(Agents): | |||
|         agent_config: AgentConfig, | ||||
|     ) -> AgentCreateResponse: | ||||
|         agent_id = str(uuid.uuid4()) | ||||
|         created_at = datetime.now(timezone.utc) | ||||
| 
 | ||||
|         agent_info = AgentInfo( | ||||
|             **agent_config.model_dump(), | ||||
|             created_at=created_at, | ||||
|         ) | ||||
| 
 | ||||
|         # Store the agent info | ||||
|         await self.persistence_store.set( | ||||
|             key=f"agent:{agent_id}", | ||||
|             value=agent_config.model_dump_json(), | ||||
|             value=agent_info.model_dump_json(), | ||||
|         ) | ||||
| 
 | ||||
|         return AgentCreateResponse( | ||||
|             agent_id=agent_id, | ||||
|         ) | ||||
| 
 | ||||
|     async def _get_agent_impl(self, agent_id: str) -> ChatAgent: | ||||
|         agent_config = await self.persistence_store.get( | ||||
|         agent_info_json = await self.persistence_store.get( | ||||
|             key=f"agent:{agent_id}", | ||||
|         ) | ||||
|         if not agent_config: | ||||
|             raise ValueError(f"Could not find agent config for {agent_id}") | ||||
|         if not agent_info_json: | ||||
|             raise ValueError(f"Could not find agent info for {agent_id}") | ||||
| 
 | ||||
|         try: | ||||
|             agent_config = json.loads(agent_config) | ||||
|         except json.JSONDecodeError as e: | ||||
|             raise ValueError(f"Could not JSON decode agent config for {agent_id}") from e | ||||
| 
 | ||||
|         try: | ||||
|             agent_config = AgentConfig(**agent_config) | ||||
|             agent_info = AgentInfo.model_validate_json(agent_info_json) | ||||
|         except Exception as e: | ||||
|             raise ValueError(f"Could not validate(?) agent config for {agent_id}") from e | ||||
|             raise ValueError(f"Could not validate agent info for {agent_id}") from e | ||||
| 
 | ||||
|         return ChatAgent( | ||||
|             agent_id=agent_id, | ||||
|             agent_config=agent_config, | ||||
|             agent_config=agent_info, | ||||
|             inference_api=self.inference_api, | ||||
|             safety_api=self.safety_api, | ||||
|             vector_io_api=self.vector_io_api, | ||||
|             tool_runtime_api=self.tool_runtime_api, | ||||
|             tool_groups_api=self.tool_groups_api, | ||||
|             persistence_store=( | ||||
|                 self.persistence_store if agent_config.enable_session_persistence else self.in_memory_store | ||||
|                 self.persistence_store if agent_info.enable_session_persistence else self.in_memory_store | ||||
|             ), | ||||
|             created_at=agent_info.created_at, | ||||
|         ) | ||||
| 
 | ||||
|     async def create_agent_session( | ||||
|  | @ -212,6 +216,7 @@ class MetaReferenceAgentsImpl(Agents): | |||
|         turn_ids: list[str] | None = None, | ||||
|     ) -> Session: | ||||
|         agent = await self._get_agent_impl(agent_id) | ||||
| 
 | ||||
|         session_info = await agent.storage.get_session_info(session_id) | ||||
|         if session_info is None: | ||||
|             raise ValueError(f"Session {session_id} not found") | ||||
|  | @ -226,24 +231,75 @@ class MetaReferenceAgentsImpl(Agents): | |||
|         ) | ||||
| 
 | ||||
|     async def delete_agents_session(self, agent_id: str, session_id: str) -> None: | ||||
|         await self.persistence_store.delete(f"session:{agent_id}:{session_id}") | ||||
|         agent = await self._get_agent_impl(agent_id) | ||||
|         session_info = await agent.storage.get_session_info(session_id) | ||||
|         if session_info is None: | ||||
|             raise ValueError(f"Session {session_id} not found") | ||||
| 
 | ||||
|         # Delete turns first, then the session | ||||
|         await agent.storage.delete_session_turns(session_id) | ||||
|         await agent.storage.delete_session(session_id) | ||||
| 
 | ||||
|     async def delete_agent(self, agent_id: str) -> None: | ||||
|         # First get all sessions for this agent | ||||
|         agent = await self._get_agent_impl(agent_id) | ||||
|         sessions = await agent.storage.list_sessions() | ||||
| 
 | ||||
|         # Delete all sessions | ||||
|         for session in sessions: | ||||
|             await self.delete_agents_session(agent_id, session.session_id) | ||||
| 
 | ||||
|         # Finally delete the agent itself | ||||
|         await self.persistence_store.delete(f"agent:{agent_id}") | ||||
| 
 | ||||
|     async def shutdown(self) -> None: | ||||
|         pass | ||||
|     async def list_agents(self, start_index: int | None = None, limit: int | None = None) -> PaginatedResponse: | ||||
|         agent_keys = await self.persistence_store.keys_in_range("agent:", "agent:\xff") | ||||
|         agent_list: list[Agent] = [] | ||||
|         for agent_key in agent_keys: | ||||
|             agent_id = agent_key.split(":")[1] | ||||
| 
 | ||||
|     async def list_agents(self) -> ListAgentsResponse: | ||||
|         pass | ||||
|             # Get the agent info using the key | ||||
|             agent_info_json = await self.persistence_store.get(agent_key) | ||||
|             if not agent_info_json: | ||||
|                 logger.error(f"Could not find agent info for key {agent_key}") | ||||
|                 continue | ||||
| 
 | ||||
|             try: | ||||
|                 agent_info = AgentInfo.model_validate_json(agent_info_json) | ||||
|                 agent_list.append( | ||||
|                     Agent( | ||||
|                         agent_id=agent_id, | ||||
|                         agent_config=agent_info, | ||||
|                         created_at=agent_info.created_at, | ||||
|                     ) | ||||
|                 ) | ||||
|             except Exception as e: | ||||
|                 logger.error(f"Error parsing agent info for {agent_id}: {e}") | ||||
|                 continue | ||||
| 
 | ||||
|         # Convert Agent objects to dictionaries | ||||
|         agent_dicts = [agent.model_dump() for agent in agent_list] | ||||
|         return paginate_records(agent_dicts, start_index, limit) | ||||
| 
 | ||||
|     async def get_agent(self, agent_id: str) -> Agent: | ||||
|         pass | ||||
|         chat_agent = await self._get_agent_impl(agent_id) | ||||
|         agent = Agent( | ||||
|             agent_id=agent_id, | ||||
|             agent_config=chat_agent.agent_config, | ||||
|             created_at=chat_agent.created_at, | ||||
|         ) | ||||
|         return agent | ||||
| 
 | ||||
|     async def list_agent_sessions( | ||||
|         self, | ||||
|         agent_id: str, | ||||
|     ) -> ListAgentSessionsResponse: | ||||
|         self, agent_id: str, start_index: int | None = None, limit: int | None = None | ||||
|     ) -> PaginatedResponse: | ||||
|         agent = await self._get_agent_impl(agent_id) | ||||
|         sessions = await agent.storage.list_sessions() | ||||
|         # Convert Session objects to dictionaries | ||||
|         session_dicts = [session.model_dump() for session in sessions] | ||||
|         return paginate_records(session_dicts, start_index, limit) | ||||
| 
 | ||||
|     async def shutdown(self) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     # OpenAI responses | ||||
|  |  | |||
|  | @ -9,9 +9,7 @@ import logging | |||
| import uuid | ||||
| from datetime import datetime, timezone | ||||
| 
 | ||||
| from pydantic import BaseModel | ||||
| 
 | ||||
| from llama_stack.apis.agents import ToolExecutionStep, Turn | ||||
| from llama_stack.apis.agents import AgentConfig, Session, ToolExecutionStep, Turn | ||||
| from llama_stack.distribution.access_control import check_access | ||||
| from llama_stack.distribution.datatypes import AccessAttributes | ||||
| from llama_stack.distribution.request_headers import get_auth_attributes | ||||
|  | @ -20,15 +18,17 @@ from llama_stack.providers.utils.kvstore import KVStore | |||
| log = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| class AgentSessionInfo(BaseModel): | ||||
|     session_id: str | ||||
|     session_name: str | ||||
| class AgentSessionInfo(Session): | ||||
|     # TODO: is this used anywhere? | ||||
|     vector_db_id: str | None = None | ||||
|     started_at: datetime | ||||
|     access_attributes: AccessAttributes | None = None | ||||
| 
 | ||||
| 
 | ||||
| class AgentInfo(AgentConfig): | ||||
|     created_at: datetime | ||||
| 
 | ||||
| 
 | ||||
| class AgentPersistence: | ||||
|     def __init__(self, agent_id: str, kvstore: KVStore): | ||||
|         self.agent_id = agent_id | ||||
|  | @ -46,6 +46,7 @@ class AgentPersistence: | |||
|             session_name=name, | ||||
|             started_at=datetime.now(timezone.utc), | ||||
|             access_attributes=access_attributes, | ||||
|             turns=[], | ||||
|         ) | ||||
| 
 | ||||
|         await self.kvstore.set( | ||||
|  | @ -109,7 +110,7 @@ class AgentPersistence: | |||
|         if not await self.get_session_if_accessible(session_id): | ||||
|             raise ValueError(f"Session {session_id} not found or access denied") | ||||
| 
 | ||||
|         values = await self.kvstore.range( | ||||
|         values = await self.kvstore.values_in_range( | ||||
|             start_key=f"session:{self.agent_id}:{session_id}:", | ||||
|             end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", | ||||
|         ) | ||||
|  | @ -121,7 +122,6 @@ class AgentPersistence: | |||
|             except Exception as e: | ||||
|                 log.error(f"Error parsing turn: {e}") | ||||
|                 continue | ||||
|         turns.sort(key=lambda x: (x.completed_at or datetime.min)) | ||||
|         return turns | ||||
| 
 | ||||
|     async def get_session_turn(self, session_id: str, turn_id: str) -> Turn | None: | ||||
|  | @ -170,3 +170,43 @@ class AgentPersistence: | |||
|             key=f"num_infer_iters_in_turn:{self.agent_id}:{session_id}:{turn_id}", | ||||
|         ) | ||||
|         return int(value) if value else None | ||||
| 
 | ||||
|     async def list_sessions(self) -> list[Session]: | ||||
|         values = await self.kvstore.values_in_range( | ||||
|             start_key=f"session:{self.agent_id}:", | ||||
|             end_key=f"session:{self.agent_id}:\xff\xff\xff\xff", | ||||
|         ) | ||||
|         sessions = [] | ||||
|         for value in values: | ||||
|             try: | ||||
|                 session_info = Session(**json.loads(value)) | ||||
|                 sessions.append(session_info) | ||||
|             except Exception as e: | ||||
|                 log.error(f"Error parsing session info: {e}") | ||||
|                 continue | ||||
|         return sessions | ||||
| 
 | ||||
|     async def delete_session_turns(self, session_id: str) -> None: | ||||
|         """Delete all turns and their associated data for a session. | ||||
| 
 | ||||
|         Args: | ||||
|             session_id: The ID of the session whose turns should be deleted. | ||||
|         """ | ||||
|         turns = await self.get_session_turns(session_id) | ||||
|         for turn in turns: | ||||
|             await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}") | ||||
| 
 | ||||
|     async def delete_session(self, session_id: str) -> None: | ||||
|         """Delete a session and all its associated turns. | ||||
| 
 | ||||
|         Args: | ||||
|             session_id: The ID of the session to delete. | ||||
| 
 | ||||
|         Raises: | ||||
|             ValueError: If the session does not exist. | ||||
|         """ | ||||
|         session_info = await self.get_session_info(session_id) | ||||
|         if session_info is None: | ||||
|             raise ValueError(f"Session {session_id} not found") | ||||
| 
 | ||||
|         await self.kvstore.delete(key=f"session:{self.agent_id}:{session_id}") | ||||
|  |  | |||
|  | @ -64,7 +64,7 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): | |||
|         # Load existing datasets from kvstore | ||||
|         start_key = DATASETS_PREFIX | ||||
|         end_key = f"{DATASETS_PREFIX}\xff" | ||||
|         stored_datasets = await self.kvstore.range(start_key, end_key) | ||||
|         stored_datasets = await self.kvstore.values_in_range(start_key, end_key) | ||||
| 
 | ||||
|         for dataset in stored_datasets: | ||||
|             dataset = Dataset.model_validate_json(dataset) | ||||
|  |  | |||
|  | @ -58,7 +58,7 @@ class MetaReferenceEvalImpl( | |||
|         # Load existing benchmarks from kvstore | ||||
|         start_key = EVAL_TASKS_PREFIX | ||||
|         end_key = f"{EVAL_TASKS_PREFIX}\xff" | ||||
|         stored_benchmarks = await self.kvstore.range(start_key, end_key) | ||||
|         stored_benchmarks = await self.kvstore.values_in_range(start_key, end_key) | ||||
| 
 | ||||
|         for benchmark in stored_benchmarks: | ||||
|             benchmark = Benchmark.model_validate_json(benchmark) | ||||
|  |  | |||
|  | @ -125,7 +125,7 @@ class FaissVectorIOAdapter(VectorIO, VectorDBsProtocolPrivate): | |||
|         # Load existing banks from kvstore | ||||
|         start_key = VECTOR_DBS_PREFIX | ||||
|         end_key = f"{VECTOR_DBS_PREFIX}\xff" | ||||
|         stored_vector_dbs = await self.kvstore.range(start_key, end_key) | ||||
|         stored_vector_dbs = await self.kvstore.values_in_range(start_key, end_key) | ||||
| 
 | ||||
|         for vector_db_data in stored_vector_dbs: | ||||
|             vector_db = VectorDB.model_validate_json(vector_db_data) | ||||
|  |  | |||
|  | @ -42,7 +42,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate): | |||
|         # Load existing datasets from kvstore | ||||
|         start_key = DATASETS_PREFIX | ||||
|         end_key = f"{DATASETS_PREFIX}\xff" | ||||
|         stored_datasets = await self.kvstore.range(start_key, end_key) | ||||
|         stored_datasets = await self.kvstore.values_in_range(start_key, end_key) | ||||
| 
 | ||||
|         for dataset in stored_datasets: | ||||
|             dataset = Dataset.model_validate_json(dataset) | ||||
|  |  | |||
|  | @ -16,4 +16,6 @@ class KVStore(Protocol): | |||
| 
 | ||||
|     async def delete(self, key: str) -> None: ... | ||||
| 
 | ||||
|     async def range(self, start_key: str, end_key: str) -> list[str]: ... | ||||
|     async def values_in_range(self, start_key: str, end_key: str) -> list[str]: ... | ||||
| 
 | ||||
|     async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: ... | ||||
|  |  | |||
|  | @ -26,9 +26,16 @@ class InmemoryKVStoreImpl(KVStore): | |||
|     async def set(self, key: str, value: str) -> None: | ||||
|         self._store[key] = value | ||||
| 
 | ||||
|     async def range(self, start_key: str, end_key: str) -> list[str]: | ||||
|     async def values_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         return [self._store[key] for key in self._store.keys() if key >= start_key and key < end_key] | ||||
| 
 | ||||
|     async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         """Get all keys in the given range.""" | ||||
|         return [key for key in self._store.keys() if key >= start_key and key < end_key] | ||||
| 
 | ||||
|     async def delete(self, key: str) -> None: | ||||
|         del self._store[key] | ||||
| 
 | ||||
| 
 | ||||
| async def kvstore_impl(config: KVStoreConfig) -> KVStore: | ||||
|     if config.type == KVStoreType.redis.value: | ||||
|  |  | |||
|  | @ -57,7 +57,7 @@ class MongoDBKVStoreImpl(KVStore): | |||
|         key = self._namespaced_key(key) | ||||
|         await self.collection.delete_one({"key": key}) | ||||
| 
 | ||||
|     async def range(self, start_key: str, end_key: str) -> list[str]: | ||||
|     async def values_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         start_key = self._namespaced_key(start_key) | ||||
|         end_key = self._namespaced_key(end_key) | ||||
|         query = { | ||||
|  | @ -68,3 +68,10 @@ class MongoDBKVStoreImpl(KVStore): | |||
|         async for doc in cursor: | ||||
|             result.append(doc["value"]) | ||||
|         return result | ||||
| 
 | ||||
|     async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         start_key = self._namespaced_key(start_key) | ||||
|         end_key = self._namespaced_key(end_key) | ||||
|         query = {"key": {"$gte": start_key, "$lt": end_key}} | ||||
|         cursor = self.collection.find(query, {"key": 1, "_id": 0}).sort("key", 1) | ||||
|         return [doc["key"] for doc in cursor] | ||||
|  |  | |||
|  | @ -85,7 +85,7 @@ class PostgresKVStoreImpl(KVStore): | |||
|             (key,), | ||||
|         ) | ||||
| 
 | ||||
|     async def range(self, start_key: str, end_key: str) -> list[str]: | ||||
|     async def values_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         start_key = self._namespaced_key(start_key) | ||||
|         end_key = self._namespaced_key(end_key) | ||||
| 
 | ||||
|  | @ -99,3 +99,13 @@ class PostgresKVStoreImpl(KVStore): | |||
|             (start_key, end_key), | ||||
|         ) | ||||
|         return [row[0] for row in self.cursor.fetchall()] | ||||
| 
 | ||||
|     async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         start_key = self._namespaced_key(start_key) | ||||
|         end_key = self._namespaced_key(end_key) | ||||
| 
 | ||||
|         self.cursor.execute( | ||||
|             f"SELECT key FROM {self.config.table_name} WHERE key >= %s AND key < %s", | ||||
|             (start_key, end_key), | ||||
|         ) | ||||
|         return [row[0] for row in self.cursor.fetchall()] | ||||
|  |  | |||
|  | @ -42,7 +42,7 @@ class RedisKVStoreImpl(KVStore): | |||
|         key = self._namespaced_key(key) | ||||
|         await self.redis.delete(key) | ||||
| 
 | ||||
|     async def range(self, start_key: str, end_key: str) -> list[str]: | ||||
|     async def values_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         start_key = self._namespaced_key(start_key) | ||||
|         end_key = self._namespaced_key(end_key) | ||||
|         cursor = 0 | ||||
|  | @ -67,3 +67,10 @@ class RedisKVStoreImpl(KVStore): | |||
|             ] | ||||
| 
 | ||||
|         return [] | ||||
| 
 | ||||
|     async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         """Get all keys in the given range.""" | ||||
|         matching_keys = await self.redis.zrangebylex(self.namespace, f"[{start_key}", f"[{end_key}") | ||||
|         if not matching_keys: | ||||
|             return [] | ||||
|         return [k.decode("utf-8") for k in matching_keys] | ||||
|  |  | |||
|  | @ -54,7 +54,7 @@ class SqliteKVStoreImpl(KVStore): | |||
|             await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) | ||||
|             await db.commit() | ||||
| 
 | ||||
|     async def range(self, start_key: str, end_key: str) -> list[str]: | ||||
|     async def values_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         async with aiosqlite.connect(self.db_path) as db: | ||||
|             async with db.execute( | ||||
|                 f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", | ||||
|  | @ -65,3 +65,13 @@ class SqliteKVStoreImpl(KVStore): | |||
|                     _, value, _ = row | ||||
|                     result.append(value) | ||||
|                 return result | ||||
| 
 | ||||
|     async def keys_in_range(self, start_key: str, end_key: str) -> list[str]: | ||||
|         """Get all keys in the given range.""" | ||||
|         async with aiosqlite.connect(self.db_path) as db: | ||||
|             cursor = await db.execute( | ||||
|                 f"SELECT key FROM {self.table_name} WHERE key >= ? AND key <= ?", | ||||
|                 (start_key, end_key), | ||||
|             ) | ||||
|             rows = await cursor.fetchall() | ||||
|             return [row[0] for row in rows] | ||||
|  |  | |||
							
								
								
									
										230
									
								
								tests/unit/providers/agent/test_meta_reference_agent.py
									
										
									
									
									
										Normal file
									
								
							
							
						
						
									
										230
									
								
								tests/unit/providers/agent/test_meta_reference_agent.py
									
										
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,230 @@ | |||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||||
| # All rights reserved. | ||||
| # | ||||
| # This source code is licensed under the terms described in the LICENSE file in | ||||
| # the root directory of this source tree. | ||||
| 
 | ||||
| from datetime import datetime | ||||
| from unittest.mock import AsyncMock | ||||
| 
 | ||||
| import pytest | ||||
| import pytest_asyncio | ||||
| 
 | ||||
| from llama_stack.apis.agents import ( | ||||
|     Agent, | ||||
|     AgentConfig, | ||||
|     AgentCreateResponse, | ||||
| ) | ||||
| from llama_stack.apis.common.responses import PaginatedResponse | ||||
| from llama_stack.apis.inference import Inference | ||||
| from llama_stack.apis.safety import Safety | ||||
| from llama_stack.apis.tools import ToolGroups, ToolRuntime | ||||
| from llama_stack.apis.vector_io import VectorIO | ||||
| from llama_stack.providers.inline.agents.meta_reference.agents import MetaReferenceAgentsImpl | ||||
| from llama_stack.providers.inline.agents.meta_reference.config import MetaReferenceAgentsImplConfig | ||||
| from llama_stack.providers.inline.agents.meta_reference.persistence import AgentInfo | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def mock_apis(): | ||||
|     return { | ||||
|         "inference_api": AsyncMock(spec=Inference), | ||||
|         "vector_io_api": AsyncMock(spec=VectorIO), | ||||
|         "safety_api": AsyncMock(spec=Safety), | ||||
|         "tool_runtime_api": AsyncMock(spec=ToolRuntime), | ||||
|         "tool_groups_api": AsyncMock(spec=ToolGroups), | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def config(tmp_path): | ||||
|     return MetaReferenceAgentsImplConfig( | ||||
|         persistence_store={ | ||||
|             "type": "sqlite", | ||||
|             "db_path": str(tmp_path / "test.db"), | ||||
|         }, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @pytest_asyncio.fixture | ||||
| async def agents_impl(config, mock_apis): | ||||
|     impl = MetaReferenceAgentsImpl( | ||||
|         config, | ||||
|         mock_apis["inference_api"], | ||||
|         mock_apis["vector_io_api"], | ||||
|         mock_apis["safety_api"], | ||||
|         mock_apis["tool_runtime_api"], | ||||
|         mock_apis["tool_groups_api"], | ||||
|     ) | ||||
|     await impl.initialize() | ||||
|     yield impl | ||||
|     await impl.shutdown() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def sample_agent_config(): | ||||
|     return AgentConfig( | ||||
|         sampling_params={ | ||||
|             "strategy": {"type": "greedy"}, | ||||
|             "max_tokens": 0, | ||||
|             "repetition_penalty": 1.0, | ||||
|         }, | ||||
|         input_shields=["string"], | ||||
|         output_shields=["string"], | ||||
|         toolgroups=["string"], | ||||
|         client_tools=[ | ||||
|             { | ||||
|                 "name": "string", | ||||
|                 "description": "string", | ||||
|                 "parameters": [ | ||||
|                     { | ||||
|                         "name": "string", | ||||
|                         "parameter_type": "string", | ||||
|                         "description": "string", | ||||
|                         "required": True, | ||||
|                         "default": None, | ||||
|                     } | ||||
|                 ], | ||||
|                 "metadata": { | ||||
|                     "property1": None, | ||||
|                     "property2": None, | ||||
|                 }, | ||||
|             } | ||||
|         ], | ||||
|         tool_choice="auto", | ||||
|         tool_prompt_format="json", | ||||
|         tool_config={ | ||||
|             "tool_choice": "auto", | ||||
|             "tool_prompt_format": "json", | ||||
|             "system_message_behavior": "append", | ||||
|         }, | ||||
|         max_infer_iters=10, | ||||
|         model="string", | ||||
|         instructions="string", | ||||
|         enable_session_persistence=False, | ||||
|         response_format={ | ||||
|             "type": "json_schema", | ||||
|             "json_schema": { | ||||
|                 "property1": None, | ||||
|                 "property2": None, | ||||
|             }, | ||||
|         }, | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| async def test_create_agent(agents_impl, sample_agent_config): | ||||
|     response = await agents_impl.create_agent(sample_agent_config) | ||||
| 
 | ||||
|     assert isinstance(response, AgentCreateResponse) | ||||
|     assert response.agent_id is not None | ||||
| 
 | ||||
|     stored_agent = await agents_impl.persistence_store.get(f"agent:{response.agent_id}") | ||||
|     assert stored_agent is not None | ||||
|     agent_info = AgentInfo.model_validate_json(stored_agent) | ||||
|     assert agent_info.model == sample_agent_config.model | ||||
|     assert agent_info.created_at is not None | ||||
|     assert isinstance(agent_info.created_at, datetime) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| async def test_get_agent(agents_impl, sample_agent_config): | ||||
|     create_response = await agents_impl.create_agent(sample_agent_config) | ||||
|     agent_id = create_response.agent_id | ||||
| 
 | ||||
|     agent = await agents_impl.get_agent(agent_id) | ||||
| 
 | ||||
|     assert isinstance(agent, Agent) | ||||
|     assert agent.agent_id == agent_id | ||||
|     assert agent.agent_config.model == sample_agent_config.model | ||||
|     assert agent.created_at is not None | ||||
|     assert isinstance(agent.created_at, datetime) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| async def test_list_agents(agents_impl, sample_agent_config): | ||||
|     agent1_response = await agents_impl.create_agent(sample_agent_config) | ||||
|     agent2_response = await agents_impl.create_agent(sample_agent_config) | ||||
| 
 | ||||
|     response = await agents_impl.list_agents() | ||||
| 
 | ||||
|     assert isinstance(response, PaginatedResponse) | ||||
|     assert len(response.data) == 2 | ||||
|     agent_ids = {agent["agent_id"] for agent in response.data} | ||||
|     assert agent1_response.agent_id in agent_ids | ||||
|     assert agent2_response.agent_id in agent_ids | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.parametrize("enable_session_persistence", [True, False]) | ||||
| async def test_create_agent_session_persistence(agents_impl, sample_agent_config, enable_session_persistence): | ||||
|     # Create an agent with specified persistence setting | ||||
|     config = sample_agent_config.model_copy() | ||||
|     config.enable_session_persistence = enable_session_persistence | ||||
|     response = await agents_impl.create_agent(config) | ||||
|     agent_id = response.agent_id | ||||
| 
 | ||||
|     # Create a session | ||||
|     session_response = await agents_impl.create_agent_session(agent_id, "test_session") | ||||
|     assert session_response.session_id is not None | ||||
| 
 | ||||
|     # Verify the session was stored | ||||
|     session = await agents_impl.get_agents_session(agent_id, session_response.session_id) | ||||
|     assert session.session_name == "test_session" | ||||
|     assert session.session_id == session_response.session_id | ||||
|     assert session.started_at is not None | ||||
|     assert session.turns == [] | ||||
| 
 | ||||
|     # Delete the session | ||||
|     await agents_impl.delete_agents_session(agent_id, session_response.session_id) | ||||
| 
 | ||||
|     # Verify the session was deleted | ||||
|     with pytest.raises(ValueError): | ||||
|         await agents_impl.get_agents_session(agent_id, session_response.session_id) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| @pytest.mark.parametrize("enable_session_persistence", [True, False]) | ||||
| async def test_list_agent_sessions_persistence(agents_impl, sample_agent_config, enable_session_persistence): | ||||
|     # Create an agent with specified persistence setting | ||||
|     config = sample_agent_config.model_copy() | ||||
|     config.enable_session_persistence = enable_session_persistence | ||||
|     response = await agents_impl.create_agent(config) | ||||
|     agent_id = response.agent_id | ||||
| 
 | ||||
|     # Create multiple sessions | ||||
|     session1 = await agents_impl.create_agent_session(agent_id, "session1") | ||||
|     session2 = await agents_impl.create_agent_session(agent_id, "session2") | ||||
| 
 | ||||
|     # List sessions | ||||
|     sessions = await agents_impl.list_agent_sessions(agent_id) | ||||
|     assert len(sessions.data) == 2 | ||||
|     session_ids = {s["session_id"] for s in sessions.data} | ||||
|     assert session1.session_id in session_ids | ||||
|     assert session2.session_id in session_ids | ||||
| 
 | ||||
|     # Delete one session | ||||
|     await agents_impl.delete_agents_session(agent_id, session1.session_id) | ||||
| 
 | ||||
|     # Verify the session was deleted | ||||
|     with pytest.raises(ValueError): | ||||
|         await agents_impl.get_agents_session(agent_id, session1.session_id) | ||||
| 
 | ||||
|     # List sessions again | ||||
|     sessions = await agents_impl.list_agent_sessions(agent_id) | ||||
|     assert len(sessions.data) == 1 | ||||
|     assert session2.session_id in {s["session_id"] for s in sessions.data} | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.asyncio | ||||
| async def test_delete_agent(agents_impl, sample_agent_config): | ||||
|     # Create an agent | ||||
|     response = await agents_impl.create_agent(sample_agent_config) | ||||
|     agent_id = response.agent_id | ||||
| 
 | ||||
|     # Delete the agent | ||||
|     await agents_impl.delete_agent(agent_id) | ||||
| 
 | ||||
|     # Verify the agent was deleted | ||||
|     with pytest.raises(ValueError): | ||||
|         await agents_impl.get_agent(agent_id) | ||||
|  | @ -54,6 +54,7 @@ async def test_session_access_control(mock_get_auth_attributes, test_setup): | |||
|         session_name="Restricted Session", | ||||
|         started_at=datetime.now(), | ||||
|         access_attributes=AccessAttributes(roles=["admin"], teams=["security-team"]), | ||||
|         turns=[], | ||||
|     ) | ||||
| 
 | ||||
|     await agent_persistence.kvstore.set( | ||||
|  | @ -85,6 +86,7 @@ async def test_turn_access_control(mock_get_auth_attributes, test_setup): | |||
|         session_name="Restricted Session", | ||||
|         started_at=datetime.now(), | ||||
|         access_attributes=AccessAttributes(roles=["admin"]), | ||||
|         turns=[], | ||||
|     ) | ||||
| 
 | ||||
|     await agent_persistence.kvstore.set( | ||||
|  | @ -137,6 +139,7 @@ async def test_tool_call_and_infer_iters_access_control(mock_get_auth_attributes | |||
|         session_name="Restricted Session", | ||||
|         started_at=datetime.now(), | ||||
|         access_attributes=AccessAttributes(roles=["admin"]), | ||||
|         turns=[], | ||||
|     ) | ||||
| 
 | ||||
|     await agent_persistence.kvstore.set( | ||||
|  |  | |||
		Loading…
	
	Add table
		Add a link
		
	
		Reference in a new issue