llama-stack/llama_stack/providers/utils/kvstore/postgres/postgres.py
Dinesh Yeduguru dcd8cfe0f3
add postgres kvstoreimpl (#374)
* add postgres kvstoreimpl

* make table name configurable

* add validator for table name

* linter fix

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
2024-11-05 11:42:21 -08:00

103 lines
3.2 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional
import psycopg2
from psycopg2.extras import DictCursor
from ..api import KVStore
from ..config import PostgresKVStoreConfig
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
self.config = config
self.conn = None
self.cursor = None
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
# Create table if it doesn't exist
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
INSERT INTO {self.config.table_name} (key, value, expiration)
VALUES (%s, %s, %s)
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
""",
(key, value, expiration),
)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key = %s
AND (expiration IS NULL OR expiration > NOW())
""",
(key,),
)
result = self.cursor.fetchone()
return result[0] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"DELETE FROM {self.config.table_name} WHERE key = %s",
(key,),
)
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key >= %s AND key < %s
AND (expiration IS NULL OR expiration > NOW())
ORDER BY key
""",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]