initial cut at using kvstores for agent persistence

This commit is contained in:
Ashwin Bharambe 2024-09-21 21:16:26 -07:00
parent 61974e337f
commit 4eb0f30891
10 changed files with 153 additions and 120 deletions

View file

@ -276,6 +276,8 @@ class AgentConfigCommon(BaseModel):
default=ToolPromptFormat.json default=ToolPromptFormat.json
) )
max_infer_iters: int = 10
@json_schema_type @json_schema_type
class AgentConfig(AgentConfigCommon): class AgentConfig(AgentConfigCommon):

View file

@ -83,10 +83,12 @@ def prompt_for_discriminated_union(
if isinstance(typ, FieldInfo): if isinstance(typ, FieldInfo):
inner_type = typ.annotation inner_type = typ.annotation
discriminator = typ.discriminator discriminator = typ.discriminator
default_value = typ.default
else: else:
args = get_args(typ) args = get_args(typ)
inner_type = args[0] inner_type = args[0]
discriminator = args[1].discriminator discriminator = args[1].discriminator
default_value = args[1].default
union_types = get_args(inner_type) union_types = get_args(inner_type)
# Find the discriminator field in each union type # Find the discriminator field in each union type
@ -99,9 +101,14 @@ def prompt_for_discriminated_union(
type_map[value] = t type_map[value] = t
while True: while True:
discriminator_value = input( prompt = f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())})"
f"Enter `{discriminator}` for {field_name} (options: {', '.join(type_map.keys())}): " if default_value is not None:
) prompt += f" (default: {default_value})"
discriminator_value = input(f"{prompt}: ")
if discriminator_value == "" and default_value is not None:
discriminator_value = default_value
if discriminator_value in type_map: if discriminator_value in type_map:
chosen_type = type_map[discriminator_value] chosen_type = type_map[discriminator_value]
print(f"\nConfiguring {chosen_type.__name__}:") print(f"\nConfiguring {chosen_type.__name__}:")

View file

