diff --git a/llama_stack/apis/agents/agents.py b/llama_stack/apis/agents/agents.py index d99a446f1..cf247b01b 100644 --- a/llama_stack/apis/agents/agents.py +++ b/llama_stack/apis/agents/agents.py @@ -281,6 +281,7 @@ class AgentConfigCommon(BaseModel): class AgentConfig(AgentConfigCommon): model: str instructions: str + enable_session_persistence: bool class AgentConfigOverridablePerTurn(AgentConfigCommon): diff --git a/llama_stack/distribution/control_plane/adapters/redis/__init__.py b/llama_stack/distribution/control_plane/adapters/redis/__init__.py deleted file mode 100644 index 0482718cc..000000000 --- a/llama_stack/distribution/control_plane/adapters/redis/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import RedisImplConfig - - -async def get_adapter_impl(config: RedisImplConfig, _deps): - from .redis import RedisControlPlaneAdapter - - impl = RedisControlPlaneAdapter(config) - await impl.initialize() - return impl diff --git a/llama_stack/distribution/control_plane/adapters/redis/config.py b/llama_stack/distribution/control_plane/adapters/redis/config.py deleted file mode 100644 index d786aceb1..000000000 --- a/llama_stack/distribution/control_plane/adapters/redis/config.py +++ /dev/null @@ -1,21 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Optional - -from llama_models.schema_utils import json_schema_type -from pydantic import BaseModel, Field - - -@json_schema_type -class RedisImplConfig(BaseModel): - url: str = Field( - description="The URL for the Redis server", - ) - namespace: Optional[str] = Field( - default=None, - description="All keys will be prefixed with this namespace", - ) diff --git a/llama_stack/distribution/control_plane/adapters/sqlite/__init__.py b/llama_stack/distribution/control_plane/adapters/sqlite/__init__.py deleted file mode 100644 index 330f15942..000000000 --- a/llama_stack/distribution/control_plane/adapters/sqlite/__init__.py +++ /dev/null @@ -1,15 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .config import SqliteControlPlaneConfig - - -async def get_provider_impl(config: SqliteControlPlaneConfig, _deps): - from .control_plane import SqliteControlPlane - - impl = SqliteControlPlane(config) - await impl.initialize() - return impl diff --git a/llama_stack/distribution/control_plane/registry.py b/llama_stack/distribution/control_plane/registry.py deleted file mode 100644 index 7465c4534..000000000 --- a/llama_stack/distribution/control_plane/registry.py +++ /dev/null @@ -1,29 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import List - -from llama_stack.distribution.datatypes import * # noqa: F403 - - -def available_providers() -> List[ProviderSpec]: - return [ - InlineProviderSpec( - api=Api.control_plane, - provider_id="sqlite", - pip_packages=["aiosqlite"], - module="llama_stack.providers.impls.sqlite.control_plane", - config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig", - ), - remote_provider_spec( - Api.control_plane, - AdapterSpec( - adapter_id="redis", - pip_packages=["redis"], - module="llama_stack.providers.adapters.control_plane.redis", - ), - ), - ] diff --git a/llama_stack/providers/impls/meta_reference/agents/agents.py b/llama_stack/providers/impls/meta_reference/agents/agents.py index 022c8c3d1..bb042baa3 100644 --- a/llama_stack/providers/impls/meta_reference/agents/agents.py +++ b/llama_stack/providers/impls/meta_reference/agents/agents.py @@ -16,7 +16,7 @@ from llama_stack.apis.safety import Safety from llama_stack.apis.agents import * # noqa: F403 from .agent_instance import ChatAgent -from .config import MetaReferenceImplConfig +from .config import MetaReferenceAgentsImplConfig from .tools.builtin import ( CodeInterpreterTool, PhotogenTool, @@ -33,10 +33,25 @@ logger.setLevel(logging.INFO) AGENT_INSTANCES_BY_ID = {} +class KVStore(Protocol): + def get(self, key: str) -> str: + ... + + def set(self, key: str, value: str) -> None: + ... + +def kvstore_impl(config: KVStoreConfig) -> KVStore: + if config.type == KVStoreType.redis: + from .kvstore_impls.redis import RedisKVStoreImpl + return RedisKVStoreImpl(config) + + return None + + class MetaReferenceAgentsImpl(Agents): def __init__( self, - config: MetaReferenceImplConfig, + config: MetaReferenceAgentsImplConfig, inference_api: Inference, memory_api: Memory, safety_api: Safety, @@ -45,6 +60,7 @@ class MetaReferenceAgentsImpl(Agents): self.inference_api = inference_api self.memory_api = memory_api self.safety_api = safety_api + self.kvstore = kvstore_impl(config.kvstore) async def initialize(self) -> None: pass diff --git a/llama_stack/providers/impls/meta_reference/agents/config.py b/llama_stack/providers/impls/meta_reference/agents/config.py index 17beb348e..f293f46c2 100644 --- a/llama_stack/providers/impls/meta_reference/agents/config.py +++ b/llama_stack/providers/impls/meta_reference/agents/config.py @@ -4,7 +4,8 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from pydantic import BaseModel +from llama_stack.providers.utils.kvstore import KVStoreConfig -class MetaReferenceImplConfig(BaseModel): ... +class MetaReferenceAgentsImplConfig(BaseModel): + kv_store: KVStoreConfig diff --git a/llama_stack/distribution/control_plane/adapters/__init__.py b/llama_stack/providers/utils/kvstore/__init__.py similarity index 100% rename from llama_stack/distribution/control_plane/adapters/__init__.py rename to llama_stack/providers/utils/kvstore/__init__.py diff --git a/llama_stack/distribution/control_plane/api.py b/llama_stack/providers/utils/kvstore/api.py similarity index 56% rename from llama_stack/distribution/control_plane/api.py rename to llama_stack/providers/utils/kvstore/api.py index db79e91cd..99d666f42 100644 --- a/llama_stack/distribution/control_plane/api.py +++ b/llama_stack/providers/utils/kvstore/api.py @@ -7,29 +7,22 @@ from datetime import datetime from typing import Any, List, Optional, Protocol -from llama_models.schema_utils import json_schema_type, webmethod from pydantic import BaseModel -@json_schema_type -class ControlPlaneValue(BaseModel): +class KVStoreValue(BaseModel): key: str value: Any expiration: Optional[datetime] = None -@json_schema_type -class ControlPlane(Protocol): - @webmethod(route="/control_plane/set") +class KVStore(Protocol): async def set( self, key: str, value: Any, expiration: Optional[datetime] = None ) -> None: ... - @webmethod(route="/control_plane/get", method="GET") - async def get(self, key: str) -> Optional[ControlPlaneValue]: ... + async def get(self, key: str) -> Optional[KVStoreValue]: ... - @webmethod(route="/control_plane/delete") async def delete(self, key: str) -> None: ... - @webmethod(route="/control_plane/range", method="GET") - async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ... + async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: ... diff --git a/llama_stack/providers/utils/kvstore/config.py b/llama_stack/providers/utils/kvstore/config.py new file mode 100644 index 000000000..515640897 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/config.py @@ -0,0 +1,52 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from enum import Enum +from typing import Literal, Optional, Union + +from pydantic import BaseModel +from typing_extensions import Annotated + + +class KVStoreType(Enum): + redis = "redis" + sqlite = "sqlite" + pgvector = "pgvector" + + +class CommonConfig(BaseModel): + namespace: Optional[str] = Field( + default=None, + description="All keys will be prefixed with this namespace", + ) + + +class RedisKVStoreImplConfig(CommonConfig): + type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value + host: str = "localhost" + port: int = 6379 + + +class SqliteKVStoreImplConfig(CommonConfig): + type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value + db_path: str = Field( + description="File path for the sqlite database", + ) + + +class PGVectorKVStoreImplConfig(CommonConfig): + type: Literal[KVStoreType.pgvector.value] = KVStoreType.pgvector.value + host: str = "localhost" + port: int = 5432 + db: str = "llamastack" + user: str + password: Optional[str] = None + + +KVStoreConfig = Annotated[ + Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PGVectorKVStoreImplConfig], + Field(discriminator="type"), +] diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py new file mode 100644 index 000000000..a7fa0af75 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -0,0 +1,30 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .api import * # noqa: F403 +from .config import * # noqa: F403 + + +def kvstore_dependencies(): + return ["aiosqlite", "psycopg2-binary", "redis"] + + +async def kvstore_impl(config: KVStoreConfig) -> KVStore: + if config.type == KVStoreType.redis: + from .redis import RedisKVStoreImpl + + impl = RedisKVStoreImpl(config) + elif config.type == KVStoreType.sqlite: + from .sqlite import SqliteKVStoreImpl + + impl = SqliteKVStoreImpl(config) + elif config.type == KVStoreType.pgvector: + raise NotImplementedError() + else: + raise ValueError(f"Unknown kvstore type {config.type}") + + await impl.initialize() + return impl diff --git a/llama_stack/providers/utils/kvstore/redis/__init__.py b/llama_stack/providers/utils/kvstore/redis/__init__.py new file mode 100644 index 000000000..94693ca43 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/redis/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .redis import RedisKVStoreImpl # noqa: F401 diff --git a/llama_stack/distribution/control_plane/adapters/redis/redis.py b/llama_stack/providers/utils/kvstore/redis/redis.py similarity index 82% rename from llama_stack/distribution/control_plane/adapters/redis/redis.py rename to llama_stack/providers/utils/kvstore/redis/redis.py index d5c468b77..340530750 100644 --- a/llama_stack/distribution/control_plane/adapters/redis/redis.py +++ b/llama_stack/providers/utils/kvstore/redis/redis.py @@ -9,14 +9,12 @@ from typing import Any, List, Optional from redis.asyncio import Redis -from llama_stack.apis.control_plane import * # noqa: F403 +from ..api import * # noqa: F403 +from ..config import RedisKVStoreImplConfig -from .config import RedisImplConfig - - -class RedisControlPlaneAdapter(ControlPlane): - def __init__(self, config: RedisImplConfig): +class RedisKVStoreImpl(KVStore): + def __init__(self, config: RedisKVStoreImplConfig): self.config = config async def initialize(self) -> None: @@ -35,20 +33,20 @@ class RedisControlPlaneAdapter(ControlPlane): if expiration: await self.redis.expireat(key, expiration) - async def get(self, key: str) -> Optional[ControlPlaneValue]: + async def get(self, key: str) -> Optional[KVStoreValue]: key = self._namespaced_key(key) value = await self.redis.get(key) if value is None: return None ttl = await self.redis.ttl(key) expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None - return ControlPlaneValue(key=key, value=value, expiration=expiration) + return KVStoreValue(key=key, value=value, expiration=expiration) async def delete(self, key: str) -> None: key = self._namespaced_key(key) await self.redis.delete(key) - async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: + async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: start_key = self._namespaced_key(start_key) end_key = self._namespaced_key(end_key) diff --git a/llama_stack/providers/utils/kvstore/sqlite/__init__.py b/llama_stack/providers/utils/kvstore/sqlite/__init__.py new file mode 100644 index 000000000..03bc53c24 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/sqlite/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .sqlite import SqliteKVStoreImpl # noqa: F401 diff --git a/llama_stack/distribution/control_plane/adapters/sqlite/config.py b/llama_stack/providers/utils/kvstore/sqlite/config.py similarity index 100% rename from llama_stack/distribution/control_plane/adapters/sqlite/config.py rename to llama_stack/providers/utils/kvstore/sqlite/config.py diff --git a/llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py similarity index 86% rename from llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py rename to llama_stack/providers/utils/kvstore/sqlite/sqlite.py index e2e655244..ae8374156 100644 --- a/llama_stack/distribution/control_plane/adapters/sqlite/control_plane.py +++ b/llama_stack/providers/utils/kvstore/sqlite/sqlite.py @@ -10,14 +10,12 @@ from typing import Any, List, Optional import aiosqlite -from llama_stack.apis.control_plane import * # noqa: F403 +from ..api import * # noqa: F403 +from ..config import SqliteKVStoreConfig -from .config import SqliteControlPlaneConfig - - -class SqliteControlPlane(ControlPlane): - def __init__(self, config: SqliteControlPlaneConfig): +class SqliteKVStoreImpl(KVStore): + def __init__(self, config: SqliteKVStoreConfig): self.db_path = config.db_path self.table_name = config.table_name @@ -44,7 +42,7 @@ class SqliteControlPlane(ControlPlane): ) await db.commit() - async def get(self, key: str) -> Optional[ControlPlaneValue]: + async def get(self, key: str) -> Optional[KVStoreValue]: async with aiosqlite.connect(self.db_path) as db: async with db.execute( f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) @@ -53,7 +51,7 @@ class SqliteControlPlane(ControlPlane): if row is None: return None value, expiration = row - return ControlPlaneValue( + return KVStoreValue( key=key, value=json.loads(value), expiration=expiration ) @@ -62,7 +60,7 @@ class SqliteControlPlane(ControlPlane): await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.commit() - async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: + async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: async with aiosqlite.connect(self.db_path) as db: async with db.execute( f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", @@ -72,7 +70,7 @@ class SqliteControlPlane(ControlPlane): async for row in cursor: key, value, expiration = row result.append( - ControlPlaneValue( + KVStoreValue( key=key, value=json.loads(value), expiration=expiration ) )