mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-06 12:37:33 +00:00
Agent persistence works
This commit is contained in:
parent
4eb0f30891
commit
59f1fe5af8
17 changed files with 136 additions and 90 deletions
|
@ -7,22 +7,15 @@
|
|||
from datetime import datetime
|
||||
from typing import List, Optional, Protocol
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class KVStoreValue(BaseModel):
|
||||
key: str
|
||||
value: str
|
||||
expiration: Optional[datetime] = None
|
||||
|
||||
|
||||
class KVStore(Protocol):
|
||||
# TODO: make the value type bytes instead of str
|
||||
async def set(
|
||||
self, key: str, value: str, expiration: Optional[datetime] = None
|
||||
) -> None: ...
|
||||
|
||||
async def get(self, key: str) -> Optional[KVStoreValue]: ...
|
||||
async def get(self, key: str) -> Optional[str]: ...
|
||||
|
||||
async def delete(self, key: str) -> None: ...
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]: ...
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]: ...
|
||||
|
|
|
@ -7,9 +7,11 @@
|
|||
from enum import Enum
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.distribution.utils.config_dirs import RUNTIME_BASE_DIR
|
||||
|
||||
|
||||
class KVStoreType(Enum):
|
||||
redis = "redis"
|
||||
|
@ -24,20 +26,21 @@ class CommonConfig(BaseModel):
|
|||
)
|
||||
|
||||
|
||||
class RedisKVStoreImplConfig(CommonConfig):
|
||||
class RedisKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.redis.value] = KVStoreType.redis.value
|
||||
host: str = "localhost"
|
||||
port: int = 6379
|
||||
|
||||
|
||||
class SqliteKVStoreImplConfig(CommonConfig):
|
||||
class SqliteKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.sqlite.value] = KVStoreType.sqlite.value
|
||||
db_path: str = Field(
|
||||
default=(RUNTIME_BASE_DIR / "kvstore.db").as_posix(),
|
||||
description="File path for the sqlite database",
|
||||
)
|
||||
|
||||
|
||||
class PostgresKVStoreImplConfig(CommonConfig):
|
||||
class PostgresKVStoreConfig(CommonConfig):
|
||||
type: Literal[KVStoreType.postgres.value] = KVStoreType.postgres.value
|
||||
host: str = "localhost"
|
||||
port: int = 5432
|
||||
|
@ -47,6 +50,6 @@ class PostgresKVStoreImplConfig(CommonConfig):
|
|||
|
||||
|
||||
KVStoreConfig = Annotated[
|
||||
Union[RedisKVStoreImplConfig, SqliteKVStoreImplConfig, PostgresKVStoreImplConfig],
|
||||
Union[RedisKVStoreConfig, SqliteKVStoreConfig, PostgresKVStoreConfig],
|
||||
Field(discriminator="type", default=KVStoreType.sqlite.value),
|
||||
]
|
||||
|
|
|
@ -12,16 +12,37 @@ def kvstore_dependencies():
|
|||
return ["aiosqlite", "psycopg2-binary", "redis"]
|
||||
|
||||
|
||||
class InmemoryKVStoreImpl(KVStore):
|
||||
def __init__(self):
|
||||
self._store = {}
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
return self._store.get(key)
|
||||
|
||||
async def set(self, key: str, value: str) -> None:
|
||||
self._store[key] = value
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
return [
|
||||
self._store[key]
|
||||
for key in self._store.keys()
|
||||
if key >= start_key and key < end_key
|
||||
]
|
||||
|
||||
|
||||
async def kvstore_impl(config: KVStoreConfig) -> KVStore:
|
||||
if config.type == KVStoreType.redis:
|
||||
if config.type == KVStoreType.redis.value:
|
||||
from .redis import RedisKVStoreImpl
|
||||
|
||||
impl = RedisKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.sqlite:
|
||||
elif config.type == KVStoreType.sqlite.value:
|
||||
from .sqlite import SqliteKVStoreImpl
|
||||
|
||||
impl = SqliteKVStoreImpl(config)
|
||||
elif config.type == KVStoreType.pgvector:
|
||||
elif config.type == KVStoreType.postgres.value:
|
||||
raise NotImplementedError()
|
||||
else:
|
||||
raise ValueError(f"Unknown kvstore type {config.type}")
|
||||
|
|
|
@ -4,17 +4,17 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from redis.asyncio import Redis
|
||||
|
||||
from ..api import * # noqa: F403
|
||||
from ..config import RedisKVStoreImplConfig
|
||||
from ..config import RedisKVStoreConfig
|
||||
|
||||
|
||||
class RedisKVStoreImpl(KVStore):
|
||||
def __init__(self, config: RedisKVStoreImplConfig):
|
||||
def __init__(self, config: RedisKVStoreConfig):
|
||||
self.config = config
|
||||
|
||||
async def initialize(self) -> None:
|
||||
|
@ -33,20 +33,19 @@ class RedisKVStoreImpl(KVStore):
|
|||
if expiration:
|
||||
await self.redis.expireat(key, expiration)
|
||||
|
||||
async def get(self, key: str) -> Optional[KVStoreValue]:
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
key = self._namespaced_key(key)
|
||||
value = await self.redis.get(key)
|
||||
if value is None:
|
||||
return None
|
||||
ttl = await self.redis.ttl(key)
|
||||
expiration = datetime.now() + timedelta(seconds=ttl) if ttl > 0 else None
|
||||
return KVStoreValue(key=key, value=value, expiration=expiration)
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
key = self._namespaced_key(key)
|
||||
await self.redis.delete(key)
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]:
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
start_key = self._namespaced_key(start_key)
|
||||
end_key = self._namespaced_key(end_key)
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
|
@ -16,9 +18,10 @@ from ..config import SqliteKVStoreConfig
|
|||
class SqliteKVStoreImpl(KVStore):
|
||||
def __init__(self, config: SqliteKVStoreConfig):
|
||||
self.db_path = config.db_path
|
||||
self.table_name = config.table_name
|
||||
self.table_name = "kvstore"
|
||||
|
||||
async def initialize(self):
|
||||
os.makedirs(os.path.dirname(self.db_path), exist_ok=True)
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(
|
||||
f"""
|
||||
|
@ -41,7 +44,7 @@ class SqliteKVStoreImpl(KVStore):
|
|||
)
|
||||
await db.commit()
|
||||
|
||||
async def get(self, key: str) -> Optional[KVStoreValue]:
|
||||
async def get(self, key: str) -> Optional[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT value, expiration FROM {self.table_name} WHERE key = ?", (key,)
|
||||
|
@ -50,14 +53,14 @@ class SqliteKVStoreImpl(KVStore):
|
|||
if row is None:
|
||||
return None
|
||||
value, expiration = row
|
||||
return KVStoreValue(key=key, value=value, expiration=expiration)
|
||||
return value
|
||||
|
||||
async def delete(self, key: str) -> None:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
await db.execute(f"DELETE FROM {self.table_name} WHERE key = ?", (key,))
|
||||
await db.commit()
|
||||
|
||||
async def range(self, start_key: str, end_key: str) -> List[KVStoreValue]:
|
||||
async def range(self, start_key: str, end_key: str) -> List[str]:
|
||||
async with aiosqlite.connect(self.db_path) as db:
|
||||
async with db.execute(
|
||||
f"SELECT key, value, expiration FROM {self.table_name} WHERE key >= ? AND key <= ?",
|
||||
|
@ -65,8 +68,6 @@ class SqliteKVStoreImpl(KVStore):
|
|||
) as cursor:
|
||||
result = []
|
||||
async for row in cursor:
|
||||
key, value, expiration = row
|
||||
result.append(
|
||||
KVStoreValue(key=key, value=value, expiration=expiration)
|
||||
)
|
||||
_, value, _ = row
|
||||
result.append(value)
|
||||
return result
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue