LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)

* refactor: cleanup unused variables + fix pyright errors

* feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686

* fix(o1_reasoning.py): add stricter check for o-1 reasoning model

* refactor(mistral/): make it easier to see mistral transformation logic

* fix(openai.py): fix openai o-1 model param mapping

Fixes https://github.com/BerriAI/litellm/issues/5685

* feat(main.py): infer finetuned gemini model from base model

Fixes https://github.com/BerriAI/litellm/issues/5678

* docs(vertex.md): update docs to call finetuned gemini models

* feat(proxy_server.py): allow admin to hide proxy model aliases

Closes https://github.com/BerriAI/litellm/issues/5692

* docs(load_balancing.md): add docs on hiding alias models from proxy config

* fix(base.py): don't raise notimplemented error

* fix(user_api_key_auth.py): fix model max budget check

* fix(router.py): fix elif

* fix(user_api_key_auth.py): don't set team_id to empty str

* fix(team_endpoints.py): fix response type

* test(test_completion.py): handle predibase error

* test(test_proxy_server.py): fix test

* fix(o1_transformation.py): fix max_completion_token mapping

* test(test_image_generation.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-14 10:02:55 -07:00 committed by GitHub
parent 60c5d3ebec
commit 713d762411
35 changed files with 1020 additions and 539 deletions

View file

@ -109,11 +109,8 @@ async def user_api_key_auth(
),
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
allowed_routes_check,
common_checks,
custom_db_client,
general_settings,
get_actual_routes,
jwt_handler,
litellm_proxy_admin_name,
llm_model_list,
@ -125,6 +122,8 @@ async def user_api_key_auth(
user_custom_auth,
)
parent_otel_span: Optional[Span] = None
try:
route: str = get_request_route(request=request)
# get the request body
@ -137,6 +136,7 @@ async def user_api_key_auth(
pass_through_endpoints: Optional[List[dict]] = general_settings.get(
"pass_through_endpoints", None
)
passed_in_key: Optional[str] = None
if isinstance(api_key, str):
passed_in_key = api_key
api_key = _get_bearer_token(api_key=api_key)
@ -161,7 +161,6 @@ async def user_api_key_auth(
custom_litellm_key_header_name=custom_litellm_key_header_name,
)
parent_otel_span: Optional[Span] = None
if open_telemetry_logger is not None:
parent_otel_span = open_telemetry_logger.tracer.start_span(
name="Received Proxy Server Request",
@ -189,7 +188,7 @@ async def user_api_key_auth(
######## Route Checks Before Reading DB / Cache for "token" ################
if (
route in LiteLLMRoutes.public_routes.value
route in LiteLLMRoutes.public_routes.value # type: ignore
or route_in_additonal_public_routes(current_route=route)
):
# check if public endpoint
@ -410,7 +409,7 @@ async def user_api_key_auth(
#### ELSE ####
## CHECK PASS-THROUGH ENDPOINTS ##
is_mapped_pass_through_route: bool = False
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value:
for mapped_route in LiteLLMRoutes.mapped_pass_through_routes.value: # type: ignore
if route.startswith(mapped_route):
is_mapped_pass_through_route = True
if is_mapped_pass_through_route:
@ -444,9 +443,9 @@ async def user_api_key_auth(
header_key = headers.get("litellm_user_api_key", "")
if (
isinstance(request.headers, dict)
and request.headers.get(key=header_key) is not None
and request.headers.get(key=header_key) is not None # type: ignore
):
api_key = request.headers.get(key=header_key)
api_key = request.headers.get(key=header_key) # type: ignore
if master_key is None:
if isinstance(api_key, str):
@ -606,7 +605,7 @@ async def user_api_key_auth(
## IF it's not a master key
## Route should not be in master_key_only_routes
if route in LiteLLMRoutes.master_key_only_routes.value:
if route in LiteLLMRoutes.master_key_only_routes.value: # type: ignore
raise Exception(
f"Tried to access route={route}, which is only for MASTER KEY"
)
@ -669,8 +668,9 @@ async def user_api_key_auth(
"allowed_model_region"
)
user_obj: Optional[LiteLLM_UserTable] = None
valid_token_dict: dict = {}
if valid_token is not None:
user_obj: Optional[LiteLLM_UserTable] = None
# Got Valid Token from Cache, DB
# Run checks for
# 1. If token can call model
@ -686,6 +686,7 @@ async def user_api_key_auth(
# Check 1. If token can call model
_model_alias_map = {}
model: Optional[str] = None
if (
hasattr(valid_token, "team_model_aliases")
and valid_token.team_model_aliases is not None
@ -698,6 +699,7 @@ async def user_api_key_auth(
_model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
@ -887,7 +889,10 @@ async def user_api_key_auth(
and max_budget_per_model.get(current_model, None) is not None
):
if (
model_spend[0]["model"] == current_model
"model" in model_spend[0]
and model_spend[0].get("model") == current_model
and "_sum" in model_spend[0]
and "spend" in model_spend[0]["_sum"]
and model_spend[0]["_sum"]["spend"]
>= max_budget_per_model[current_model]
):
@ -927,16 +932,19 @@ async def user_api_key_auth(
)
# Check 8: Additional Common Checks across jwt + key auth
_team_obj = LiteLLM_TeamTable(
team_id=valid_token.team_id,
max_budget=valid_token.team_max_budget,
spend=valid_token.team_spend,
tpm_limit=valid_token.team_tpm_limit,
rpm_limit=valid_token.team_rpm_limit,
blocked=valid_token.team_blocked,
models=valid_token.team_models,
metadata=valid_token.team_metadata,
)
if valid_token.team_id is not None:
_team_obj: Optional[LiteLLM_TeamTable] = LiteLLM_TeamTable(
team_id=valid_token.team_id,
max_budget=valid_token.team_max_budget,
spend=valid_token.team_spend,
tpm_limit=valid_token.team_tpm_limit,
rpm_limit=valid_token.team_rpm_limit,
blocked=valid_token.team_blocked,
models=valid_token.team_models,
metadata=valid_token.team_metadata,
)
else:
_team_obj = None
user_api_key_cache.set_cache(
key=valid_token.team_id, value=_team_obj
@ -1045,7 +1053,7 @@ async def user_api_key_auth(
"/global/predict/spend/logs",
"/global/activity",
"/health/services",
] + LiteLLMRoutes.info_routes.value
] + LiteLLMRoutes.info_routes.value # type: ignore
# check if the current route startswith any of the allowed routes
if (
route is not None
@ -1106,7 +1114,7 @@ async def user_api_key_auth(
# Log this exception to OTEL
if open_telemetry_logger is not None:
await open_telemetry_logger.async_post_call_failure_hook(
await open_telemetry_logger.async_post_call_failure_hook( # type: ignore
original_exception=e,
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
)