@ -25,10 +25,19 @@ from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403 from llama_stack.apis.safety import * # noqa: F403
from llama_stack.providers.utils.api import KVStore
from .rag.context_retriever import generate_rag_query from .rag.context_retriever import generate_rag_query
from .safety import SafetyException, ShieldRunnerMixin from .safety import SafetyException, ShieldRunnerMixin
from .tools.base import BaseTool from .tools.base import BaseTool
from .tools.builtin import interpret_content_as_attachment, SingleMessageBuiltinTool from .tools.builtin import (
CodeInterpreterTool,
interpret_content_as_attachment,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
def make_random_string(length: int = 8): def make_random_string(length: int = 8):
@ -40,23 +49,44 @@ def make_random_string(length: int = 8):
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):
def __init__( def __init__(
self, self,
agent_id: str,
agent_config: AgentConfig, agent_config: AgentConfig,
inference_api: Inference, inference_api: Inference,
memory_api: Memory, memory_api: Memory,
safety_api: Safety, safety_api: Safety,
builtin_tools: List[SingleMessageBuiltinTool], persistence_store: KVStore,
max_infer_iters: int = 10,
): ):
self.agent_id = agent_id
self.agent_config = agent_config self.agent_config = agent_config
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
self.persistence_store = persistence_store
self.max_infer_iters = max_infer_iters
self.tools_dict = {t.get_name(): t for t in builtin_tools}
self.tempdir = tempfile.mkdtemp() self.tempdir = tempfile.mkdtemp()
self.sessions = {}
builtin_tools = []
for tool_defn in agent_config.tools:
if isinstance(tool_defn, WolframAlphaToolDefinition):
tool = WolframAlphaTool(tool_defn.api_key)
elif isinstance(tool_defn, SearchToolDefinition):
tool = SearchTool(tool_defn.engine, tool_defn.api_key)
elif isinstance(tool_defn, CodeInterpreterToolDefinition):
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=self.tempdir)
else:
continue
builtin_tools.append(
with_safety(
tool,
safety_api,
tool_defn.input_shields,
tool_defn.output_shields,
)
)
self.tools_dict = {t.get_name(): t for t in builtin_tools}
ShieldRunnerMixin.__init__( ShieldRunnerMixin.__init__(
self, self,
@ -80,7 +110,6 @@ class ChatAgent(ShieldRunnerMixin):
msg.context = None msg.context = None
messages.append(msg) messages.append(msg)
# messages.extend(turn.input_messages)
for step in turn.steps: for step in turn.steps:
if step.step_type == StepType.inference.value: if step.step_type == StepType.inference.value:
messages.append(step.model_response) messages.append(step.model_response)
@ -105,31 +134,64 @@ class ChatAgent(ShieldRunnerMixin):
# print_dialog(messages) # print_dialog(messages)
return messages return messages
def create_session(self, name: str) -> Session: async def create_session(self, name: str) -> str:
session_id = str(uuid.uuid4()) session_id = str(uuid.uuid4())
session = Session( session = dict(
session_id=session_id, session_id=session_id,
session_name=name, session_name=name,
turns=[], started_at=str(datetime.now()),
started_at=datetime.now(),
) )
self.sessions[session_id] = session await self.persistence_store.set(
return session key=f"session:{self.agent_id}:{session_id}",
value=json.dumps(session),
)
return session_id
async def _add_turn_to_session(self, session_id: str, turn: Turn):
await self.persistence_store.set(
key=f"session:{self.agent_id}:{session_id}:{turn.turn_id}",
value=json.dumps(turn.dict()),
)
async def _get_session_turns(self, session_id: str) -> List[Turn]:
values = await self.persistence_store.range(
start_key=f"session:{self.agent_id}:{session_id}:",
end_key=f"session:{self.agent_id}:{session_id}:\xff\xff\xff\xff",
)
turns = []
for value in values:
try:
turn = Turn(**json.loads(value))
turns.append(turn)
except Exception as e:
print(f"Error parsing turn: {e}")
continue
return turns
async def _get_session_info(self, session_id: str) -> Optional[Dict[str, str]]:
value = await self.persistence_store.get(
key=f"session:{self.agent_id}:{session_id}",
)
if not value:
return None
return json.loads(value)
async def create_and_execute_turn( async def create_and_execute_turn(
self, request: AgentTurnCreateRequest self, request: AgentTurnCreateRequest
) -> AsyncGenerator: ) -> AsyncGenerator:
assert ( session_info = await self._get_session_info(request.session_id)
request.session_id in self.sessions if session_info is None:
), f"Session {request.session_id} not found" raise ValueError(f"Session {request.session_id} not found")
session = self.sessions[request.session_id] turns = await self._get_session_turns(request.session_id)
messages = [] messages = []
if len(session.turns) == 0 and self.agent_config.instructions != "": if len(turns) == 0 and self.agent_config.instructions != "":
messages.append(SystemMessage(content=self.agent_config.instructions)) messages.append(SystemMessage(content=self.agent_config.instructions))
for i, turn in enumerate(session.turns): for i, turn in enumerate(turns):
messages.extend(self.turn_to_messages(turn)) messages.extend(self.turn_to_messages(turn))
messages.extend(request.messages) messages.extend(request.messages)
@ -186,7 +248,7 @@ class ChatAgent(ShieldRunnerMixin):
completed_at=datetime.now(), completed_at=datetime.now(),
steps=steps, steps=steps,
) )
session.turns.append(turn) await self._add_turn_to_session(request.session_id, turn)
chunk = AgentTurnResponseStreamChunk( chunk = AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent( event=AgentTurnResponseEvent(
@ -455,7 +517,7 @@ class ChatAgent(ShieldRunnerMixin):
) )
) )
if n_iter >= self.max_infer_iters: if n_iter >= self.agent_config.max_infer_iters:
cprint("Done with MAX iterations, exiting.") cprint("Done with MAX iterations, exiting.")
yield message yield message
break break

View file

