kvstore impls for redis / sqlite moved

This commit is contained in:
Ashwin Bharambe 2024-09-20 19:55:44 -07:00
parent c1ab66f1e6
commit 61974e337f
16 changed files with 137 additions and 114 deletions

View file

@ -281,6 +281,7 @@ class AgentConfigCommon(BaseModel):
class AgentConfig(AgentConfigCommon): class AgentConfig(AgentConfigCommon):
model: str model: str
instructions: str instructions: str
enable_session_persistence: bool
class AgentConfigOverridablePerTurn(AgentConfigCommon): class AgentConfigOverridablePerTurn(AgentConfigCommon):

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import RedisImplConfig
async def get_adapter_impl(config: RedisImplConfig, _deps):
from .redis import RedisControlPlaneAdapter
impl = RedisControlPlaneAdapter(config)
await impl.initialize()
return impl

View file

@ -1,21 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class RedisImplConfig(BaseModel):
url: str = Field(
description="The URL for the Redis server",
)
namespace: Optional[str] = Field(
default=None,
description="All keys will be prefixed with this namespace",
)

View file

@ -1,15 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import SqliteControlPlaneConfig
async def get_provider_impl(config: SqliteControlPlaneConfig, _deps):
from .control_plane import SqliteControlPlane
impl = SqliteControlPlane(config)
await impl.initialize()
return impl

View file

@ -1,29 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.control_plane,
provider_id="sqlite",
pip_packages=["aiosqlite"],
module="llama_stack.providers.impls.sqlite.control_plane",
config_class="llama_stack.providers.impls.sqlite.control_plane.SqliteControlPlaneConfig",
),
remote_provider_spec(
Api.control_plane,
AdapterSpec(
adapter_id="redis",
pip_packages=["redis"],
module="llama_stack.providers.adapters.control_plane.redis",
),
),
]

View file

@ -16,7 +16,7 @@ from llama_stack.apis.safety import Safety
from llama_stack.apis.agents import * # noqa: F403 from llama_stack.apis.agents import * # noqa: F403
from .agent_instance import ChatAgent from .agent_instance import ChatAgent
from .config import MetaReferenceImplConfig from .config import MetaReferenceAgentsImplConfig
from .tools.builtin import ( from .tools.builtin import (
CodeInterpreterTool, CodeInterpreterTool,
PhotogenTool, PhotogenTool,
@ -33,10 +33,25 @@ logger.setLevel(logging.INFO)
AGENT_INSTANCES_BY_ID = {} AGENT_INSTANCES_BY_ID = {}
class KVStore(Protocol):
def get(self, key: str) -> str:
...
def set(self, key: str, value: str) -> None:
...
def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis:
from .kvstore_impls.redis import RedisKVStoreImpl
return RedisKVStoreImpl(config)
return None
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):
def __init__( def __init__(
self, self,
config: MetaReferenceImplConfig, config: MetaReferenceAgentsImplConfig,
inference_api: Inference, inference_api: Inference,
memory_api: Memory, memory_api: Memory,
safety_api: Safety, safety_api: Safety,
@ -45,6 +60,7 @@ class MetaReferenceAgentsImpl(Agents):
self.inference_api = inference_api self.inference_api = inference_api
self.memory_api = memory_api self.memory_api = memory_api
self.safety_api = safety_api self.safety_api = safety_api
self.kvstore = kvstore_impl(config.kvstore)
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass

View file

@ -4,7 +4,8 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel from llama_stack.providers.utils.kvstore import KVStoreConfig
class MetaReferenceImplConfig(BaseModel): ... class MetaReferenceAgentsImplConfig(BaseModel):
kv_store: KVStoreConfig

View file

@ -7,29 +7,22 @@
from datetime import datetime from datetime import datetime
from typing import Any, List, Optional, Protocol from typing import Any, List, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel from pydantic import BaseModel
@json_schema_type class KVStoreValue(BaseModel):
class ControlPlaneValue(BaseModel):
key: str key: str
value: Any value: Any
expiration: Optional[datetime] = None expiration: Optional[datetime] = None
@json_schema_type class KVStore(Protocol):
class ControlPlane(Protocol):
@webmethod(route="/control_plane/set")
async def set( async def set(
self, key: str, value: Any, expiration: Optional[datetime] = None self, key: str, value: Any, expiration: Optional[datetime] = None
) -> None: ... ) -> None: ...
@webmethod(route="/control_plane/get", method="GET") async def get(self, key: str) -> Optional[KVStoreValue]: ...
async def get(self, key: str) -> Optional[ControlPlaneValue]: ...
@webmethod(route="/control_plane/delete")
async def delete(self, key: str) -> None: ... async def delete(self, key: str) -> None: ...
@webmethod(route="/control_plane/range", method="GET") async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: ...
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: ...

View file

@ -0,0 +1,52 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Literal, Optional, Union
from pydantic import BaseModel
from typing_extensions import Annotated
class KVStoreType(Enum):
redis = "redis"
sqlite = "sqlite"
pgvector = "pgvector"
class CommonConfig(BaseModel):
namespace: Optional[str] = Field(
default=None,
description="All keys will be prefixed with this namespace",
)
class RedisKVStoreImplConfig(CommonConfig):
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
host: str = "localhost"
port: int = 6379
class SqliteKVStoreImplConfig(CommonConfig):
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
db_path: str = Field(
description="File path for the sqlite database",
)
class PGVectorKVStoreImplConfig(CommonConfig):
type: Literal[KVStoreType.pgvector.value] = KVStoreType.pgvector.value
host: str = "localhost"
port: int = 5432
db: str = "llamastack"
user: str
password: Optional[str] = None
KVStoreConfig = Annotated[
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PGVectorKVStoreImplConfig],
Field(discriminator="type"),
]

View file

@ -0,0 +1,30 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .api import * # noqa: F403
from .config import * # noqa: F403
def kvstore_dependencies():
return ["aiosqlite", "psycopg2-binary", "redis"]
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
if config.type == KVStoreType.redis:
from .redis import RedisKVStoreImpl
impl = RedisKVStoreImpl(config)
elif config.type == KVStoreType.sqlite:
from .sqlite import SqliteKVStoreImpl
impl = SqliteKVStoreImpl(config)
elif config.type == KVStoreType.pgvector:
raise NotImplementedError()
else:
raise ValueError(f"Unknown kvstore type {config.type}")
await impl.initialize()
return impl

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .redis import RedisKVStoreImpl # noqa: F401

View file

@ -9,14 +9,12 @@ from typing import Any, List, Optional
from redis.asyncio import Redis from redis.asyncio import Redis
from llama_stack.apis.control_plane import * # noqa: F403 from ..api import * # noqa: F403
from ..config import RedisKVStoreImplConfig
from .config import RedisImplConfig class RedisKVStoreImpl(KVStore):
def __init__(self, config: RedisKVStoreImplConfig):
class RedisControlPlaneAdapter(ControlPlane):
def __init__(self, config: RedisImplConfig):
self.config = config self.config = config
async def initialize(self) -> None: async def initialize(self) -> None:
@ -35,20 +33,20 @@ class RedisControlPlaneAdapter(ControlPlane):
if expiration: if expiration:
await self.redis.expireat(key, expiration) await self.redis.expireat(key, expiration)
async def get(self, key: str) -> Optional[ControlPlaneValue]: async def get(self, key: str) -> Optional[KVStoreValue]:
key = self._namespaced_key(key) key = self._namespaced_key(key)
value = await self.redis.get(key) value = await self.redis.get(key)
if value is None: if value is None:
return None return None
ttl = await self.redis.ttl(key) ttl = await self.redis.ttl(key)
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
return ControlPlaneValue(key=key, value=value, expiration=expiration) return KVStoreValue(key=key, value=value, expiration=expiration)
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:
key = self._namespaced_key(key) key = self._namespaced_key(key)
await self.redis.delete(key) await self.redis.delete(key)
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]:
start_key = self._namespaced_key(start_key) start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key) end_key = self._namespaced_key(end_key)

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .sqlite import SqliteKVStoreImpl # noqa: F401

