test: fix testing

This commit is contained in:
Krrish Dholakia 2024-07-31 11:49:07 -07:00
parent a5ce084a2c
commit 6974b45c75
2 changed files with 11 additions and 7 deletions

View file

@ -1953,7 +1953,7 @@ async def test_upperbound_key_params(prisma_client):
key = await generate_key_fn(request) key = await generate_key_fn(request)
# print(result) # print(result)
except Exception as e: except Exception as e:
assert e.code == 400 assert e.code == str(400)
def test_get_bearer_token(): def test_get_bearer_token():

View file

@ -188,7 +188,12 @@ async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
from fastapi import HTTPException, Request from fastapi import HTTPException, Request
from starlette.datastructures import URL from starlette.datastructures import URL
from litellm.proxy._types import LiteLLM_TeamTable, ProxyException, UserAPIKeyAuth from litellm.proxy._types import (
LiteLLM_TeamTable,
LiteLLM_TeamTableCachedObj,
ProxyException,
UserAPIKeyAuth,
)
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.proxy_server import hash_token, user_api_key_cache from litellm.proxy.proxy_server import hash_token, user_api_key_cache
@ -202,7 +207,7 @@ async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
last_refreshed_at=time.time(), last_refreshed_at=time.time(),
) )
await asyncio.sleep(1) await asyncio.sleep(1)
team_obj = LiteLLM_TeamTable( team_obj = LiteLLM_TeamTableCachedObj(
team_id=_team_id, team_id=_team_id,
blocked=False, blocked=False,
last_refreshed_at=time.time(), last_refreshed_at=time.time(),
@ -227,7 +232,7 @@ async def test_team_disable_guardrails(mock_acompletion, client_no_auth):
await user_api_key_auth(request=request, api_key="Bearer " + user_key) await user_api_key_auth(request=request, api_key="Bearer " + user_key)
pytest.fail("Expected to raise 403 forbidden error.") pytest.fail("Expected to raise 403 forbidden error.")
except ProxyException as e: except ProxyException as e:
assert e.code == "401" assert e.code == str(403)
from litellm.tests.test_custom_callback_input import CompletionCustomHandler from litellm.tests.test_custom_callback_input import CompletionCustomHandler
@ -740,7 +745,7 @@ async def test_team_update_redis():
Tests if team update, updates the redis cache if set Tests if team update, updates the redis cache if set
""" """
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import LiteLLM_TeamTable from litellm.proxy._types import LiteLLM_TeamTableCachedObj
from litellm.proxy.auth.auth_checks import _cache_team_object from litellm.proxy.auth.auth_checks import _cache_team_object
proxy_logging_obj: ProxyLogging = getattr( proxy_logging_obj: ProxyLogging = getattr(
@ -756,7 +761,7 @@ async def test_team_update_redis():
) as mock_client: ) as mock_client:
await _cache_team_object( await _cache_team_object(
team_id="1234", team_id="1234",
team_table=LiteLLM_TeamTable(), team_table=LiteLLM_TeamTableCachedObj(),
user_api_key_cache=DualCache(), user_api_key_cache=DualCache(),
proxy_logging_obj=proxy_logging_obj, proxy_logging_obj=proxy_logging_obj,
) )
@ -770,7 +775,6 @@ async def test_get_team_redis(client_no_auth):
Tests if get_team_object gets value from redis cache, if set Tests if get_team_object gets value from redis cache, if set
""" """
from litellm.caching import DualCache, RedisCache from litellm.caching import DualCache, RedisCache
from litellm.proxy._types import LiteLLM_TeamTable
from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object from litellm.proxy.auth.auth_checks import _cache_team_object, get_team_object
proxy_logging_obj: ProxyLogging = getattr( proxy_logging_obj: ProxyLogging = getattr(