@ -6,7 +6,6 @@
import logging import logging
import tempfile
import uuid import uuid
from typing import AsyncGenerator from typing import AsyncGenerator
@ -15,39 +14,15 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from llama_stack.providers.utils.kvstore import kvstore_impl
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceAgentsImplConfig from .config import MetaReferenceAgentsImplConfig
from .tools.builtin import (
CodeInterpreterTool,
PhotogenTool,
SearchTool,
WolframAlphaTool,
)
from .tools.safety import with_safety
logger = logging.getLogger() logger = logging.getLogger()
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {}
class KVStore(Protocol):
def get(self, key: str) -> str:
...
def set(self, key: str, value: str) -> None:
...
def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis:
from .kvstore_impls.redis import RedisKVStoreImpl
return RedisKVStoreImpl(config)
return None
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):
def __init__( def __init__(
self, self,
@ -60,10 +35,9 @@ class MetaReferenceAgentsImpl(Agents):
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
self.kvstore = kvstore_impl(config.kvstore)
async def initialize(self) -> None: async def initialize(self) -> None:
pass self.persistence_store = await kvstore_impl(self.config.persistence_store)
async def create_agent( async def create_agent(
self, self,
@ -71,38 +45,42 @@ class MetaReferenceAgentsImpl(Agents):
) -> AgentCreateResponse: ) -> AgentCreateResponse:
agent_id = str(uuid.uuid4()) agent_id = str(uuid.uuid4())
builtin_tools = [] await self.persistence_store.set(
for tool_defn in agent_config.tools: key=f"agent:{agent_id}",
if isinstance(tool_defn, WolframAlphaToolDefinition): value=json.dumps(agent_config.dict()),
tool = WolframAlphaTool(tool_defn.api_key) )
elif isinstance(tool_defn, SearchToolDefinition): return AgentCreateResponse(
tool = SearchTool(tool_defn.engine, tool_defn.api_key) agent_id=agent_id,
elif isinstance(tool_defn, CodeInterpreterToolDefinition): )
tool = CodeInterpreterTool()
elif isinstance(tool_defn, PhotogenToolDefinition):
tool = PhotogenTool(dump_dir=tempfile.mkdtemp())
else:
continue
builtin_tools.append( async def get_agent(self, agent_id: str) -> ChatAgent:
with_safety( agent_config = await self.persistence_store.get(
tool, key=f"agent:{agent_id}",
self.safety_api, )
tool_defn.input_shields, if not agent_config:
tool_defn.output_shields, raise ValueError(f"Could not find agent config for {agent_id}")
)
)
AGENT_INSTANCES_BY_ID[agent_id] = ChatAgent( try:
agent_config = json.loads(agent_config)
except json.JSONDecodeError as e:
raise ValueError(
f"Could not JSON decode agent config for {agent_id}"
) from e
try:
agent_config = AgentConfig(**agent_config)
except Exception as e:
raise ValueError(
f"Could not validate(?) agent config for {agent_id}"
) from e
return ChatAgent(
agent_id=agent_id,
agent_config=agent_config, agent_config=agent_config,
inference_api=self.inference_api, inference_api=self.inference_api,
safety_api=self.safety_api, safety_api=self.safety_api,
memory_api=self.memory_api, memory_api=self.memory_api,
builtin_tools=builtin_tools, persistence_store=self.persistence_store,
)
return AgentCreateResponse(
agent_id=agent_id,
) )
async def create_agent_session( async def create_agent_session(
@ -110,12 +88,11 @@ class MetaReferenceAgentsImpl(Agents):
agent_id: str, agent_id: str,
session_name: str, session_name: str,
) -> AgentSessionCreateResponse: ) -> AgentSessionCreateResponse:
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found" agent = await self.get_agent(agent_id)
agent = AGENT_INSTANCES_BY_ID[agent_id]
session = agent.create_session(session_name) session_id = await agent.create_session(session_name)
return AgentSessionCreateResponse( return AgentSessionCreateResponse(
session_id=session.session_id, session_id=session_id,
) )
async def create_agent_turn( async def create_agent_turn(
@ -131,6 +108,8 @@ class MetaReferenceAgentsImpl(Agents):
attachments: Optional[List[Attachment]] = None, attachments: Optional[List[Attachment]] = None,
stream: Optional[bool] = False, stream: Optional[bool] = False,
) -> AsyncGenerator: ) -> AsyncGenerator:
agent = await self.get_agent(agent_id)
# wrapper request to make it easier to pass around (internal only, not exposed to API) # wrapper request to make it easier to pass around (internal only, not exposed to API)
request = AgentTurnCreateRequest( request = AgentTurnCreateRequest(
agent_id=agent_id, agent_id=agent_id,
@ -140,12 +119,5 @@ class MetaReferenceAgentsImpl(Agents):
stream=stream, stream=stream,
) )
agent_id = request.agent_id
assert agent_id in AGENT_INSTANCES_BY_ID, f"System {agent_id} not found"
agent = AGENT_INSTANCES_BY_ID[agent_id]
assert (
request.session_id in agent.sessions
), f"Session {request.session_id} not found"
async for event in agent.create_and_execute_turn(request): async for event in agent.create_and_execute_turn(request):
yield event yield event

View file

@ -8,4 +8,4 @@ from llama_stack.providers.utils.kvstore import KVStoreConfig
class MetaReferenceAgentsImplConfig(BaseModel): class MetaReferenceAgentsImplConfig(BaseModel):
kv_store: KVStoreConfig persistence_store: KVStoreConfig

View file

@ -3,3 +3,5 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from .kvstore import * # noqa: F401, F403

View file

@ -5,20 +5,20 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional, Protocol from typing import List, Optional, Protocol
from pydantic import BaseModel from pydantic import BaseModel
class KVStoreValue(BaseModel): class KVStoreValue(BaseModel):
key: str key: str
value: Any value: str
expiration: Optional[datetime] = None expiration: Optional[datetime] = None
class KVStore(Protocol): class KVStore(Protocol):
async def set( async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ... ) -> None: ...
async def get(self, key: str) -> Optional[KVStoreValue]: ... async def get(self, key: str) -> Optional[KVStoreValue]: ...

View file

@ -14,7 +14,7 @@ from typing_extensions import Annotated
class KVStoreType(Enum): class KVStoreType(Enum):
redis = "redis" redis = "redis"
sqlite = "sqlite" sqlite = "sqlite"
pgvector = "pgvector" postgres = "postgres"
class CommonConfig(BaseModel): class CommonConfig(BaseModel):
@ -37,8 +37,8 @@ class SqliteKVStoreImplConfig(CommonConfig):
) )
class PGVectorKVStoreImplConfig(CommonConfig): class PostgresKVStoreImplConfig(CommonConfig):
type: Literal[KVStoreType.pgvector.value] = KVStoreType.pgvector.value type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
host: str = "localhost" host: str = "localhost"
port: int = 5432 port: int = 5432
db: str = "llamastack" db: str = "llamastack"
@ -47,6 +47,6 @@ class PGVectorKVStoreImplConfig(CommonConfig):
KVStoreConfig = Annotated[ KVStoreConfig = Annotated[
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PGVectorKVStoreImplConfig], Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig],
Field(discriminator="type"), Field(discriminator="type", default=KVStoreType.sqlite.value),
] ]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, List, Optional from typing import List, Optional
from redis.asyncio import Redis from redis.asyncio import Redis
@ -26,7 +26,7 @@ class RedisKVStoreImpl(KVStore):
return f"{self.config.namespace}:{key}" return f"{self.config.namespace}:{key}"
async def set( async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ) -> None:
key = self._namespaced_key(key) key = self._namespaced_key(key)
await self.redis.set(key, value) await self.redis.set(key, value)
@ -50,11 +50,4 @@ class RedisKVStoreImpl(KVStore):
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)
keys = await self.redis.keys(f"{start_key}*") return await self.redis.zrangebylex(start_key, end_key)
result = []
for key in keys:
if key <= end_key:
value = await self.get(key)
if value:
result.append(value)
return result

View file

@ -4,9 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional from typing import List, Optional
import aiosqlite import aiosqlite
@ -33,12 +32,12 @@ class SqliteKVStoreImpl(KVStore):
await db.commit() await db.commit()
async def set( async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None self, key: str, value: str, expiration: Optional[datetime] = None
) -> None: ) -> None:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
await db.execute( await db.execute(
f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)",
(key, json.dumps(value), expiration), (key, value, expiration),
) )
await db.commit() await db.commit()
@ -51,9 +50,7 @@ class SqliteKVStoreImpl(KVStore):
if row is None: if row is None:
return None return None
value, expiration = row value, expiration = row
return KVStoreValue( return KVStoreValue(key=key, value=value, expiration=expiration)
key=key, value=json.loads(value), expiration=expiration
)
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
@ -70,8 +67,6 @@ class SqliteKVStoreImpl(KVStore):
async for row in cursor: async for row in cursor:
key, value, expiration = row key, value, expiration = row
result.append( result.append(
KVStoreValue( KVStoreValue(key=key, value=value, expiration=expiration)
key=key, value=json.loads(value), expiration=expiration
)
) )
return result return result