mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
removed StackRunConfig, added policy, removed unnecessary hasattr calls, fixed first/last id, and updated tests
Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
This commit is contained in:
parent
387a2a5de8
commit
21ae267feb
2 changed files with 44 additions and 36 deletions
|
@ -21,12 +21,16 @@ from llama_stack.apis.conversations.conversations import (
|
||||||
Conversations,
|
Conversations,
|
||||||
Metadata,
|
Metadata,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import AccessRule, StackRunConfig
|
from llama_stack.core.datatypes import AccessRule
|
||||||
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType
|
||||||
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import PostgresSqlStoreConfig, SqliteSqlStoreConfig, sqlstore_impl
|
from llama_stack.providers.utils.sqlstore.sqlstore import (
|
||||||
|
SqliteSqlStoreConfig,
|
||||||
|
SqlStoreConfig,
|
||||||
|
sqlstore_impl,
|
||||||
|
)
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="openai::conversations")
|
logger = get_logger(name=__name__, category="openai::conversations")
|
||||||
|
|
||||||
|
@ -34,10 +38,14 @@ logger = get_logger(name=__name__, category="openai::conversations")
|
||||||
class ConversationServiceConfig(BaseModel):
|
class ConversationServiceConfig(BaseModel):
|
||||||
"""Configuration for the built-in conversation service.
|
"""Configuration for the built-in conversation service.
|
||||||
|
|
||||||
:param run_config: Stack run configuration containing distribution info
|
:param conversations_store: SQL store configuration for conversations (defaults to SQLite)
|
||||||
|
:param policy: Access control rules
|
||||||
"""
|
"""
|
||||||
|
|
||||||
run_config: StackRunConfig
|
conversations_store: SqlStoreConfig = SqliteSqlStoreConfig(
|
||||||
|
db_path=(DISTRIBS_BASE_DIR / "conversations.db").as_posix()
|
||||||
|
)
|
||||||
|
policy: list[AccessRule] = []
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
async def get_provider_impl(config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||||
|
@ -53,23 +61,15 @@ class ConversationServiceImpl(Conversations):
|
||||||
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
def __init__(self, config: ConversationServiceConfig, deps: dict[Any, Any]):
|
||||||
self.config = config
|
self.config = config
|
||||||
self.deps = deps
|
self.deps = deps
|
||||||
self.policy: list[AccessRule] = []
|
self.policy = config.policy
|
||||||
|
|
||||||
conversations_store_config = config.run_config.conversations_store
|
base_sql_store = sqlstore_impl(config.conversations_store)
|
||||||
if conversations_store_config is None:
|
|
||||||
sql_store_config: SqliteSqlStoreConfig | PostgresSqlStoreConfig = SqliteSqlStoreConfig(
|
|
||||||
db_path=(DISTRIBS_BASE_DIR / config.run_config.image_name / "conversations.db").as_posix()
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sql_store_config = conversations_store_config
|
|
||||||
|
|
||||||
base_sql_store = sqlstore_impl(sql_store_config)
|
|
||||||
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
self.sql_store = AuthorizedSqlStore(base_sql_store, self.policy)
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
"""Initialize the store and create tables."""
|
"""Initialize the store and create tables."""
|
||||||
if hasattr(self.sql_store.sql_store, "config") and hasattr(self.sql_store.sql_store.config, "db_path"):
|
if isinstance(self.config.conversations_store, SqliteSqlStoreConfig):
|
||||||
os.makedirs(os.path.dirname(self.sql_store.sql_store.config.db_path), exist_ok=True)
|
os.makedirs(os.path.dirname(self.config.conversations_store.db_path), exist_ok=True)
|
||||||
|
|
||||||
await self.sql_store.create_table(
|
await self.sql_store.create_table(
|
||||||
"openai_conversations",
|
"openai_conversations",
|
||||||
|
@ -91,8 +91,7 @@ class ConversationServiceImpl(Conversations):
|
||||||
|
|
||||||
items_json = []
|
items_json = []
|
||||||
for item in items or []:
|
for item in items or []:
|
||||||
item_dict = item.model_dump() if hasattr(item, "model_dump") else item
|
items_json.append(item.model_dump())
|
||||||
items_json.append(item_dict)
|
|
||||||
|
|
||||||
record_data = {
|
record_data = {
|
||||||
"id": conversation_id,
|
"id": conversation_id,
|
||||||
|
@ -170,7 +169,7 @@ class ConversationServiceImpl(Conversations):
|
||||||
item_id = f"item_{random_bytes.hex()}"
|
item_id = f"item_{random_bytes.hex()}"
|
||||||
|
|
||||||
# Create a copy of the item with the generated ID and completed status
|
# Create a copy of the item with the generated ID and completed status
|
||||||
item_dict = item.model_dump() if hasattr(item, "model_dump") else dict(item)
|
item_dict = item.model_dump()
|
||||||
item_dict["id"] = item_id
|
item_dict["id"] = item_id
|
||||||
if "status" not in item_dict:
|
if "status" not in item_dict:
|
||||||
item_dict["status"] = "completed"
|
item_dict["status"] = "completed"
|
||||||
|
@ -238,15 +237,9 @@ class ConversationServiceImpl(Conversations):
|
||||||
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
|
adapter.validate_python(item) if isinstance(item, dict) else item for item in items
|
||||||
]
|
]
|
||||||
|
|
||||||
# Get first and last IDs safely
|
# Get first and last IDs from converted response items
|
||||||
first_id = None
|
first_id = response_items[0].id if response_items else None
|
||||||
last_id = None
|
last_id = response_items[-1].id if response_items else None
|
||||||
if items:
|
|
||||||
first_item = items[0]
|
|
||||||
last_item = items[-1]
|
|
||||||
|
|
||||||
first_id = first_item.get("id") if isinstance(first_item, dict) else getattr(first_item, "id", None)
|
|
||||||
last_id = last_item.get("id") if isinstance(last_item, dict) else getattr(last_item, "id", None)
|
|
||||||
|
|
||||||
return ConversationItemList(
|
return ConversationItemList(
|
||||||
data=response_items,
|
data=response_items,
|
||||||
|
|
|
@ -20,7 +20,6 @@ from llama_stack.core.conversations.conversations import (
|
||||||
ConversationServiceConfig,
|
ConversationServiceConfig,
|
||||||
ConversationServiceImpl,
|
ConversationServiceImpl,
|
||||||
)
|
)
|
||||||
from llama_stack.core.datatypes import StackRunConfig
|
|
||||||
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
from llama_stack.providers.utils.sqlstore.sqlstore import SqliteSqlStoreConfig
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,13 +28,7 @@ async def service():
|
||||||
with tempfile.TemporaryDirectory() as tmpdir:
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
db_path = Path(tmpdir) / "test_conversations.db"
|
db_path = Path(tmpdir) / "test_conversations.db"
|
||||||
|
|
||||||
config = ConversationServiceConfig(
|
config = ConversationServiceConfig(conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=[])
|
||||||
run_config=StackRunConfig(
|
|
||||||
image_name="test",
|
|
||||||
providers={},
|
|
||||||
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
service = ConversationServiceImpl(config, {})
|
service = ConversationServiceImpl(config, {})
|
||||||
await service.initialize()
|
await service.initialize()
|
||||||
yield service
|
yield service
|
||||||
|
@ -115,3 +108,25 @@ async def test_openai_type_compatibility(service):
|
||||||
|
|
||||||
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
|
openai_item_adapter = TypeAdapter(OpenAIConversationItem)
|
||||||
openai_item_adapter.validate_python(item_dict)
|
openai_item_adapter.validate_python(item_dict)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_policy_configuration():
|
||||||
|
from llama_stack.core.access_control.datatypes import Action, Scope
|
||||||
|
from llama_stack.core.datatypes import AccessRule
|
||||||
|
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = Path(tmpdir) / "test_conversations_policy.db"
|
||||||
|
|
||||||
|
restrictive_policy = [
|
||||||
|
AccessRule(forbid=Scope(principal="test_user", actions=[Action.CREATE, Action.READ], resource="*"))
|
||||||
|
]
|
||||||
|
|
||||||
|
config = ConversationServiceConfig(
|
||||||
|
conversations_store=SqliteSqlStoreConfig(db_path=str(db_path)), policy=restrictive_policy
|
||||||
|
)
|
||||||
|
service = ConversationServiceImpl(config, {})
|
||||||
|
await service.initialize()
|
||||||
|
|
||||||
|
assert service.policy == restrictive_policy
|
||||||
|
assert len(service.policy) == 1
|
||||||
|
assert service.policy[0].forbid is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue