mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-03 09:53:45 +00:00
Merge branch 'main' into elasticsearch-integration
This commit is contained in:
commit
7aaab870bd
87 changed files with 762 additions and 5192 deletions
|
|
@ -17,7 +17,6 @@ from unittest.mock import AsyncMock, patch
|
|||
import pytest
|
||||
|
||||
from llama_stack.core.library_client import LlamaStackAsLibraryClient
|
||||
from llama_stack.core.telemetry.telemetry import MetricEvent
|
||||
from llama_stack_api import (
|
||||
Api,
|
||||
OpenAIAssistantMessageParam,
|
||||
|
|
@ -27,10 +26,6 @@ from llama_stack_api import (
|
|||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionWithMetrics(OpenAIChatCompletion):
|
||||
metrics: list[MetricEvent] | None = None
|
||||
|
||||
|
||||
def test_unregistered_model_routing_with_provider_data(client_with_models):
|
||||
"""
|
||||
Test that a model can be routed using provider_id/model_id format
|
||||
|
|
@ -72,7 +67,7 @@ def test_unregistered_model_routing_with_provider_data(client_with_models):
|
|||
# The inference router's routing_table.impls_by_provider_id should have anthropic
|
||||
# Let's patch the anthropic provider's openai_chat_completion method
|
||||
# to avoid making real API calls
|
||||
mock_response = OpenAIChatCompletionWithMetrics(
|
||||
mock_response = OpenAIChatCompletion(
|
||||
id="chatcmpl-test-123",
|
||||
created=1234567890,
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
|
|
|
|||
|
|
@ -15,11 +15,10 @@ from opentelemetry.sdk.trace import TracerProvider
|
|||
from opentelemetry.sdk.trace.export import SimpleSpanProcessor
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
|
||||
import llama_stack.core.telemetry.telemetry as telemetry_module
|
||||
|
||||
from .base import BaseTelemetryCollector, MetricStub, SpanStub
|
||||
|
||||
|
||||
# TODO: Fix thi to work with Automatic Instrumentation
|
||||
class InMemoryTelemetryCollector(BaseTelemetryCollector):
|
||||
"""In-memory telemetry collector for library-client tests.
|
||||
|
||||
|
|
@ -75,13 +74,10 @@ class InMemoryTelemetryManager:
|
|||
meter_provider = MeterProvider(metric_readers=[metric_reader])
|
||||
metrics.set_meter_provider(meter_provider)
|
||||
|
||||
telemetry_module._TRACER_PROVIDER = tracer_provider
|
||||
|
||||
self.collector = InMemoryTelemetryCollector(span_exporter, metric_reader)
|
||||
self._tracer_provider = tracer_provider
|
||||
self._meter_provider = meter_provider
|
||||
|
||||
def shutdown(self) -> None:
|
||||
telemetry_module._TRACER_PROVIDER = None
|
||||
self._tracer_provider.shutdown()
|
||||
self._meter_provider.shutdown()
|
||||
|
|
|
|||
|
|
@ -15,6 +15,7 @@ from tests.integration.fixtures.common import instantiate_llama_stack_client
|
|||
from tests.integration.telemetry.collectors import InMemoryTelemetryManager, OtlpHttpTestCollector
|
||||
|
||||
|
||||
# TODO: Fix this to work with Automatic Instrumentation
|
||||
@pytest.fixture(scope="session")
|
||||
def telemetry_test_collector():
|
||||
stack_mode = os.environ.get("LLAMA_STACK_TEST_STACK_CONFIG_TYPE", "library_client")
|
||||
|
|
@ -48,6 +49,7 @@ def telemetry_test_collector():
|
|||
manager.shutdown()
|
||||
|
||||
|
||||
# TODO: Fix this to work with Automatic Instrumentation
|
||||
@pytest.fixture(scope="session")
|
||||
def llama_stack_client(telemetry_test_collector, request):
|
||||
"""Ensure telemetry collector is ready before initializing the stack client."""
|
||||
|
|
|
|||
|
|
@ -155,9 +155,6 @@ def old_config():
|
|||
provider_type: inline::meta-reference
|
||||
config: {{}}
|
||||
api_providers:
|
||||
telemetry:
|
||||
provider_type: noop
|
||||
config: {{}}
|
||||
"""
|
||||
)
|
||||
|
||||
|
|
@ -181,7 +178,7 @@ def test_parse_and_maybe_upgrade_config_up_to_date(up_to_date_config):
|
|||
def test_parse_and_maybe_upgrade_config_old_format(old_config):
|
||||
result = parse_and_maybe_upgrade_config(old_config)
|
||||
assert result.version == LLAMA_STACK_RUN_CONFIG_VERSION
|
||||
assert all(api in result.providers for api in ["inference", "safety", "memory", "telemetry"])
|
||||
assert all(api in result.providers for api in ["inference", "safety", "memory"])
|
||||
safety_provider = result.providers["safety"][0]
|
||||
assert safety_provider.provider_type == "inline::meta-reference"
|
||||
assert "llama_guard_shield" in safety_provider.config
|
||||
|
|
|
|||
|
|
@ -83,7 +83,7 @@ class TestProviderInitialization:
|
|||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[])
|
||||
assert provider is not None
|
||||
|
||||
async def test_initialization_without_safety_api(self, mock_persistence_config, mock_deps):
|
||||
|
|
@ -97,7 +97,7 @@ class TestProviderInitialization:
|
|||
new_callable=AsyncMock,
|
||||
):
|
||||
# Should not raise any exception
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[], telemetry_enabled=False)
|
||||
provider = await get_provider_impl(config, mock_deps, policy=[])
|
||||
assert provider is not None
|
||||
assert provider.safety_api is None
|
||||
|
||||
|
|
|
|||
|
|
@ -364,23 +364,6 @@ def test_invalid_auth_header_format_oauth2(oauth2_client):
|
|||
assert "Invalid Authorization header format" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_jwks_response(*args, **kwargs):
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
import jwt
|
||||
|
|
@ -421,28 +404,60 @@ def mock_jwks_urlopen():
|
|||
yield mock_urlopen
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_jwks_urlopen_with_auth_required():
|
||||
"""Mock urllib.request.urlopen that requires Bearer token for JWKS requests."""
|
||||
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||
|
||||
def side_effect(request, **kwargs):
|
||||
# Check if Authorization header is present
|
||||
auth_header = request.headers.get("Authorization") if hasattr(request, "headers") else None
|
||||
|
||||
if not auth_header or not auth_header.startswith("Bearer "):
|
||||
# Simulate 401 Unauthorized
|
||||
import urllib.error
|
||||
|
||||
raise urllib.error.HTTPError(
|
||||
url=request.full_url if hasattr(request, "full_url") else "",
|
||||
code=401,
|
||||
msg="Unauthorized",
|
||||
hdrs={},
|
||||
fp=None,
|
||||
)
|
||||
|
||||
# Mock the JWKS response for PyJWKClient
|
||||
mock_response = Mock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||
}
|
||||
]
|
||||
}
|
||||
).encode()
|
||||
return mock_response
|
||||
|
||||
mock_urlopen.side_effect = side_effect
|
||||
yield mock_urlopen
|
||||
|
||||
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, suppress_auth_errors):
|
||||
def test_invalid_oauth2_authentication(oauth2_client, invalid_token, mock_jwks_urlopen, suppress_auth_errors):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"})
|
||||
assert response.status_code == 401
|
||||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_auth_jwks_response(*args, **kwargs):
|
||||
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
|
||||
return MockResponse(401, {})
|
||||
authz = kwargs["headers"]["Authorization"]
|
||||
if authz != "Bearer my-jwks-token":
|
||||
return MockResponse(401, {})
|
||||
return await mock_jwks_response(args, kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
|
|
@ -472,8 +487,9 @@ def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
|||
return TestClient(oauth2_app_with_jwks_token)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid, suppress_auth_errors):
|
||||
def test_oauth2_with_jwks_token_expected(
|
||||
oauth2_client, jwt_token_valid, mock_jwks_urlopen_with_auth_required, suppress_auth_errors
|
||||
):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue