mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-31 08:13:53 +00:00
fix unit tests
This commit is contained in:
parent
b937a49436
commit
01fac67e33
6 changed files with 37 additions and 40 deletions
|
|
@ -23,10 +23,7 @@ class RequestProviderDataContext(ContextManager):
|
|||
def __init__(
|
||||
self, provider_data: Optional[Dict[str, Any]] = None, auth_attributes: Optional[Dict[str, List[str]]] = None
|
||||
):
|
||||
# Initialize with either provider_data or create a new dict
|
||||
self.provider_data = provider_data or {}
|
||||
|
||||
# Add auth attributes under a special key if provided
|
||||
if auth_attributes:
|
||||
self.provider_data["__auth_attributes"] = auth_attributes
|
||||
|
||||
|
|
|
|||
|
|
@ -274,17 +274,14 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
if not hasattr(obj, "access_attributes") or not obj.access_attributes:
|
||||
return True
|
||||
|
||||
# Get user attributes from context
|
||||
# Get user attributes from request context
|
||||
user_attributes = get_auth_attributes()
|
||||
|
||||
# If no user attributes, deny access to objects with access control
|
||||
if not user_attributes:
|
||||
return False
|
||||
|
||||
# Convert AccessAttributes to dictionary for checking
|
||||
obj_attributes = obj.access_attributes.model_dump(exclude_none=True)
|
||||
|
||||
# If the model_dump is empty (all fields are None), allow access
|
||||
if not obj_attributes:
|
||||
return True
|
||||
|
||||
|
|
@ -292,14 +289,12 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for attr_key, required_values in obj_attributes.items():
|
||||
user_values = user_attributes.get(attr_key, [])
|
||||
|
||||
# No values for this category in user attributes
|
||||
if not user_values:
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': missing required attribute category '{attr_key}'"
|
||||
)
|
||||
return False
|
||||
|
||||
# None of the values in this category match (need at least one match per category)
|
||||
if not any(val in user_values for val in required_values):
|
||||
logger.debug(
|
||||
f"Access denied to {obj.type} '{obj.identifier}': "
|
||||
|
|
|
|||
|
|
@ -166,18 +166,16 @@ class AuthenticationMiddleware:
|
|||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
|
||||
auth_response = AuthResponse(**response_data)
|
||||
|
||||
# Store attributes in request scope for access control
|
||||
if auth_response.access_attributes:
|
||||
user_attributes = auth_response.access_attributes.model_dump(exclude_none=True)
|
||||
scope["user_attributes"] = user_attributes
|
||||
else:
|
||||
logger.warning("Authentication response did not contain any attributes")
|
||||
scope["user_attributes"] = {}
|
||||
user_attributes = {}
|
||||
|
||||
# Log authentication success with attribute details
|
||||
scope["user_attributes"] = user_attributes
|
||||
logger.debug(f"Authentication successful: {len(user_attributes)} attributes")
|
||||
except Exception:
|
||||
logger.exception("Error parsing authentication response")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue