diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index cf247b01b..d008331d5 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -276,6 +276,8 @@ class AgentConfigCommon(BaseModel): default=ToolPromptFormat.json ) + max_infer_iters: int = 10 + @json_schema_type class AgentConfig(AgentConfigCommon): diff --git a/llama_stack/distribution/utils/prompt_for_config.py b/llama_stack/distribution/utils/prompt_for_config.py index 63ee64fb0..54e9e9cc3 100644 --- a/llama_stack/distribution/utils/prompt_for_config.py +++ b/llama_stack/distribution/utils/prompt_for_config.py @@ -83,10 +83,12 @@ def prompt_for_discriminated_union( if isinstance(typ, FieldInfo): inner_type = typ.annotation discriminator = typ.discriminator + default_value = typ.default else: args = get_args(typ) inner_type = args[0] discriminator = args[1].discriminator + default_value = args[1].default union_types = get_args(inner_type) # Find the discriminator field in each union type @@ -99,9 +101,14 @@ def prompt_for_discriminated_union( type_map[value] = t while True: - discriminator_value = input( - f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): " - ) + prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})" + if default_value is not None: + prompt += f" (default: {default_value})" + + discriminator_value = input(f"{prompt}: ") + if discriminator_value == "" and default_value is not None: + discriminator_value = default_value + if discriminator_value in type_map: chosen_type = type_map[discriminator_value] print(f"\nConfiguring {chosen_type.__name__}:") diff --git a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py index 9b2957506..b7f56917e 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agent_instance.py +++ b/llama_stack/providers/impls/meta_reference/agents/agent_instance.py @@ -25,10 +25,19 @@ from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403 +from llama_stack.providers.utils.api import KVStore + from .rag.context_retriever import generate_rag_query from .safety import SafetyException, ShieldRunnerMixin from .tools.base import BaseTool -from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool +from .tools.builtin import ( + CodeInterpreterTool, + interpret_content_as_attachment, + PhotogenTool, + SearchTool, + WolframAlphaTool, +) +from .tools.safety import with_safety def make_random_string(length: int = 8): @@ -40,23 +49,44 @@ def make_random_string(length: int = 8): class ChatAgent(ShieldRunnerMixin): def __init__( self, + agent_id: str, agent_config: AgentConfig, inference_api: Inference, memory_api: Memory, safety_api: Safety, - builtin_tools: List[SingleMessageBuiltinTool], - max_infer_iters: int = 10, + persistence_store: KVStore, ): + self.agent_id = agent_id self.agent_config = agent_config self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api - - self.max_infer_iters = max_infer_iters - self.tools_dict = {t.get_name(): t for t in builtin_tools} + self.persistence_store = persistence_store self.tempdir = tempfile.mkdtemp() - self.sessions = {} + + builtin_tools = [] + for tool_defn in agent_config.tools: + if isinstance(tool_defn, WolframAlphaToolDefinition): + tool = WolframAlphaTool(tool_defn.api_key) + elif isinstance(tool_defn, SearchToolDefinition): + tool = SearchTool(tool_defn.engine, tool_defn.api_key) + elif isinstance(tool_defn, CodeInterpreterToolDefinition): + tool = CodeInterpreterTool() + elif isinstance(tool_defn, PhotogenToolDefinition): + tool = PhotogenTool(dump_dir=self.tempdir) + else: + continue + + builtin_tools.append( + with_safety( + tool, + safety_api, + tool_defn.input_shields, + tool_defn.output_shields, + ) + ) + self.tools_dict = {t.get_name(): t for t in builtin_tools} ShieldRunnerMixin.__init__( self, @@ -80,7 +110,6 @@ class ChatAgent(ShieldRunnerMixin): msg.context = None messages.append(msg) - # messages.extend(turn.input_messages) for step in turn.steps: if step.step_type == StepType.inference.value: messages.append(step.model_response) @@ -105,31 +134,64 @@ class ChatAgent(ShieldRunnerMixin): # print_dialog(messages) return messages - def create_session(self, name: str) -> Session: + async def create_session(self, name: str) -> str: session_id = str(uuid.uuid4()) - session = Session( + session = dict( session_id=session_id, session_name=name, - turns=[], - started_at=datetime.now(), + started_at=str(datetime.now()), ) - self.sessions[session_id] = session - return session + await self.persistence_store.set( + key=f"session:{self.agent_id}:{session_id}", + value=json.dumps(session), + ) + return session_id + + async def _add_turn_to_session(self, session_id: str, turn: Turn): + await self.persistence_store.set( + key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}", + value=json.dumps(turn.dict()), + ) + + async def _get_session_turns(self, session_id: str) -> List[Turn]: + values = await self.persistence_store.range( + start_key=f"session:{self.agent_id}:{session_id}:", + end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff", + ) + turns = [] + for value in values: + try: + turn = Turn(**json.loads(value)) + turns.append(turn) + except Exception as e: + print(f"Error parsing turn: {e}") + continue + + return turns + + async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]: + value = await self.persistence_store.get( + key=f"session:{self.agent_id}:{session_id}", + ) + if not value: + return None + + return json.loads(value) async def create_and_execute_turn( self, request: AgentTurnCreateRequest ) -> AsyncGenerator: - assert ( - request.session_id in self.sessions - ), f"Session {request.session_id} not found" + session_info = await self._get_session_info(request.session_id) + if session_info is None: + raise ValueError(f"Session {request.session_id} not found") - session = self.sessions[request.session_id] + turns = await self._get_session_turns(request.session_id) messages = [] - if len(session.turns) == 0 and self.agent_config.instructions != "": + if len(turns) == 0 and self.agent_config.instructions != "": messages.append(SystemMessage(content=self.agent_config.instructions)) - for i, turn in enumerate(session.turns): + for i, turn in enumerate(turns): messages.extend(self.turn_to_messages(turn)) messages.extend(request.messages) @@ -186,7 +248,7 @@ class ChatAgent(ShieldRunnerMixin): completed_at=datetime.now(), steps=steps, ) - session.turns.append(turn) + await self._add_turn_to_session(request.session_id, turn) chunk = AgentTurnResponseStreamChunk( event=AgentTurnResponseEvent( @@ -455,7 +517,7 @@ class ChatAgent(ShieldRunnerMixin): ) ) - if n_iter >= self.max_infer_iters: + if n_iter >= self.agent_config.max_infer_iters: cprint("Done with MAX iterations, exiting.") yield message break diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index bb042baa3..a7ab548d2 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -6,7 +6,6 @@ import logging -import tempfile import uuid from typing import AsyncGenerator @@ -15,39 +14,15 @@ from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety from llama_stack.apis.agents import * # noqa: F403 +from llama_stack.providers.utils.kvstore import kvstore_impl + from .agent_instance import ChatAgent from .config import MetaReferenceAgentsImplConfig -from .tools.builtin import ( - CodeInterpreterTool, - PhotogenTool, - SearchTool, - WolframAlphaTool, -) -from .tools.safety import with_safety - logger = logging.getLogger() logger.setLevel(logging.INFO) -AGENT_INSTANCES_BY_ID = {} - - -class KVStore(Protocol): - def get(self, key: str) -> str: - ... - - def set(self, key: str, value: str) -> None: - ... - -def kvstore_impl(config: KVStoreConfig) -> KVStore: - if config.type == KVStoreType.redis: - from .kvstore_impls.redis import RedisKVStoreImpl - return RedisKVStoreImpl(config) - - return None - - class MetaReferenceAgentsImpl(Agents): def __init__( self, @@ -60,10 +35,9 @@ class MetaReferenceAgentsImpl(Agents): self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api - self.kvstore = kvstore_impl(config.kvstore) async def initialize(self) -> None: - pass + self.persistence_store = await kvstore_impl(self.config.persistence_store) async def create_agent( self, @@ -71,38 +45,42 @@ class MetaReferenceAgentsImpl(Agents): ) -> AgentCreateResponse: agent_id = str(uuid.uuid4()) - builtin_tools = [] - for tool_defn in agent_config.tools: - if isinstance(tool_defn, WolframAlphaToolDefinition): - tool = WolframAlphaTool(tool_defn.api_key) - elif isinstance(tool_defn, SearchToolDefinition): - tool = SearchTool(tool_defn.engine, tool_defn.api_key) - elif isinstance(tool_defn, CodeInterpreterToolDefinition): - tool = CodeInterpreterTool() - elif isinstance(tool_defn, PhotogenToolDefinition): - tool = PhotogenTool(dump_dir=tempfile.mkdtemp()) - else: - continue + await self.persistence_store.set( + key=f"agent:{agent_id}", + value=json.dumps(agent_config.dict()), + ) + return AgentCreateResponse( + agent_id=agent_id, + ) - builtin_tools.append( - with_safety( - tool, - self.safety_api, - tool_defn.input_shields, - tool_defn.output_shields, - ) - ) + async def get_agent(self, agent_id: str) -> ChatAgent: + agent_config = await self.persistence_store.get( + key=f"agent:{agent_id}", + ) + if not agent_config: + raise ValueError(f"Could not find agent config for {agent_id}") - AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent( + try: + agent_config = json.loads(agent_config) + except json.JSONDecodeError as e: + raise ValueError( + f"Could not JSON decode agent config for {agent_id}" + ) from e + + try: + agent_config = AgentConfig(**agent_config) + except Exception as e: + raise ValueError( + f"Could not validate(?) agent config for {agent_id}" + ) from e + + return ChatAgent( + agent_id=agent_id, agent_config=agent_config, inference_api=self.inference_api, safety_api=self.safety_api, memory_api=self.memory_api, - builtin_tools=builtin_tools, - ) - - return AgentCreateResponse( - agent_id=agent_id, + persistence_store=self.persistence_store, ) async def create_agent_session( @@ -110,12 +88,11 @@ class MetaReferenceAgentsImpl(Agents): agent_id: str, session_name: str, ) -> AgentSessionCreateResponse: - assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" - agent = AGENT_INSTANCES_BY_ID[agent_id] + agent = await self.get_agent(agent_id) - session = agent.create_session(session_name) + session_id = await agent.create_session(session_name) return AgentSessionCreateResponse( - session_id=session.session_id, + session_id=session_id, ) async def create_agent_turn( @@ -131,6 +108,8 @@ class MetaReferenceAgentsImpl(Agents): attachments: Optional[List[Attachment]] = None, stream: Optional[bool] = False, ) -> AsyncGenerator: + agent = await self.get_agent(agent_id) + # wrapper request to make it easier to pass around (internal only, not exposed to API) request = AgentTurnCreateRequest( agent_id=agent_id, @@ -140,12 +119,5 @@ class MetaReferenceAgentsImpl(Agents): stream=stream, ) - agent_id = request.agent_id - assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" - agent = AGENT_INSTANCES_BY_ID[agent_id] - - assert ( - request.session_id in agent.sessions - ), f"Session {request.session_id} not found" async for event in agent.create_and_execute_turn(request): yield event diff --git a/llama_stack/providers/impls/meta_reference/agents/config.py b/llama_stack/providers/impls/meta_reference/agents/config.py index f293f46c2..d4f15968e 100644 --- a/llama_stack/providers/impls/meta_reference/agents/config.py +++ b/llama_stack/providers/impls/meta_reference/agents/config.py @@ -8,4 +8,4 @@ from llama_stack.providers.utils.kvstore import KVStoreConfig class MetaReferenceAgentsImplConfig(BaseModel): - kv_store: KVStoreConfig + persistence_store: KVStoreConfig diff --git a/llama_stack/providers/utils/kvstore/__init__.py b/llama_stack/providers/utils/kvstore/__init__.py index 756f351d8..470a75d2d 100644 --- a/llama_stack/providers/utils/kvstore/__init__.py +++ b/llama_stack/providers/utils/kvstore/__init__.py @@ -3,3 +3,5 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. + +from .kvstore import * # noqa: F401, F403 diff --git a/llama_stack/providers/utils/kvstore/api.py b/llama_stack/providers/utils/kvstore/api.py index 99d666f42..5ddc8db5b 100644 --- a/llama_stack/providers/utils/kvstore/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -5,20 +5,20 @@ # the root directory of this source tree. from datetime import datetime -from typing import Any, List, Optional, Protocol +from typing import List, Optional, Protocol from pydantic import BaseModel class KVStoreValue(BaseModel): key: str - value: Any + value: str expiration: Optional[datetime] = None class KVStore(Protocol): async def set( - self, key: str, value: Any, expiration: Optional[datetime] = None + self, key: str, value: str, expiration: Optional[datetime] = None ) -> None: ... async def get(self, key: str) -> Optional[KVStoreValue]: ... diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py index 515640897..0a99c78cd 100644 --- a/llama_stack/providers/utils/kvstore/config.py +++ b/llama_stack/providers/utils/kvstore/config.py @@ -14,7 +14,7 @@ from typing_extensions import Annotated class KVStoreType(Enum): redis = "redis" sqlite = "sqlite" - pgvector = "pgvector" + postgres = "postgres" class CommonConfig(BaseModel): @@ -37,8 +37,8 @@ class SqliteKVStoreImplConfig(CommonConfig): ) -class PGVectorKVStoreImplConfig(CommonConfig): - type: Literal[KVStoreType.pgvector.value] = KVStoreType.pgvector.value +class PostgresKVStoreImplConfig(CommonConfig): + type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value host: str = "localhost" port: int = 5432 db: str = "llamastack" @@ -47,6 +47,6 @@ class PGVectorKVStoreImplConfig(CommonConfig): KVStoreConfig = Annotated[ - Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PGVectorKVStoreImplConfig], - Field(discriminator="type"), + Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig], + Field(discriminator="type", default=KVStoreType.sqlite.value), ] diff --git a/llama_stack/providers/utils/kvstore/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py index 340530750..977da06d3 100644 --- a/llama_stack/providers/utils/kvstore/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -5,7 +5,7 @@ # the root directory of this source tree. from datetime import datetime, timedelta -from typing import Any, List, Optional +from typing import List, Optional from redis.asyncio import Redis @@ -26,7 +26,7 @@ class RedisKVStoreImpl(KVStore): return f"{self.config.namespace}:{key}" async def set( - self, key: str, value: Any, expiration: Optional[datetime] = None + self, key: str, value: str, expiration: Optional[datetime] = None ) -> None: key = self._namespaced_key(key) await self.redis.set(key, value) @@ -50,11 +50,4 @@ class RedisKVStoreImpl(KVStore): start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) - keys = await self.redis.keys(f"{start_key}*") - result = [] - for key in keys: - if key <= end_key: - value = await self.get(key) - if value: - result.append(value) - return result + return await self.redis.zrangebylex(start_key, end_key) diff --git a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py index ae8374156..9d7fd4d86 100644 --- a/llama_stack/providers/utils/kvstore/sqlite/sqlite.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -4,9 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import json from datetime import datetime -from typing import Any, List, Optional +from typing import List, Optional import aiosqlite @@ -33,12 +32,12 @@ class SqliteKVStoreImpl(KVStore): await db.commit() async def set( - self, key: str, value: Any, expiration: Optional[datetime] = None + self, key: str, value: str, expiration: Optional[datetime] = None ) -> None: async with aiosqlite.connect(self.db_path) as db: await db.execute( f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", - (key, json.dumps(value), expiration), + (key, value, expiration), ) await db.commit() @@ -51,9 +50,7 @@ class SqliteKVStoreImpl(KVStore): if row is None: return None value, expiration = row - return KVStoreValue( - key=key, value=json.loads(value), expiration=expiration - ) + return KVStoreValue(key=key, value=value, expiration=expiration) async def delete(self, key: str) -> None: async with aiosqlite.connect(self.db_path) as db: @@ -70,8 +67,6 @@ class SqliteKVStoreImpl(KVStore): async for row in cursor: key, value, expiration = row result.append( - KVStoreValue( - key=key, value=json.loads(value), expiration=expiration - ) + KVStoreValue(key=key, value=value, expiration=expiration) ) return result