llama-stack-mirror/llama_stack/providers/remote/datasetio/postgresql/postgresql.py
Josh Salomon 8e3b579df2 Initial commit for postgresql dataset provider
Core implementation of postgresql dataset for llama stack

Signed-off-by: Josh Salomon <jsalomon@redhat.com>
2025-04-14 14:22:54 +03:00

164 lines
7 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.common.provider_utils import get_provider_type
from llama_stack.providers.utils.kvstore import KVStore
# from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from psycopg_pool import AsyncConnectionPool
from .config import PostgreSQLDatasetIOConfig
# from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool
from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool
from .pg_tools import build_select_statement, get_row_count, rows_to_iterrows_response
from .pg_tools import get_table_name
log = logging.getLogger(__name__)
DATASETS_PREFIX = "datasets:"
class PostgreSQLDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
def __init__(self, config: PostgreSQLDatasetIOConfig) -> None:
self.config = config
# local registry for keeping track of datasets within the provider
self.dataset_infos: Dict[str, Dataset] = {}
self.conn_pools: Dict[str, AsyncConnectionPool] = {}
self.row_counts: Dict[str, int] = {}
self.kvstore: KVStore | None = None
async def initialize(self) -> None:
self.kvstore = await kvstore_impl(self.config.kvstore)
# Load existing datasets from kvstore
start_key = DATASETS_PREFIX
end_key = f"{DATASETS_PREFIX}\xff"
stored_datasets = await self.kvstore.range(start_key, end_key)
for ds in stored_datasets:
dataset = Dataset.model_validate_json(ds)
pg_config_info, _ = get_config_from_uri(dataset.source.uri, self.config)
self.dataset_infos[dataset.identifier] = dataset
try:
conn_pool = await create_connection_pool(3, pg_config_info)
self.conn_pools[dataset.identifier] = conn_pool
except Exception as e:
log.error(f"Failed to create connection pool for dataset on initialization {dataset.identifier}: {e}")
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset_def: Dataset,
) -> None:
# Store in kvstore
provider_type = get_provider_type(self.__module__)
if self.dataset_infos.get(dataset_def.identifier):
log.error(
f"Failed to register dataset {dataset_def.identifier}. " + "Dataset with this name alreadt exists"
)
raise ValueError(f"Dataset {dataset_def.identifier} already exists")
pg_connection_info, table = get_config_from_uri(dataset_def.source.uri, self.config)
tbmd = dataset_def.metadata.get("table", None)
if table is not None:
# Uri setting overrides metadata setting
dataset_def.metadata["table"] = table ## logging table for future use.
if tbmd and table and tbmd != table:
log.warning(
f"Table name mismatch for dataset {dataset_def.identifier}: metadata:{tbmd} != uri:{table}, using {table}"
)
elif get_table_name(dataset_def) is None: # should have been set by now
log.error(
f"No table defined for dataset: {provider_type}::{dataset_def.provider_id}({dataset_def.identifier})"
)
raise ValueError(f"No table defined for dataset {dataset_def.identifier}")
try:
pool = await create_connection_pool(3, pg_connection_info)
async with pool.connection() as conn:
await check_table_and_schema(dataset_def, conn, provider_type)
except ValueError:
# these are already logged in check_table_and_schema
raise
except Exception as e:
log.error(f"Error: {e}")
raise
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
self.conn_pools[dataset_def.identifier] = pool
async def unregister_dataset(self, dataset_id: str) -> None:
if self.conn_pools[dataset_id] is not None:
await self.conn_pools[dataset_id].close()
del self.conn_pools[dataset_id]
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
if start_index is not None and start_index < 0:
raise ValueError(f"start_index ({start_index}) must be a non-negative integer")
dataset_def = self.dataset_infos[dataset_id]
pool = self.conn_pools[dataset_id]
if pool is None: # Retry to crate connection pool
try:
pg_config_info, _ = get_config_from_uri(dataset_def.source.uri, self.config)
pool = await create_connection_pool(3, pg_config_info)
self.conn_pools[dataset_def.identifier] = pool
except Exception as e:
log.error(f"Failed to create connection pool for dataset {dataset_def.identifier}: {e}")
raise
try:
async with pool.connection() as conn:
await pool._check_connection(conn)
stmnt = build_select_statement(dataset_def, conn, start_index=start_index, limit=limit)
if self.row_counts.get(dataset_def.identifier) is None or start_index is None or start_index < 3:
# get row count only once per iteration
self.row_counts[dataset_def.identifier] = await get_row_count(conn, get_table_name(dataset_def))
async with conn.cursor() as cur:
await cur.execute(stmnt)
rows = await cur.fetchall()
data = await rows_to_iterrows_response(rows, conn, get_table_name(dataset_def))
except Exception as e:
log.error(f"Error: {e}")
raise
begin = 0 if start_index is None else start_index
end = begin + len(data)
return IterrowsResponse(
data=data,
next_start_index=end if end < self.row_counts[dataset_def.identifier] else None,
)
# TODO: Implement filtering
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
# This inteface is not implemented in the DatasetsResource class and is not
# accessible via the client.
raise NotImplementedError("Uploading to postgresql dataset is not supported yet")