mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
(QA / testing) - Add e2e tests for key model access auth checks (#8000)
* fix _model_matches_any_wildcard_pattern_in_list * test key model access checks * add key_model_access_denied to ProxyErrorTypes * update auth checks * test_model_access_update * test_team_model_access_patterns * fix _team_model_access_check * fix config used for otel testing * test fix test_call_with_invalid_model * fix model acces check tests * test_team_access_groups * test _model_matches_any_wildcard_pattern_in_list
This commit is contained in:
parent
833a268f4b
commit
d19614b8c0
6 changed files with 375 additions and 21 deletions
|
@ -1992,6 +1992,8 @@ class SpendCalculateRequest(LiteLLMPydanticObjectBase):
|
||||||
|
|
||||||
class ProxyErrorTypes(str, enum.Enum):
|
class ProxyErrorTypes(str, enum.Enum):
|
||||||
budget_exceeded = "budget_exceeded"
|
budget_exceeded = "budget_exceeded"
|
||||||
|
key_model_access_denied = "key_model_access_denied"
|
||||||
|
team_model_access_denied = "team_model_access_denied"
|
||||||
expired_key = "expired_key"
|
expired_key = "expired_key"
|
||||||
auth_error = "auth_error"
|
auth_error = "auth_error"
|
||||||
internal_server_error = "internal_server_error"
|
internal_server_error = "internal_server_error"
|
||||||
|
|
|
@ -14,12 +14,14 @@ import time
|
||||||
import traceback
|
import traceback
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
||||||
|
|
||||||
|
from fastapi import status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
from litellm.caching.dual_cache import LimitedSizeOrderedDict
|
||||||
|
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||||
from litellm.proxy._types import (
|
from litellm.proxy._types import (
|
||||||
DB_CONNECTION_ERROR_TYPES,
|
DB_CONNECTION_ERROR_TYPES,
|
||||||
CallInfo,
|
CallInfo,
|
||||||
|
@ -31,6 +33,8 @@ from litellm.proxy._types import (
|
||||||
LiteLLM_UserTable,
|
LiteLLM_UserTable,
|
||||||
LiteLLMRoutes,
|
LiteLLMRoutes,
|
||||||
LitellmUserRoles,
|
LitellmUserRoles,
|
||||||
|
ProxyErrorTypes,
|
||||||
|
ProxyException,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.route_checks import RouteChecks
|
from litellm.proxy.auth.route_checks import RouteChecks
|
||||||
|
@ -887,8 +891,11 @@ async def can_key_call_model(
|
||||||
all_model_access = True
|
all_model_access = True
|
||||||
|
|
||||||
if model is not None and model not in filtered_models and all_model_access is False:
|
if model is not None and model not in filtered_models and all_model_access is False:
|
||||||
raise ValueError(
|
raise ProxyException(
|
||||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
message=f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}",
|
||||||
|
type=ProxyErrorTypes.key_model_access_denied,
|
||||||
|
param="model",
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
valid_token.models = filtered_models
|
valid_token.models = filtered_models
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
|
@ -1064,11 +1071,7 @@ def _team_model_access_check(
|
||||||
and model not in team_object.models
|
and model not in team_object.models
|
||||||
):
|
):
|
||||||
# this means the team has access to all models on the proxy
|
# this means the team has access to all models on the proxy
|
||||||
if (
|
if "all-proxy-models" in team_object.models or "*" in team_object.models:
|
||||||
"all-proxy-models" in team_object.models
|
|
||||||
or "*" in team_object.models
|
|
||||||
or "openai/*" in team_object.models
|
|
||||||
):
|
|
||||||
# this means the team has access to all models on the proxy
|
# this means the team has access to all models on the proxy
|
||||||
pass
|
pass
|
||||||
# check if the team model is an access_group
|
# check if the team model is an access_group
|
||||||
|
@ -1086,8 +1089,11 @@ def _team_model_access_check(
|
||||||
):
|
):
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise ProxyException(
|
||||||
f"Team={team_object.team_id} not allowed to call model={model}. Allowed team models = {team_object.models}"
|
message=f"Team not allowed to access model. Team={team_object.team_id}, Model={model}. Allowed team models = {team_object.models}",
|
||||||
|
type=ProxyErrorTypes.team_model_access_denied,
|
||||||
|
param="model",
|
||||||
|
code=status.HTTP_401_UNAUTHORIZED,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1121,10 +1127,51 @@ def _model_matches_any_wildcard_pattern_in_list(
|
||||||
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
|
- model=`bedrock/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/us.*` returns True
|
||||||
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
|
- model=`bedrockzzzz/us.amazon.nova-micro-v1:0`, allowed_models=`bedrock/*` returns False
|
||||||
"""
|
"""
|
||||||
return any(
|
|
||||||
"*" in allowed_model_pattern
|
if any(
|
||||||
|
_is_wildcard_pattern(allowed_model_pattern)
|
||||||
and is_model_allowed_by_pattern(
|
and is_model_allowed_by_pattern(
|
||||||
model=model, allowed_model_pattern=allowed_model_pattern
|
model=model, allowed_model_pattern=allowed_model_pattern
|
||||||
)
|
)
|
||||||
for allowed_model_pattern in allowed_model_list
|
for allowed_model_pattern in allowed_model_list
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
if any(
|
||||||
|
_is_wildcard_pattern(allowed_model_pattern)
|
||||||
|
and _model_custom_llm_provider_matches_wildcard_pattern(
|
||||||
|
model=model, allowed_model_pattern=allowed_model_pattern
|
||||||
|
)
|
||||||
|
for allowed_model_pattern in allowed_model_list
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _model_custom_llm_provider_matches_wildcard_pattern(
|
||||||
|
model: str, allowed_model_pattern: str
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True for this scenario:
|
||||||
|
- `model=gpt-4o`
|
||||||
|
- `allowed_model_pattern=openai/*`
|
||||||
|
|
||||||
|
or
|
||||||
|
- `model=claude-3-5-sonnet-20240620`
|
||||||
|
- `allowed_model_pattern=anthropic/*`
|
||||||
|
"""
|
||||||
|
model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||||
|
return is_model_allowed_by_pattern(
|
||||||
|
model=f"{custom_llm_provider}/{model}",
|
||||||
|
allowed_model_pattern=allowed_model_pattern,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _is_wildcard_pattern(allowed_model_pattern: str) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the pattern is a wildcard pattern.
|
||||||
|
|
||||||
|
Checks if `*` is in the pattern.
|
||||||
|
"""
|
||||||
|
return "*" in allowed_model_pattern
|
||||||
|
|
|
@ -31,7 +31,15 @@ model_list:
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
model_info:
|
model_info:
|
||||||
supports_vision: True
|
supports_vision: True
|
||||||
|
- model_name: bedrock/*
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/*
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
- model_name: openai/*
|
||||||
|
litellm_params:
|
||||||
|
model: openai/*
|
||||||
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
|
|
|
@ -3,6 +3,10 @@ model_list:
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/*
|
model: openai/*
|
||||||
api_key: os.environ/OPENAI_API_KEY
|
api_key: os.environ/OPENAI_API_KEY
|
||||||
|
- model_name: bedrock/*
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/*
|
||||||
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
- model_name: text-embedding-ada-002
|
- model_name: text-embedding-ada-002
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: openai/text-embedding-ada-002
|
model: openai/text-embedding-ada-002
|
||||||
|
@ -14,6 +18,12 @@ model_list:
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
callbacks: ["langfuse"]
|
||||||
|
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
callbacks: ["prometheus"]
|
callbacks: ["prometheus"]
|
||||||
prometheus_initialize_budget_metrics: true
|
prometheus_initialize_budget_metrics: true
|
||||||
|
|
290
tests/otel_tests/test_e2e_model_access.py
Normal file
290
tests/otel_tests/test_e2e_model_access.py
Normal file
|
@ -0,0 +1,290 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
from httpx import AsyncClient
|
||||||
|
from typing import Any, Optional, List, Literal
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_key(
|
||||||
|
session, models: Optional[List[str]] = None, team_id: Optional[str] = None
|
||||||
|
):
|
||||||
|
"""Helper function to generate a key with specific model access"""
|
||||||
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
data = {}
|
||||||
|
if models is not None:
|
||||||
|
data["models"] = models
|
||||||
|
if team_id is not None:
|
||||||
|
data["team_id"] = team_id
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_team(session, models: Optional[List[str]] = None):
|
||||||
|
"""Helper function to generate a team with specific model access"""
|
||||||
|
url = "http://0.0.0.0:4000/team/new"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
data = {}
|
||||||
|
if models is not None:
|
||||||
|
data["models"] = models
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def mock_chat_completion(session, key: str, model: str):
|
||||||
|
"""Make a chat completion request using OpenAI SDK"""
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000/v1")
|
||||||
|
|
||||||
|
response = await client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=[{"role": "user", "content": f"Say hello! {uuid.uuid4()}"}],
|
||||||
|
extra_body={
|
||||||
|
"mock_response": "mock_response",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"key_models, test_model, expect_success",
|
||||||
|
[
|
||||||
|
(["openai/*"], "anthropic/claude-2", False), # Non-matching model
|
||||||
|
(["gpt-4"], "gpt-4", True), # Exact model match
|
||||||
|
(["bedrock/*"], "bedrock/anthropic.claude-3", True), # Bedrock wildcard
|
||||||
|
(["bedrock/anthropic.*"], "bedrock/anthropic.claude-3", True), # Pattern match
|
||||||
|
(["bedrock/anthropic.*"], "bedrock/amazon.titan", False), # Pattern non-match
|
||||||
|
(None, "gpt-4", True), # No model restrictions
|
||||||
|
([], "gpt-4", True), # Empty model list
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_access_patterns(key_models, test_model, expect_success):
|
||||||
|
"""
|
||||||
|
Test model access patterns for API keys:
|
||||||
|
1. Create key with specific model access pattern
|
||||||
|
2. Attempt to make completion with test model
|
||||||
|
3. Verify access is granted/denied as expected
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Generate key with specified model access
|
||||||
|
key_gen = await generate_key(session=session, models=key_models)
|
||||||
|
key = key_gen["key"]
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await mock_chat_completion(
|
||||||
|
session=session,
|
||||||
|
key=key,
|
||||||
|
model=test_model,
|
||||||
|
)
|
||||||
|
if not expect_success:
|
||||||
|
pytest.fail(f"Expected request to fail for model {test_model}")
|
||||||
|
assert (
|
||||||
|
response is not None
|
||||||
|
), "Should get valid response when access is allowed"
|
||||||
|
except Exception as e:
|
||||||
|
if expect_success:
|
||||||
|
pytest.fail(f"Expected request to succeed but got error: {e}")
|
||||||
|
_error_body = e.body
|
||||||
|
|
||||||
|
# Assert error structure and values
|
||||||
|
assert _error_body["type"] == "key_model_access_denied"
|
||||||
|
assert _error_body["param"] == "model"
|
||||||
|
assert _error_body["code"] == "401"
|
||||||
|
assert "API Key not allowed to access model" in _error_body["message"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_model_access_update():
|
||||||
|
"""
|
||||||
|
Test updating model access for an existing key:
|
||||||
|
1. Create key with restricted model access
|
||||||
|
2. Verify access patterns
|
||||||
|
3. Update key with new model access
|
||||||
|
4. Verify new access patterns
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url="http://0.0.0.0:4000")
|
||||||
|
headers = {"Authorization": "Bearer sk-1234"}
|
||||||
|
|
||||||
|
# Create initial key with restricted access
|
||||||
|
response = await client.post(
|
||||||
|
"/key/generate", json={"models": ["openai/gpt-4"]}, headers=headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
key_data = response.json()
|
||||||
|
key = key_data["key"]
|
||||||
|
|
||||||
|
# Test initial access
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Should work with gpt-4
|
||||||
|
await mock_chat_completion(session=session, key=key, model="openai/gpt-4")
|
||||||
|
|
||||||
|
# Should fail with gpt-3.5-turbo
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await mock_chat_completion(
|
||||||
|
session=session, key=key, model="openai/gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
_validate_model_access_exception(
|
||||||
|
exc_info.value, expected_type="key_model_access_denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update key with new model access
|
||||||
|
response = await client.post(
|
||||||
|
"/key/update", json={"key": key, "models": ["openai/*"]}, headers=headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Test updated access
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Both models should now work
|
||||||
|
await mock_chat_completion(session=session, key=key, model="openai/gpt-4")
|
||||||
|
await mock_chat_completion(
|
||||||
|
session=session, key=key, model="openai/gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-OpenAI model should still fail
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await mock_chat_completion(
|
||||||
|
session=session, key=key, model="anthropic/claude-2"
|
||||||
|
)
|
||||||
|
_validate_model_access_exception(
|
||||||
|
exc_info.value, expected_type="key_model_access_denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"team_models, test_model, expect_success",
|
||||||
|
[
|
||||||
|
(["openai/*"], "anthropic/claude-2", False), # Non-matching model
|
||||||
|
(["gpt-4"], "gpt-4", True), # Exact model match
|
||||||
|
(["bedrock/*"], "bedrock/anthropic.claude-3", True), # Bedrock wildcard
|
||||||
|
(["bedrock/anthropic.*"], "bedrock/anthropic.claude-3", True), # Pattern match
|
||||||
|
(["bedrock/anthropic.*"], "bedrock/amazon.titan", False), # Pattern non-match
|
||||||
|
(None, "gpt-4", True), # No model restrictions
|
||||||
|
([], "gpt-4", True), # Empty model list
|
||||||
|
],
|
||||||
|
)
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_access_patterns(team_models, test_model, expect_success):
|
||||||
|
"""
|
||||||
|
Test model access patterns for team-based API keys:
|
||||||
|
1. Create team with specific model access pattern
|
||||||
|
2. Generate key for that team
|
||||||
|
3. Attempt to make completion with test model
|
||||||
|
4. Verify access is granted/denied as expected
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url="http://0.0.0.0:4000")
|
||||||
|
headers = {"Authorization": "Bearer sk-1234"}
|
||||||
|
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
try:
|
||||||
|
team_gen = await generate_team(session=session, models=team_models)
|
||||||
|
print("created team", team_gen)
|
||||||
|
team_id = team_gen["team_id"]
|
||||||
|
key_gen = await generate_key(session=session, team_id=team_id)
|
||||||
|
print("created key", key_gen)
|
||||||
|
key = key_gen["key"]
|
||||||
|
response = await mock_chat_completion(
|
||||||
|
session=session,
|
||||||
|
key=key,
|
||||||
|
model=test_model,
|
||||||
|
)
|
||||||
|
if not expect_success:
|
||||||
|
pytest.fail(f"Expected request to fail for model {test_model}")
|
||||||
|
assert (
|
||||||
|
response is not None
|
||||||
|
), "Should get valid response when access is allowed"
|
||||||
|
except Exception as e:
|
||||||
|
if expect_success:
|
||||||
|
pytest.fail(f"Expected request to succeed but got error: {e}")
|
||||||
|
_validate_model_access_exception(
|
||||||
|
e, expected_type="team_model_access_denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_access_update():
|
||||||
|
"""
|
||||||
|
Test updating model access for a team:
|
||||||
|
1. Create team with restricted model access
|
||||||
|
2. Verify access patterns
|
||||||
|
3. Update team with new model access
|
||||||
|
4. Verify new access patterns
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url="http://0.0.0.0:4000")
|
||||||
|
headers = {"Authorization": "Bearer sk-1234"}
|
||||||
|
|
||||||
|
# Create initial team with restricted access
|
||||||
|
response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={"models": ["openai/gpt-4"], "name": "test-team"},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
team_data = response.json()
|
||||||
|
team_id = team_data["team_id"]
|
||||||
|
|
||||||
|
# Generate a key for this team
|
||||||
|
response = await client.post(
|
||||||
|
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
key = response.json()["key"]
|
||||||
|
|
||||||
|
# Test initial access
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Should work with gpt-4
|
||||||
|
await mock_chat_completion(session=session, key=key, model="openai/gpt-4")
|
||||||
|
|
||||||
|
# Should fail with gpt-3.5-turbo
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await mock_chat_completion(
|
||||||
|
session=session, key=key, model="openai/gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
_validate_model_access_exception(
|
||||||
|
exc_info.value, expected_type="team_model_access_denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Update team with new model access
|
||||||
|
response = await client.post(
|
||||||
|
"/team/update",
|
||||||
|
json={"team_id": team_id, "models": ["openai/*"]},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert response.status_code == 200
|
||||||
|
|
||||||
|
# Test updated access
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
# Both models should now work
|
||||||
|
await mock_chat_completion(session=session, key=key, model="openai/gpt-4")
|
||||||
|
await mock_chat_completion(
|
||||||
|
session=session, key=key, model="openai/gpt-3.5-turbo"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Non-OpenAI model should still fail
|
||||||
|
with pytest.raises(Exception) as exc_info:
|
||||||
|
await mock_chat_completion(
|
||||||
|
session=session, key=key, model="anthropic/claude-2"
|
||||||
|
)
|
||||||
|
_validate_model_access_exception(
|
||||||
|
exc_info.value, expected_type="team_model_access_denied"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_model_access_exception(
|
||||||
|
e: Exception,
|
||||||
|
expected_type: Literal["key_model_access_denied", "team_model_access_denied"],
|
||||||
|
):
|
||||||
|
_error_body = e.body
|
||||||
|
|
||||||
|
# Assert error structure and values
|
||||||
|
assert _error_body["type"] == expected_type
|
||||||
|
assert _error_body["param"] == "model"
|
||||||
|
assert _error_body["code"] == "401"
|
||||||
|
if expected_type == "key_model_access_denied":
|
||||||
|
assert "API Key not allowed to access model" in _error_body["message"]
|
||||||
|
elif expected_type == "team_model_access_denied":
|
||||||
|
assert "Team not allowed to access model" in _error_body["message"]
|
|
@ -361,11 +361,9 @@ def test_call_with_invalid_model(prisma_client):
|
||||||
|
|
||||||
asyncio.run(test())
|
asyncio.run(test())
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
assert (
|
assert isinstance(e, ProxyException)
|
||||||
e.message
|
assert e.type == ProxyErrorTypes.key_model_access_denied
|
||||||
== "Authentication Error, API Key not allowed to access model. This token can only access models=['mistral']. Tried to access gemini-pro-vision"
|
assert e.param == "model"
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def test_call_with_valid_model(prisma_client):
|
def test_call_with_valid_model(prisma_client):
|
||||||
|
@ -3200,10 +3198,9 @@ async def test_team_access_groups(prisma_client):
|
||||||
pytest.fail(f"This should have failed!. IT's an invalid model")
|
pytest.fail(f"This should have failed!. IT's an invalid model")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("got exception", e)
|
print("got exception", e)
|
||||||
assert (
|
assert isinstance(e, ProxyException)
|
||||||
"not allowed to call model" in e.message
|
assert e.type == ProxyErrorTypes.team_model_access_denied
|
||||||
and "Allowed team models" in e.message
|
assert e.param == "model"
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue