llama-stack-mirror/llama_stack/providers/utils/kvstore/postgres/postgres.py
Dinesh Yeduguru 6395dadc2b
use logging instead of prints (#499)
# What does this PR do?

This PR moves all print statements to use logging. Things changed:
- Had to add `await start_trace("sse_generator")` to server.py to
actually get tracing working. else was not seeing any logs
- If no telemetry provider is provided in the run.yaml, we will write to
stdout
- by default, the logs are going to be in JSON, but we expose an option
to configure to output in a human readable way.
2024-11-21 11:32:53 -08:00

105 lines
3.3 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from datetime import datetime
from typing import List, Optional
import psycopg2
from psycopg2.extras import DictCursor
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = logging.getLogger(__name__)
class PostgresKVStoreImpl(KVStore):
def __init__(self, config: PostgresKVStoreConfig):
self.config = config
self.conn = None
self.cursor = None
async def initialize(self) -> None:
try:
self.conn = psycopg2.connect(
host=self.config.host,
port=self.config.port,
database=self.config.db,
user=self.config.user,
password=self.config.password,
)
self.conn.autocommit = True
self.cursor = self.conn.cursor(cursor_factory=DictCursor)
# Create table if it doesn't exist
self.cursor.execute(
f"""
CREATE TABLE IF NOT EXISTS {self.config.table_name} (
key TEXT PRIMARY KEY,
value TEXT,
expiration TIMESTAMP
)
"""
)
except Exception as e:
log.exception("Could not connect to PostgreSQL database server")
raise RuntimeError("Could not connect to PostgreSQL database server") from e
def _namespaced_key(self, key: str) -> str:
if not self.config.namespace:
return key
return f"{self.config.namespace}:{key}"
async def set(
self, key: str, value: str, expiration: Optional[datetime] = None
) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
INSERT INTO {self.config.table_name} (key, value, expiration)
VALUES (%s, %s, %s)
ON CONFLICT (key) DO UPDATE
SET value = EXCLUDED.value, expiration = EXCLUDED.expiration
""",
(key, value, expiration),
)
async def get(self, key: str) -> Optional[str]:
key = self._namespaced_key(key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key = %s
AND (expiration IS NULL OR expiration > NOW())
""",
(key,),
)
result = self.cursor.fetchone()
return result[0] if result else None
async def delete(self, key: str) -> None:
key = self._namespaced_key(key)
self.cursor.execute(
f"DELETE FROM {self.config.table_name} WHERE key = %s",
(key,),
)
async def range(self, start_key: str, end_key: str) -> List[str]:
start_key = self._namespaced_key(start_key)
end_key = self._namespaced_key(end_key)
self.cursor.execute(
f"""
SELECT value FROM {self.config.table_name}
WHERE key >= %s AND key < %s
AND (expiration IS NULL OR expiration > NOW())
ORDER BY key
""",
(start_key, end_key),
)
return [row[0] for row in self.cursor.fetchall()]