mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-03 19:57:35 +00:00
removed StackRunConfig, added policy, removed unnecessary hasattr calls, fixed first/last id, and updated tests
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
387a2a5de8
commit
21ae267feb
2 changed files with 44 additions and 36 deletions
|
@ -21,12 +21,16 @@ from llama_stack.apis.conversations.conversations import (
|
|||
Conversations,
|
||||
Metadata,
|
||||
)
|
||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||
SqliteSqlStoreConfig,
|
||||
SqlStoreConfig,
|
||||
sqlstore_impl,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="openai::conversations")
|
||||
|
||||
|
@ -34,10 +38,14 @@ logger = get_logger(name=__name__, category="openai::conversations")
|
|||
class ConversationServiceConfig(BaseModel):
|
||||
"""Configuration for the built-in conversation service.
|
||||
|
||||
:param run_config: Stack run configuration containing distribution info
|
||||
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||
:param policy: Access control rules
|
||||
"""
|
||||
|
||||
run_config: StackRunConfig
|
||||
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||
)
|
||||
policy: list[AccessRule] = []
|
||||
|
||||
|
||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
|
@ -53,23 +61,15 @@ class ConversationServiceImpl(Conversations):
|
|||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||
self.config = config
|
||||
self.deps = deps
|
||||
self.policy: list[AccessRule] = []
|
||||
self.policy = config.policy
|
||||
|
||||
conversations_store_config = config.run_config.conversations_store
|
||||
if conversations_store_config is None:
|
||||
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
|
||||
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
|
||||
)
|
||||
else:
|
||||
sql_store_config = conversations_store_config
|
||||
|
||||
base_sql_store = sqlstore_impl(sql_store_config)
|
||||
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialize the store and create tables."""
|
||||
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
|
||||
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
|
||||
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||
|
||||
await self.sql_store.create_table(
|
||||
"openai_conversations",
|
||||
|
@ -91,8 +91,7 @@ class ConversationServiceImpl(Conversations):
|
|||
|
||||
items_json = []
|
||||
for item in items or []:
|
||||
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
|
||||
items_json.append(item_dict)
|
||||
items_json.append(item.model_dump())
|
||||
|
||||
record_data = {
|
||||
"id": conversation_id,
|
||||
|
@ -170,7 +169,7 @@ class ConversationServiceImpl(Conversations):
|
|||
item_id = f"item_{random_bytes.hex()}"
|
||||
|
||||
# Create a copy of the item with the generated ID and completed status
|
||||
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
|
||||
item_dict = item.model_dump()
|
||||
item_dict["id"] = item_id
|
||||
if "status" not in item_dict:
|
||||
item_dict["status"] = "completed"
|
||||
|
@ -238,15 +237,9 @@ class ConversationServiceImpl(Conversations):
|
|||
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
|
||||
]
|
||||
|
||||
# Get first and last IDs safely
|
||||
first_id = None
|
||||
last_id = None
|
||||
if items:
|
||||
first_item = items[0]
|
||||
last_item = items[-1]
|
||||
|
||||
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
|
||||
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
|
||||
# Get first and last IDs from converted response items
|
||||
first_id = response_items[0].id if response_items else None
|
||||
last_id = response_items[-1].id if response_items else None
|
||||
|
||||
return ConversationItemList(
|
||||
data=response_items,
|
||||
|
|
|
@ -20,7 +20,6 @@ from llama_stack.core.conversations.conversations import (
|
|||
ConversationServiceConfig,
|
||||
ConversationServiceImpl,
|
||||
)
|
||||
from llama_stack.core.datatypes import StackRunConfig
|
||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||
|
||||
|
||||
|
@ -29,13 +28,7 @@ async def service():
|
|||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations.db"
|
||||
|
||||
config = ConversationServiceConfig(
|
||||
run_config=StackRunConfig(
|
||||
image_name="test",
|
||||
providers={},
|
||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)),
|
||||
)
|
||||
)
|
||||
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
yield service
|
||||
|
@ -115,3 +108,25 @@ async def test_openai_type_compatibility(service):
|
|||
|
||||
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
|
||||
openai_item_adapter.validate_python(item_dict)
|
||||
|
||||
|
||||
async def test_policy_configuration():
|
||||
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||
from llama_stack.core.datatypes import AccessRule
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
db_path = Path(tmpdir) / "test_conversations_policy.db"
|
||||
|
||||
restrictive_policy = [
|
||||
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||
]
|
||||
|
||||
config = ConversationServiceConfig(
|
||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||
)
|
||||
service = ConversationServiceImpl(config, {})
|
||||
await service.initialize()
|
||||
|
||||
assert service.policy == restrictive_policy
|
||||
assert len(service.policy) == 1
|
||||
assert service.policy[0].forbid is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue