mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-30 16:53:12 +00:00
Core implementation of postgresql dataset for llama stack Signed-off-by: Josh Salomon <jsalomon@redhat.com>
164 lines
7 KiB
Python
164 lines
7 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
import logging
|
|
from typing import Any, Dict, List, Optional
|
|
|
|
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
|
|
from llama_stack.apis.datasets import Dataset
|
|
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
|
from llama_stack.providers.utils.common.provider_utils import get_provider_type
|
|
from llama_stack.providers.utils.kvstore import KVStore
|
|
|
|
# from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
|
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
|
|
|
from psycopg_pool import AsyncConnectionPool
|
|
|
|
from .config import PostgreSQLDatasetIOConfig
|
|
|
|
# from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool
|
|
from .pg_tools import get_config_from_uri, check_table_and_schema, create_connection_pool
|
|
from .pg_tools import build_select_statement, get_row_count, rows_to_iterrows_response
|
|
from .pg_tools import get_table_name
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
DATASETS_PREFIX = "datasets:"
|
|
|
|
|
|
class PostgreSQLDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|
def __init__(self, config: PostgreSQLDatasetIOConfig) -> None:
|
|
self.config = config
|
|
# local registry for keeping track of datasets within the provider
|
|
self.dataset_infos: Dict[str, Dataset] = {}
|
|
self.conn_pools: Dict[str, AsyncConnectionPool] = {}
|
|
self.row_counts: Dict[str, int] = {}
|
|
self.kvstore: KVStore | None = None
|
|
|
|
async def initialize(self) -> None:
|
|
self.kvstore = await kvstore_impl(self.config.kvstore)
|
|
# Load existing datasets from kvstore
|
|
start_key = DATASETS_PREFIX
|
|
end_key = f"{DATASETS_PREFIX}\xff"
|
|
stored_datasets = await self.kvstore.range(start_key, end_key)
|
|
|
|
for ds in stored_datasets:
|
|
dataset = Dataset.model_validate_json(ds)
|
|
pg_config_info, _ = get_config_from_uri(dataset.source.uri, self.config)
|
|
self.dataset_infos[dataset.identifier] = dataset
|
|
try:
|
|
conn_pool = await create_connection_pool(3, pg_config_info)
|
|
self.conn_pools[dataset.identifier] = conn_pool
|
|
except Exception as e:
|
|
log.error(f"Failed to create connection pool for dataset on initialization {dataset.identifier}: {e}")
|
|
|
|
async def shutdown(self) -> None: ...
|
|
|
|
async def register_dataset(
|
|
self,
|
|
dataset_def: Dataset,
|
|
) -> None:
|
|
# Store in kvstore
|
|
provider_type = get_provider_type(self.__module__)
|
|
if self.dataset_infos.get(dataset_def.identifier):
|
|
log.error(
|
|
f"Failed to register dataset {dataset_def.identifier}. " + "Dataset with this name alreadt exists"
|
|
)
|
|
raise ValueError(f"Dataset {dataset_def.identifier} already exists")
|
|
|
|
pg_connection_info, table = get_config_from_uri(dataset_def.source.uri, self.config)
|
|
tbmd = dataset_def.metadata.get("table", None)
|
|
if table is not None:
|
|
# Uri setting overrides metadata setting
|
|
dataset_def.metadata["table"] = table ## logging table for future use.
|
|
|
|
if tbmd and table and tbmd != table:
|
|
log.warning(
|
|
f"Table name mismatch for dataset {dataset_def.identifier}: metadata:{tbmd} != uri:{table}, using {table}"
|
|
)
|
|
elif get_table_name(dataset_def) is None: # should have been set by now
|
|
log.error(
|
|
f"No table defined for dataset: {provider_type}::{dataset_def.provider_id}({dataset_def.identifier})"
|
|
)
|
|
raise ValueError(f"No table defined for dataset {dataset_def.identifier}")
|
|
|
|
try:
|
|
pool = await create_connection_pool(3, pg_connection_info)
|
|
async with pool.connection() as conn:
|
|
await check_table_and_schema(dataset_def, conn, provider_type)
|
|
except ValueError:
|
|
# these are already logged in check_table_and_schema
|
|
raise
|
|
except Exception as e:
|
|
log.error(f"Error: {e}")
|
|
raise
|
|
|
|
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
|
|
|
|
await self.kvstore.set(
|
|
key=key,
|
|
value=dataset_def.model_dump_json(),
|
|
)
|
|
self.dataset_infos[dataset_def.identifier] = dataset_def
|
|
self.conn_pools[dataset_def.identifier] = pool
|
|
|
|
async def unregister_dataset(self, dataset_id: str) -> None:
|
|
if self.conn_pools[dataset_id] is not None:
|
|
await self.conn_pools[dataset_id].close()
|
|
del self.conn_pools[dataset_id]
|
|
key = f"{DATASETS_PREFIX}{dataset_id}"
|
|
await self.kvstore.delete(key=key)
|
|
del self.dataset_infos[dataset_id]
|
|
|
|
async def iterrows(
|
|
self,
|
|
dataset_id: str,
|
|
start_index: Optional[int] = None,
|
|
limit: Optional[int] = None,
|
|
) -> IterrowsResponse:
|
|
if start_index is not None and start_index < 0:
|
|
raise ValueError(f"start_index ({start_index}) must be a non-negative integer")
|
|
|
|
dataset_def = self.dataset_infos[dataset_id]
|
|
pool = self.conn_pools[dataset_id]
|
|
if pool is None: # Retry to crate connection pool
|
|
try:
|
|
pg_config_info, _ = get_config_from_uri(dataset_def.source.uri, self.config)
|
|
pool = await create_connection_pool(3, pg_config_info)
|
|
self.conn_pools[dataset_def.identifier] = pool
|
|
except Exception as e:
|
|
log.error(f"Failed to create connection pool for dataset {dataset_def.identifier}: {e}")
|
|
raise
|
|
|
|
try:
|
|
async with pool.connection() as conn:
|
|
await pool._check_connection(conn)
|
|
stmnt = build_select_statement(dataset_def, conn, start_index=start_index, limit=limit)
|
|
if self.row_counts.get(dataset_def.identifier) is None or start_index is None or start_index < 3:
|
|
# get row count only once per iteration
|
|
self.row_counts[dataset_def.identifier] = await get_row_count(conn, get_table_name(dataset_def))
|
|
async with conn.cursor() as cur:
|
|
await cur.execute(stmnt)
|
|
rows = await cur.fetchall()
|
|
data = await rows_to_iterrows_response(rows, conn, get_table_name(dataset_def))
|
|
except Exception as e:
|
|
log.error(f"Error: {e}")
|
|
raise
|
|
|
|
begin = 0 if start_index is None else start_index
|
|
end = begin + len(data)
|
|
|
|
return IterrowsResponse(
|
|
data=data,
|
|
next_start_index=end if end < self.row_counts[dataset_def.identifier] else None,
|
|
)
|
|
|
|
# TODO: Implement filtering
|
|
|
|
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
|
|
# This inteface is not implemented in the DatasetsResource class and is not
|
|
# accessible via the client.
|
|
raise NotImplementedError("Uploading to postgresql dataset is not supported yet")
|