mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-24 08:47:26 +00:00
# What does this PR do? <!-- Provide a short summary of what this PR does and why. Link to relevant issues if applicable. --> - Added `set -e` to the beginning of the unit test script to ensure the script exits on failure and correctly fails the CI when tests do not pass. - Fixed all unit tests that were silently failing in the CI. - Fixed Python 3.13 unit test CI failing silently. <!-- If resolving an issue, uncomment and update the line below --> <!-- Closes #[issue-number] --> Closes #2877 ## Test Plan <!-- Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.* --> - **Previously:** Unit tests passing in CI eventhough it failed 11 tests -> [CI-run](4683681501 (step)
:4:2097) - **Made the fix. Now, ensuring CI fails as expected on test failures:** Unit tests failing in CI with 1 failed test -> [CI-run](4684234247 (step)
:4:1506) - This PR shows the CI passing and all unit tests passing.
200 lines
6.8 KiB
Python
200 lines
6.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
# All rights reserved.
|
|
#
|
|
# This source code is licensed under the terms described in the LICENSE file in
|
|
# the root directory of this source tree.
|
|
|
|
from unittest.mock import AsyncMock, patch
|
|
|
|
import httpx
|
|
import pytest
|
|
from fastapi import FastAPI
|
|
from fastapi.testclient import TestClient
|
|
|
|
from llama_stack.distribution.datatypes import AuthenticationConfig, AuthProviderType, GitHubTokenAuthConfig
|
|
from llama_stack.distribution.server.auth import AuthenticationMiddleware
|
|
|
|
|
|
class MockResponse:
|
|
def __init__(self, status_code, json_data):
|
|
self.status_code = status_code
|
|
self._json_data = json_data
|
|
|
|
def json(self):
|
|
return self._json_data
|
|
|
|
def raise_for_status(self):
|
|
if self.status_code != 200:
|
|
# Create a mock request for the HTTPStatusError
|
|
mock_request = httpx.Request("GET", "https://api.github.com/user")
|
|
raise httpx.HTTPStatusError(f"HTTP error: {self.status_code}", request=mock_request, response=self)
|
|
|
|
|
|
@pytest.fixture
|
|
def github_token_app():
|
|
app = FastAPI()
|
|
|
|
# Configure GitHub token auth
|
|
auth_config = AuthenticationConfig(
|
|
provider_config=GitHubTokenAuthConfig(
|
|
type=AuthProviderType.GITHUB_TOKEN,
|
|
github_api_base_url="https://api.github.com",
|
|
claims_mapping={
|
|
"login": "username",
|
|
"id": "user_id",
|
|
"organizations": "teams",
|
|
},
|
|
),
|
|
access_policy=[],
|
|
)
|
|
|
|
# Add auth middleware
|
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
|
|
|
@app.get("/test")
|
|
def test_endpoint():
|
|
return {"message": "Authentication successful"}
|
|
|
|
return app
|
|
|
|
|
|
@pytest.fixture
|
|
def github_token_client(github_token_app):
|
|
return TestClient(github_token_app)
|
|
|
|
|
|
def test_authenticated_endpoint_without_token(github_token_client):
|
|
"""Test accessing protected endpoint without token"""
|
|
response = github_token_client.get("/test")
|
|
assert response.status_code == 401
|
|
assert "Authentication required" in response.json()["error"]["message"]
|
|
assert "GitHub access token" in response.json()["error"]["message"]
|
|
|
|
|
|
def test_authenticated_endpoint_with_invalid_bearer_format(github_token_client):
|
|
"""Test accessing protected endpoint with invalid bearer format"""
|
|
response = github_token_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
|
assert response.status_code == 401
|
|
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
|
|
|
|
|
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
|
def test_authenticated_endpoint_with_valid_github_token(mock_client_class, github_token_client):
|
|
"""Test accessing protected endpoint with valid GitHub token"""
|
|
# Mock the GitHub API responses
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
|
|
# Mock successful user API response
|
|
mock_client.get.side_effect = [
|
|
MockResponse(
|
|
200,
|
|
{
|
|
"login": "testuser",
|
|
"id": 12345,
|
|
"email": "test@example.com",
|
|
"name": "Test User",
|
|
},
|
|
),
|
|
MockResponse(
|
|
200,
|
|
[
|
|
{"login": "test-org-1"},
|
|
{"login": "test-org-2"},
|
|
],
|
|
),
|
|
]
|
|
|
|
response = github_token_client.get("/test", headers={"Authorization": "Bearer github_token_123"})
|
|
assert response.status_code == 200
|
|
assert response.json()["message"] == "Authentication successful"
|
|
|
|
# Verify the GitHub API was called correctly
|
|
assert mock_client.get.call_count == 1
|
|
calls = mock_client.get.call_args_list
|
|
assert calls[0][0][0] == "https://api.github.com/user"
|
|
|
|
# Check authorization header was passed
|
|
assert calls[0][1]["headers"]["Authorization"] == "Bearer github_token_123"
|
|
|
|
|
|
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
|
def test_authenticated_endpoint_with_invalid_github_token(mock_client_class, github_token_client):
|
|
"""Test accessing protected endpoint with invalid GitHub token"""
|
|
# Mock the GitHub API to return 401 Unauthorized
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
|
|
# Mock failed user API response
|
|
mock_client.get.return_value = MockResponse(401, {"message": "Bad credentials"})
|
|
|
|
response = github_token_client.get("/test", headers={"Authorization": "Bearer invalid_token"})
|
|
assert response.status_code == 401
|
|
assert (
|
|
"GitHub token validation failed. Please check your token and try again." in response.json()["error"]["message"]
|
|
)
|
|
|
|
|
|
@patch("llama_stack.distribution.server.auth_providers.httpx.AsyncClient")
|
|
def test_github_enterprise_support(mock_client_class):
|
|
"""Test GitHub Enterprise support with custom API base URL"""
|
|
app = FastAPI()
|
|
|
|
# Configure GitHub token auth with enterprise URL
|
|
auth_config = AuthenticationConfig(
|
|
provider_config=GitHubTokenAuthConfig(
|
|
type=AuthProviderType.GITHUB_TOKEN,
|
|
github_api_base_url="https://github.enterprise.com/api/v3",
|
|
),
|
|
access_policy=[],
|
|
)
|
|
|
|
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config, impls={})
|
|
|
|
@app.get("/test")
|
|
def test_endpoint():
|
|
return {"message": "Authentication successful"}
|
|
|
|
client = TestClient(app)
|
|
|
|
# Mock the GitHub Enterprise API responses
|
|
mock_client = AsyncMock()
|
|
mock_client_class.return_value.__aenter__.return_value = mock_client
|
|
|
|
# Mock successful user API response
|
|
mock_client.get.side_effect = [
|
|
MockResponse(
|
|
200,
|
|
{
|
|
"login": "enterprise_user",
|
|
"id": 99999,
|
|
"email": "user@enterprise.com",
|
|
},
|
|
),
|
|
MockResponse(
|
|
200,
|
|
[
|
|
{"login": "enterprise-org"},
|
|
],
|
|
),
|
|
]
|
|
|
|
response = client.get("/test", headers={"Authorization": "Bearer enterprise_token"})
|
|
assert response.status_code == 200
|
|
|
|
# Verify the correct GitHub Enterprise URLs were called
|
|
assert mock_client.get.call_count == 1
|
|
calls = mock_client.get.call_args_list
|
|
assert calls[0][0][0] == "https://github.enterprise.com/api/v3/user"
|
|
|
|
|
|
def test_github_token_auth_error_message_format(github_token_client):
|
|
"""Test that the error message for missing auth is properly formatted"""
|
|
response = github_token_client.get("/test")
|
|
assert response.status_code == 401
|
|
|
|
error_data = response.json()
|
|
assert "error" in error_data
|
|
assert "message" in error_data["error"]
|
|
assert "Authentication required" in error_data["error"]["message"]
|
|
assert "https://docs.github.com" in error_data["error"]["message"] # Contains link to GitHub docs
|