diff --git a/llama_stack/providers/utils/kvstore/kvstore.py b/llama_stack/providers/utils/kvstore/kvstore.py index a3cabc206..469f400d0 100644 --- a/llama_stack/providers/utils/kvstore/kvstore.py +++ b/llama_stack/providers/utils/kvstore/kvstore.py @@ -43,7 +43,9 @@ async def kvstore_impl(config: KVStoreConfig) -> KVStore: impl = SqliteKVStoreImpl(config) elif config.type == KVStoreType.postgres.value: - raise NotImplementedError() + from .postgres import PostgresKVStoreImpl + + impl = PostgresKVStoreImpl(config) else: raise ValueError(f"Unknown kvstore type {config.type}") diff --git a/llama_stack/providers/utils/kvstore/postgres/__init__.py b/llama_stack/providers/utils/kvstore/postgres/__init__.py new file mode 100644 index 000000000..efbf6299d --- /dev/null +++ b/llama_stack/providers/utils/kvstore/postgres/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .postgres import PostgresKVStoreImpl # noqa: F401 F403 diff --git a/llama_stack/providers/utils/kvstore/postgres/postgres.py b/llama_stack/providers/utils/kvstore/postgres/postgres.py new file mode 100644 index 000000000..9f5a17552 --- /dev/null +++ b/llama_stack/providers/utils/kvstore/postgres/postgres.py @@ -0,0 +1,104 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from datetime import datetime +from typing import List, Optional + +import psycopg2 +from psycopg2.extras import DictCursor + +from ..api import KVStore +from ..config import PostgresKVStoreConfig + + +class PostgresKVStoreImpl(KVStore): + def __init__(self, config: PostgresKVStoreConfig): + self.config = config + self.table_name = "kvstore" + self.conn = None + self.cursor = None + + async def initialize(self) -> None: + try: + self.conn = psycopg2.connect( + host=self.config.host, + port=self.config.port, + database=self.config.db, + user=self.config.user, + password=self.config.password, + ) + self.conn.autocommit = True + self.cursor = self.conn.cursor(cursor_factory=DictCursor) + + # Create table if it doesn't exist + self.cursor.execute( + f""" + CREATE TABLE IF NOT EXISTS {self.table_name} ( + key TEXT PRIMARY KEY, + value TEXT, + expiration TIMESTAMP + ) + """ + ) + except Exception as e: + import traceback + + traceback.print_exc() + raise RuntimeError("Could not connect to PostgreSQL database server") from e + + def _namespaced_key(self, key: str) -> str: + if not self.config.namespace: + return key + return f"{self.config.namespace}:{key}" + + async def set( + self, key: str, value: str, expiration: Optional[datetime] = None + ) -> None: + key = self._namespaced_key(key) + self.cursor.execute( + f""" + INSERT INTO {self.table_name} (key, value, expiration) + VALUES (%s, %s, %s) + ON CONFLICT (key) DO UPDATE + SET value = EXCLUDED.value, expiration = EXCLUDED.expiration + """, + (key, value, expiration), + ) + + async def get(self, key: str) -> Optional[str]: + key = self._namespaced_key(key) + self.cursor.execute( + f""" + SELECT value FROM {self.table_name} + WHERE key = %s + AND (expiration IS NULL OR expiration > NOW()) + """, + (key,), + ) + result = self.cursor.fetchone() + return result[0] if result else None + + async def delete(self, key: str) -> None: + key = self._namespaced_key(key) + self.cursor.execute( + f"DELETE FROM {self.table_name} WHERE key = %s", + (key,), + ) + + async def range(self, start_key: str, end_key: str) -> List[str]: + start_key = self._namespaced_key(start_key) + end_key = self._namespaced_key(end_key) + + self.cursor.execute( + f""" + SELECT value FROM {self.table_name} + WHERE key >= %s AND key < %s + AND (expiration IS NULL OR expiration > NOW()) + ORDER BY key + """, + (start_key, end_key), + ) + return [row[0] for row in self.cursor.fetchall()]