removed StackRunConfig, added policy, removed unnecessary hasattr calls, fixed first/last id, and updated tests

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-02 15:27:50 -04:00
parent 387a2a5de8
commit 21ae267feb
2 changed files with 44 additions and 36 deletions

View file

@ -21,12 +21,16 @@ from llama_stack.apis.conversations.conversations import (
Conversations, Conversations,
Metadata, Metadata,
) )
from llama_stack.core.datatypes import AccessRule, StackRunConfig from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations") logger = get_logger(name=__name__, category="openai::conversations")
@ -34,10 +38,14 @@ logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel): class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service. """Configuration for the built-in conversation service.
:param run_config: Stack run configuration containing distribution info :param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param policy: Access control rules
""" """
run_config: StackRunConfig conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]): async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
@ -53,23 +61,15 @@ class ConversationServiceImpl(Conversations):
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]): def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config self.config = config
self.deps = deps self.deps = deps
self.policy: list[AccessRule] = [] self.policy = config.policy
conversations_store_config = config.run_config.conversations_store base_sql_store = sqlstore_impl(config.conversations_store)
if conversations_store_config is None:
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
)
else:
sql_store_config = conversations_store_config
base_sql_store = sqlstore_impl(sql_store_config)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy) self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None: async def initialize(self) -> None:
"""Initialize the store and create tables.""" """Initialize the store and create tables."""
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"): if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True) os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_conversations", "openai_conversations",
@ -91,8 +91,7 @@ class ConversationServiceImpl(Conversations):
items_json = [] items_json = []
for item in items or []: for item in items or []:
item_dict = item.model_dump() if hasattr(item, "model_dump") else item items_json.append(item.model_dump())
items_json.append(item_dict)
record_data = { record_data = {
"id": conversation_id, "id": conversation_id,
@ -170,7 +169,7 @@ class ConversationServiceImpl(Conversations):
item_id = f"item_{random_bytes.hex()}" item_id = f"item_{random_bytes.hex()}"
# Create a copy of the item with the generated ID and completed status # Create a copy of the item with the generated ID and completed status
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item) item_dict = item.model_dump()
item_dict["id"] = item_id item_dict["id"] = item_id
if "status" not in item_dict: if "status" not in item_dict:
item_dict["status"] = "completed" item_dict["status"] = "completed"
@ -238,15 +237,9 @@ class ConversationServiceImpl(Conversations):
adapter.validate_python(item) if isinstance(item, dict) else item for item in items adapter.validate_python(item) if isinstance(item, dict) else item for item in items
] ]
# Get first and last IDs safely # Get first and last IDs from converted response items
first_id = None first_id = response_items[0].id if response_items else None
last_id = None last_id = response_items[-1].id if response_items else None
if items:
first_item = items[0]
last_item = items[-1]
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
return ConversationItemList( return ConversationItemList(
data=response_items, data=response_items,

View file

@ -20,7 +20,6 @@ from llama_stack.core.conversations.conversations import (
ConversationServiceConfig, ConversationServiceConfig,
ConversationServiceImpl, ConversationServiceImpl,
) )
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -29,13 +28,7 @@ async def service():
with tempfile.TemporaryDirectory() as tmpdir: with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db" db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig( config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
run_config=StackRunConfig(
image_name="test",
providers={},
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)),
)
)
service = ConversationServiceImpl(config, {}) service = ConversationServiceImpl(config, {})
await service.initialize() await service.initialize()
yield service yield service
@ -115,3 +108,25 @@ async def test_openai_type_compatibility(service):
openai_item_adapter = TypeAdapter(OpenAIConversationItem) openai_item_adapter = TypeAdapter(OpenAIConversationItem)
openai_item_adapter.validate_python(item_dict) openai_item_adapter.validate_python(item_dict)
async def test_policy_configuration():
from llama_stack.core.access_control.datatypes import Action, Scope
from llama_stack.core.datatypes import AccessRule
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations_policy.db"
restrictive_policy = [
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
)
service = ConversationServiceImpl(config, {})
await service.initialize()
assert service.policy == restrictive_policy
assert len(service.policy) == 1
assert service.policy[0].forbid is not None