Merge branch 'main' into use-openai-for-databricks

This commit is contained in:
Matthew Farrellee 2025-09-20 06:16:54 -04:00
commit c8623607f5
31 changed files with 665 additions and 691 deletions

View file

@ -1,6 +1,5 @@
adapter:
adapter_type: kaze
pip_packages: ["tests/external/llama-stack-provider-kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
adapter_type: kaze
pip_packages: ["tests/external/llama-stack-provider-kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig
module: llama_stack_provider_kaze
optional_api_dependencies: []

View file

@ -6,7 +6,7 @@
from typing import Protocol
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
from llama_stack.schema_utils import webmethod
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather,
provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
),
]

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func()
base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2
# Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
# Should see:
# - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
),
]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally:

View file

@ -66,10 +66,9 @@ def base_config(tmp_path):
def provider_spec_yaml():
"""Common provider spec YAML for testing."""
return """
adapter:
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
api_dependencies:
- safety
"""
@ -182,9 +181,9 @@ class TestProviderRegistry:
assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
assert provider.adapter_type == "test_provider"
assert provider.module == "test_provider"
assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
@ -246,8 +245,7 @@ class TestProviderRegistry:
"""Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories
malformed_spec = """
adapter:
adapter_type: test_provider
adapter_type: test_provider
# Missing required fields
api_dependencies:
- safety
@ -270,7 +268,7 @@ pip_packages:
with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info:
with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control
await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document"
# User should only see their document
mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1
assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None
assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table(
table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy)
sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set()
for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name,
)
)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table(
table="user_data",