mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-15 09:42:51 +00:00
Merge branch 'main' into change-default-embedding-model
This commit is contained in:
commit
da35f2452e
15 changed files with 473 additions and 231 deletions
57
tests/unit/core/routers/test_vector_io.py
Normal file
57
tests/unit/core/routers/test_vector_io.py
Normal file
|
|
@ -0,0 +1,57 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from unittest.mock import AsyncMock, Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from llama_stack.apis.vector_io import OpenAICreateVectorStoreRequestWithExtraBody
|
||||
from llama_stack.core.routers.vector_io import VectorIORouter
|
||||
|
||||
|
||||
async def test_single_provider_auto_selection():
|
||||
# provider_id automatically selected during vector store create() when only one provider available
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {"inline::faiss": "mock_provider"}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
mock_routing_table.register_vector_db = AsyncMock(
|
||||
return_value=Mock(identifier="vs_123", provider_id="inline::faiss", provider_resource_id="vs_123")
|
||||
)
|
||||
mock_routing_table.get_provider_impl = AsyncMock(
|
||||
return_value=Mock(openai_create_vector_store=AsyncMock(return_value=Mock(id="vs_123")))
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
result = await router.openai_create_vector_store(request)
|
||||
assert result.id == "vs_123"
|
||||
|
||||
|
||||
async def test_create_vector_stores_multiple_providers_missing_provider_id_error():
|
||||
# if multiple providers are available, vector store create will error without provider_id
|
||||
mock_routing_table = Mock()
|
||||
mock_routing_table.impls_by_provider_id = {
|
||||
"inline::faiss": "mock_provider_1",
|
||||
"inline::sqlite-vec": "mock_provider_2",
|
||||
}
|
||||
mock_routing_table.get_all_with_type = AsyncMock(
|
||||
return_value=[
|
||||
Mock(identifier="all-MiniLM-L6-v2", model_type="embedding", metadata={"embedding_dimension": 384})
|
||||
]
|
||||
)
|
||||
router = VectorIORouter(mock_routing_table)
|
||||
request = OpenAICreateVectorStoreRequestWithExtraBody.model_validate(
|
||||
{"name": "test_store", "embedding_model": "all-MiniLM-L6-v2"}
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Multiple vector_io providers available"):
|
||||
await router.openai_create_vector_store(request)
|
||||
|
|
@ -5,7 +5,8 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
import base64
|
||||
from unittest.mock import AsyncMock, patch
|
||||
import json
|
||||
from unittest.mock import AsyncMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from fastapi import FastAPI
|
||||
|
|
@ -374,7 +375,7 @@ async def mock_jwks_response(*args, **kwargs):
|
|||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
from jose import jwt
|
||||
import jwt
|
||||
|
||||
return jwt.encode(
|
||||
{
|
||||
|
|
@ -389,8 +390,30 @@ def jwt_token_valid():
|
|||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||
@pytest.fixture
|
||||
def mock_jwks_urlopen():
|
||||
"""Mock urllib.request.urlopen for PyJWKClient JWKS requests."""
|
||||
with patch("urllib.request.urlopen") as mock_urlopen:
|
||||
# Mock the JWKS response for PyJWKClient
|
||||
mock_response = Mock()
|
||||
mock_response.read.return_value = json.dumps(
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": base64.b64encode(b"foobarbaz").decode(),
|
||||
}
|
||||
]
|
||||
}
|
||||
).encode()
|
||||
mock_urlopen.return_value.__enter__.return_value = mock_response
|
||||
yield mock_urlopen
|
||||
|
||||
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
@ -447,8 +470,7 @@ def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
|||
assert response.status_code == 401
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid, mock_jwks_urlopen):
|
||||
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue