From bbf0b59ae4b226a1a4cffcf812e3e09c278f1542 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Mon, 16 Sep 2024 21:59:07 -0700 Subject: [PATCH] add redis adapter + sqlite provider --- llama_stack/core/datatypes.py | 1 + llama_stack/core/distribution.py | 2 + .../adapters/control_plane/__init__.py | 5 ++ .../adapters/control_plane/redis/__init__.py | 15 ++++ .../adapters/control_plane/redis/config.py | 19 +++++ .../adapters/control_plane/redis/redis.py | 62 +++++++++++++++ .../providers/impls/sqlite/__init__.py | 5 ++ .../impls/sqlite/control_plane/__init__.py | 15 ++++ .../impls/sqlite/control_plane/config.py | 19 +++++ .../sqlite/control_plane/control_plane.py | 79 +++++++++++++++++++ .../providers/registry/control_plane.py | 29 +++++++ 11 files changed, 251 insertions(+) create mode 100644 llama_stack/providers/adapters/control_plane/__init__.py create mode 100644 llama_stack/providers/adapters/control_plane/redis/__init__.py create mode 100644 llama_stack/providers/adapters/control_plane/redis/config.py create mode 100644 llama_stack/providers/adapters/control_plane/redis/redis.py create mode 100644 llama_stack/providers/impls/sqlite/__init__.py create mode 100644 llama_stack/providers/impls/sqlite/control_plane/__init__.py create mode 100644 llama_stack/providers/impls/sqlite/control_plane/config.py create mode 100644 llama_stack/providers/impls/sqlite/control_plane/control_plane.py create mode 100644 llama_stack/providers/registry/control_plane.py diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 06eb1cc49..4f388a3fe 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -20,6 +20,7 @@ class Api(Enum): agents = "agents" memory = "memory" telemetry = "telemetry" + control_plane = "control_plane" @json_schema_type diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index 13c96c3a5..8bfa75d9c 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -9,6 +9,7 @@ import inspect from typing import Dict, List from llama_stack.apis.agents import Agents +from llama_stack.apis.control_plane import ControlPlane from llama_stack.apis.inference import Inference from llama_stack.apis.memory import Memory from llama_stack.apis.safety import Safety @@ -37,6 +38,7 @@ def api_endpoints() -> Dict[Api, List[ApiEndpoint]]: Api.agents: Agents, Api.memory: Memory, Api.telemetry: Telemetry, + Api.control_plane: ControlPlane, } for api, protocol in protocols.items(): diff --git a/llama_stack/providers/adapters/control_plane/__init__.py b/llama_stack/providers/adapters/control_plane/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/adapters/control_plane/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/adapters/control_plane/redis/__init__.py b/llama_stack/providers/adapters/control_plane/redis/__init__.py new file mode 100644 index 000000000..0482718cc --- /dev/null +++ b/llama_stack/providers/adapters/control_plane/redis/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import RedisImplConfig + + +async def get_adapter_impl(config: RedisImplConfig, _deps): + from .redis import RedisControlPlaneAdapter + + impl = RedisControlPlaneAdapter(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/adapters/control_plane/redis/config.py b/llama_stack/providers/adapters/control_plane/redis/config.py new file mode 100644 index 000000000..6238611e0 --- /dev/null +++ b/llama_stack/providers/adapters/control_plane/redis/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class RedisImplConfig(BaseModel): + url: str = Field( + description="The URL for the Redis server", + ) + namespace: Optional[str] = Field( + default=None, + description="All keys will be prefixed with this namespace", + ) diff --git a/llama_stack/providers/adapters/control_plane/redis/redis.py b/llama_stack/providers/adapters/control_plane/redis/redis.py new file mode 100644 index 000000000..d5c468b77 --- /dev/null +++ b/llama_stack/providers/adapters/control_plane/redis/redis.py @@ -0,0 +1,62 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime, timedelta +from typing import Any, List, Optional + +from redis.asyncio import Redis + +from llama_stack.apis.control_plane import * # noqa: F403 + + +from .config import RedisImplConfig + + +class RedisControlPlaneAdapter(ControlPlane): + def __init__(self, config: RedisImplConfig): + self.config = config + + async def initialize(self) -> None: + self.redis = Redis.from_url(self.config.url) + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set( + self, key: str, value: Any, expiration: Optional[datetime] = None + ) -> None: + key = self._namespaced_key(key) + await self.redis.set(key, value) + if expiration: + await self.redis.expireat(key, expiration) + + async def get(self, key: str) -> Optional[ControlPlaneValue]: + key = self._namespaced_key(key) + value = await self.redis.get(key) + if value is None: + return None + ttl = await self.redis.ttl(key) + expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None + return ControlPlaneValue(key=key, value=value, expiration=expiration) + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + await self.redis.delete(key) + + async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + keys = await self.redis.keys(f"{start_key}*") + result = [] + for key in keys: + if key <= end_key: + value = await self.get(key) + if value: + result.append(value) + return result diff --git a/llama_stack/providers/impls/sqlite/__init__.py b/llama_stack/providers/impls/sqlite/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/impls/sqlite/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/impls/sqlite/control_plane/__init__.py b/llama_stack/providers/impls/sqlite/control_plane/__init__.py new file mode 100644 index 000000000..330f15942 --- /dev/null +++ b/llama_stack/providers/impls/sqlite/control_plane/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .config import SqliteControlPlaneConfig + + +async def get_provider_impl(config: SqliteControlPlaneConfig, _deps): + from .control_plane import SqliteControlPlane + + impl = SqliteControlPlane(config) + await impl.initialize() + return impl diff --git a/llama_stack/providers/impls/sqlite/control_plane/config.py b/llama_stack/providers/impls/sqlite/control_plane/config.py new file mode 100644 index 000000000..a616c90d0 --- /dev/null +++ b/llama_stack/providers/impls/sqlite/control_plane/config.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_models.schema_utils import json_schema_type +from pydantic import BaseModel, Field + + +@json_schema_type +class SqliteControlPlaneConfig(BaseModel): + db_path: str = Field( + description="File path for the sqlite database", + ) + table_name: str = Field( + default="llamastack_control_plane", + description="Table into which all the keys will be placed", + ) diff --git a/llama_stack/providers/impls/sqlite/control_plane/control_plane.py b/llama_stack/providers/impls/sqlite/control_plane/control_plane.py new file mode 100644 index 000000000..e2e655244 --- /dev/null +++ b/llama_stack/providers/impls/sqlite/control_plane/control_plane.py @@ -0,0 +1,79 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import json +from datetime import datetime +from typing import Any, List, Optional + +import aiosqlite + +from llama_stack.apis.control_plane import * # noqa: F403 + + +from .config import SqliteControlPlaneConfig + + +class SqliteControlPlane(ControlPlane): + def __init__(self, config: SqliteControlPlaneConfig): + self.db_path = config.db_path + self.table_name = config.table_name + + async def initialize(self): + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + await db.commit() + + async def set( + self, key: str, value: Any, expiration: Optional[datetime] = None + ) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute( + f"INSERT OR REPLACE INTO {self.table_name} (key, value, expiration) VALUES (?, ?, ?)", + (key, json.dumps(value), expiration), + ) + await db.commit() + + async def get(self, key: str) -> Optional[ControlPlaneValue]: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) + ) as cursor: + row = await cursor.fetchone() + if row is None: + return None + value, expiration = row + return ControlPlaneValue( + key=key, value=json.loads(value), expiration=expiration + ) + + async def delete(self, key: str) -> None: + async with aiosqlite.connect(self.db_path) as db: + await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) + await db.commit() + + async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: + async with aiosqlite.connect(self.db_path) as db: + async with db.execute( + f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", + (start_key, end_key), + ) as cursor: + result = [] + async for row in cursor: + key, value, expiration = row + result.append( + ControlPlaneValue( + key=key, value=json.loads(value), expiration=expiration + ) + ) + return result diff --git a/llama_stack/providers/registry/control_plane.py b/llama_stack/providers/registry/control_plane.py new file mode 100644 index 000000000..8e240b913 --- /dev/null +++ b/llama_stack/providers/registry/control_plane.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import List + +from llama_stack.core.datatypes import * # noqa: F403 + + +def available_providers() -> List[ProviderSpec]: + return [ + InlineProviderSpec( + api=Api.control_plane, + provider_id="sqlite", + pip_packages=["aiosqlite"], + module="llama_stack.providers.impls.sqlite.control_plane", + config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig", + ), + remote_provider_spec( + Api.control_plane, + AdapterSpec( + adapter_id="redis", + pip_packages=["redis"], + module="llama_stack.providers.adapters.control_plane.redis", + ), + ), + ]