View file

@ -10,14 +10,12 @@ from typing import Any, List, Optional
import aiosqlite import aiosqlite
from llama_stack.apis.control_plane import * # noqa: F403 from ..api import * # noqa: F403
from ..config import SqliteKVStoreConfig
from .config import SqliteControlPlaneConfig class SqliteKVStoreImpl(KVStore):
def __init__(self, config: SqliteKVStoreConfig):
class SqliteControlPlane(ControlPlane):
def __init__(self, config: SqliteControlPlaneConfig):
self.db_path = config.db_path self.db_path = config.db_path
self.table_name = config.table_name self.table_name = config.table_name
@ -44,7 +42,7 @@ class SqliteControlPlane(ControlPlane):
) )
await db.commit() await db.commit()
async def get(self, key: str) -> Optional[ControlPlaneValue]: async def get(self, key: str) -> Optional[KVStoreValue]:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
async with db.execute( async with db.execute(
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,) f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
@ -53,7 +51,7 @@ class SqliteControlPlane(ControlPlane):
if row is None: if row is None:
return None return None
value, expiration = row value, expiration = row
return ControlPlaneValue( return KVStoreValue(
key=key, value=json.loads(value), expiration=expiration key=key, value=json.loads(value), expiration=expiration
) )
@ -62,7 +60,7 @@ class SqliteControlPlane(ControlPlane):
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,)) await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
await db.commit() await db.commit()
async def range(self, start_key: str, end_key: str) -> List[ControlPlaneValue]: async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]:
async with aiosqlite.connect(self.db_path) as db: async with aiosqlite.connect(self.db_path) as db:
async with db.execute( async with db.execute(
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?", f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
@ -72,7 +70,7 @@ class SqliteControlPlane(ControlPlane):
async for row in cursor: async for row in cursor:
key, value, expiration = row key, value, expiration = row
result.append( result.append(
ControlPlaneValue( KVStoreValue(
key=key, value=json.loads(value), expiration=expiration key=key, value=json.loads(value), expiration=expiration
) )
) )