add postgres kvstoreimpl (#374)

* add postgres kvstoreimpl

* make table name configurable

* add validator for table name

* linter fix

---------

Co-authored-by: Dinesh Yeduguru <dineshyv@fb.com>
This commit is contained in:
Dinesh Yeduguru 2024-11-05 11:42:21 -08:00 committed by GitHub
parent 8de845a96d
commit dcd8cfe0f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 131 additions and 2 deletions

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from datetime import datetime
from typing import List, Optional
import psycopg2
from psycopg2.extras import DictCursor
from ..api import KVStore
from ..config import PostgresKVStoreConfig
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
self.config = config
self.conn = None
self.cursor = None
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
# Create table if it doesn't exist
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
except Exception as e:
import traceback
traceback.print_exc()
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
INSERT INTO {self.config.table_name} (key, value, expiration)
VALUES (%s, %s, %s)
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
""",
(key, value, expiration),
)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key = %s
AND (expiration IS NULL OR expiration > NOW())
""",
(key,),
)
result = self.cursor.fetchone()
return result[0] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"DELETE FROM {self.config.table_name} WHERE key = %s",
(key,),
)
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key >= %s AND key < %s
AND (expiration IS NULL OR expiration > NOW())
ORDER BY key
""",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]