removed StackRunConfig, added policy, removed unnecessary hasattr calls, fixed first/last id, and updated tests

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
Francisco Javier Arceo 2025-10-02 15:27:50 -04:00
parent 387a2a5de8
commit 21ae267feb
2 changed files with 44 additions and 36 deletions

View file

@ -21,12 +21,16 @@ from llama_stack.apis.conversations.conversations import (
Conversations,
Metadata,
)
from llama_stack.core.datatypes import AccessRule, StackRunConfig
from llama_stack.core.datatypes import AccessRule
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
from llama_stack.providers.utils.sqlstore.sqlstore import (
SqliteSqlStoreConfig,
SqlStoreConfig,
sqlstore_impl,
)
logger = get_logger(name=__name__, category="openai::conversations")
@ -34,10 +38,14 @@ logger = get_logger(name=__name__, category="openai::conversations")
class ConversationServiceConfig(BaseModel):
"""Configuration for the built-in conversation service.
:param run_config: Stack run configuration containing distribution info
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
:param policy: Access control rules
"""
run_config: StackRunConfig
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
)
policy: list[AccessRule] = []
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
@ -53,23 +61,15 @@ class ConversationServiceImpl(Conversations):
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
self.config = config
self.deps = deps
self.policy: list[AccessRule] = []
self.policy = config.policy
conversations_store_config = config.run_config.conversations_store
if conversations_store_config is None:
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
)
else:
sql_store_config = conversations_store_config
base_sql_store = sqlstore_impl(sql_store_config)
base_sql_store = sqlstore_impl(config.conversations_store)
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
async def initialize(self) -> None:
"""Initialize the store and create tables."""
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
await self.sql_store.create_table(
"openai_conversations",
@ -91,8 +91,7 @@ class ConversationServiceImpl(Conversations):
items_json = []
for item in items or []:
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
items_json.append(item_dict)
items_json.append(item.model_dump())
record_data = {
"id": conversation_id,
@ -170,7 +169,7 @@ class ConversationServiceImpl(Conversations):
item_id = f"item_{random_bytes.hex()}"
# Create a copy of the item with the generated ID and completed status
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
item_dict = item.model_dump()
item_dict["id"] = item_id
if "status" not in item_dict:
item_dict["status"] = "completed"
@ -238,15 +237,9 @@ class ConversationServiceImpl(Conversations):
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
]
# Get first and last IDs safely
first_id = None
last_id = None
if items:
first_item = items[0]
last_item = items[-1]
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
# Get first and last IDs from converted response items
first_id = response_items[0].id if response_items else None
last_id = response_items[-1].id if response_items else None
return ConversationItemList(
data=response_items,

View file

@ -20,7 +20,6 @@ from llama_stack.core.conversations.conversations import (
ConversationServiceConfig,
ConversationServiceImpl,
)
from llama_stack.core.datatypes import StackRunConfig
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
@ -29,13 +28,7 @@ async def service():
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations.db"
config = ConversationServiceConfig(
run_config=StackRunConfig(
image_name="test",
providers={},
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)),
)
)
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
service = ConversationServiceImpl(config, {})
await service.initialize()
yield service
@ -115,3 +108,25 @@ async def test_openai_type_compatibility(service):
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
openai_item_adapter.validate_python(item_dict)
async def test_policy_configuration():
from llama_stack.core.access_control.datatypes import Action, Scope
from llama_stack.core.datatypes import AccessRule
with tempfile.TemporaryDirectory() as tmpdir:
db_path = Path(tmpdir) / "test_conversations_policy.db"
restrictive_policy = [
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
]
config = ConversationServiceConfig(
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
)
service = ConversationServiceImpl(config, {})
await service.initialize()
assert service.policy == restrictive_policy
assert len(service.policy) == 1
assert service.policy[0].forbid is not None