LiteLLM Minor Fixes and Improvements (09/13/2024) (#5689)

* refactor: cleanup unused variables + fix pyright errors

* feat(health_check.py): Closes https://github.com/BerriAI/litellm/issues/5686

* fix(o1_reasoning.py): add stricter check for o-1 reasoning model

* refactor(mistral/): make it easier to see mistral transformation logic

* fix(openai.py): fix openai o-1 model param mapping

Fixes https://github.com/BerriAI/litellm/issues/5685

* feat(main.py): infer finetuned gemini model from base model

Fixes https://github.com/BerriAI/litellm/issues/5678

* docs(vertex.md): update docs to call finetuned gemini models

* feat(proxy_server.py): allow admin to hide proxy model aliases

Closes https://github.com/BerriAI/litellm/issues/5692

* docs(load_balancing.md): add docs on hiding alias models from proxy config

* fix(base.py): don't raise notimplemented error

* fix(user_api_key_auth.py): fix model max budget check

* fix(router.py): fix elif

* fix(user_api_key_auth.py): don't set team_id to empty str

* fix(team_endpoints.py): fix response type

* test(test_completion.py): handle predibase error

* test(test_proxy_server.py): fix test

* fix(o1_transformation.py): fix max_completion_token mapping

* test(test_image_generation.py): mark flaky test
This commit is contained in:
Krish Dholakia 2024-09-14 10:02:55 -07:00 committed by GitHub
parent 60c5d3ebec
commit 713d762411
35 changed files with 1020 additions and 539 deletions

View file

@ -125,16 +125,7 @@ from litellm.proxy._types import *
from litellm.proxy.analytics_endpoints.analytics_endpoints import (
router as analytics_router,
)
from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
log_to_opentelemetry,
)
from litellm.proxy.auth.auth_checks import log_to_opentelemetry
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
@ -260,6 +251,7 @@ from litellm.secret_managers.aws_secret_manager import (
load_aws_secret_manager,
)
from litellm.secret_managers.google_kms import load_google_kms
from litellm.secret_managers.main import get_secret
from litellm.types.llms.anthropic import (
AnthropicMessagesRequest,
AnthropicResponse,
@ -484,7 +476,7 @@ general_settings: dict = {}
callback_settings: dict = {}
log_file = "api_log.json"
worker_config = None
master_key = None
master_key: Optional[str] = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
custom_db_client: Optional[DBClient] = None
@ -874,7 +866,9 @@ def error_tracking():
def _set_spend_logs_payload(
payload: dict, prisma_client: PrismaClient, spend_logs_url: Optional[str] = None
payload: Union[dict, SpendLogsPayload],
prisma_client: PrismaClient,
spend_logs_url: Optional[str] = None,
):
if prisma_client is not None and spend_logs_url is not None:
if isinstance(payload["startTime"], datetime):
@ -1341,6 +1335,9 @@ async def _run_background_health_check():
# make 1 deep copy of llm_model_list -> use this for all background health checks
_llm_model_list = copy.deepcopy(llm_model_list)
if _llm_model_list is None:
return
while True:
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
model_list=_llm_model_list, details=health_check_details
@ -1352,7 +1349,10 @@ async def _run_background_health_check():
health_check_results["healthy_count"] = len(healthy_endpoints)
health_check_results["unhealthy_count"] = len(unhealthy_endpoints)
await asyncio.sleep(health_check_interval)
if health_check_interval is not None and isinstance(
health_check_interval, float
):
await asyncio.sleep(health_check_interval)
class ProxyConfig:
@ -1467,7 +1467,7 @@ class ProxyConfig:
break
for k, v in team_config.items():
if isinstance(v, str) and v.startswith("os.environ/"):
team_config[k] = litellm.get_secret(v)
team_config[k] = get_secret(v)
return team_config
def _init_cache(
@ -1513,6 +1513,9 @@ class ProxyConfig:
config = get_file_contents_from_s3(
bucket_name=bucket_name, object_key=object_key
)
if config is None:
raise Exception("Unable to load config from given source.")
else:
# default to file
config = await self.get_config(config_file_path=config_file_path)
@ -1528,9 +1531,7 @@ class ProxyConfig:
environment_variables = config.get("environment_variables", None)
if environment_variables:
for key, value in environment_variables.items():
os.environ[key] = str(
litellm.get_secret(secret_name=key, default_value=value)
)
os.environ[key] = str(get_secret(secret_name=key, default_value=value))
# check if litellm_license in general_settings
if "LITELLM_LICENSE" in environment_variables:
@ -1566,8 +1567,8 @@ class ProxyConfig:
if (
cache_type == "redis" or cache_type == "redis-semantic"
) and len(cache_params.keys()) == 0:
cache_host = litellm.get_secret("REDIS_HOST", None)
cache_port = litellm.get_secret("REDIS_PORT", None)
cache_host = get_secret("REDIS_HOST", None)
cache_port = get_secret("REDIS_PORT", None)
cache_password = None
cache_params.update(
{
@ -1577,8 +1578,8 @@ class ProxyConfig:
}
)
if litellm.get_secret("REDIS_PASSWORD", None) is not None:
cache_password = litellm.get_secret("REDIS_PASSWORD", None)
if get_secret("REDIS_PASSWORD", None) is not None:
cache_password = get_secret("REDIS_PASSWORD", None)
cache_params.update(
{
"password": cache_password,
@ -1617,7 +1618,7 @@ class ProxyConfig:
# users can pass os.environ/ variables on the proxy - we should read them from the env
for key, value in cache_params.items():
if type(value) is str and value.startswith("os.environ/"):
cache_params[key] = litellm.get_secret(value)
cache_params[key] = get_secret(value)
## to pass a complete url, or set ssl=True, etc. just set it as `os.environ[REDIS_URL] = <your-redis-url>`, _redis.py checks for REDIS specific environment variables
self._init_cache(cache_params=cache_params)
@ -1738,7 +1739,7 @@ class ProxyConfig:
if value is not None and isinstance(value, dict):
for _k, _v in value.items():
if isinstance(_v, str) and _v.startswith("os.environ/"):
value[_k] = litellm.get_secret(_v)
value[_k] = get_secret(_v)
litellm.upperbound_key_generate_params = (
LiteLLM_UpperboundKeyGenerateParams(**value)
)
@ -1812,15 +1813,15 @@ class ProxyConfig:
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
verbose_proxy_logger.debug("GOING INTO LITELLM.GET_SECRET!")
database_url = litellm.get_secret(database_url)
database_url = get_secret(database_url)
verbose_proxy_logger.debug("RETRIEVED DB URL: %s", database_url)
### MASTER KEY ###
master_key = general_settings.get(
"master_key", litellm.get_secret("LITELLM_MASTER_KEY", None)
"master_key", get_secret("LITELLM_MASTER_KEY", None)
)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
master_key = get_secret(master_key) # type: ignore
if not isinstance(master_key, str):
raise Exception(
"Master key must be a string. Current type - {}".format(
@ -1861,33 +1862,6 @@ class ProxyConfig:
await initialize_pass_through_endpoints(
pass_through_endpoints=general_settings["pass_through_endpoints"]
)
## dynamodb
database_type = general_settings.get("database_type", None)
if database_type is not None and (
database_type == "dynamo_db" or database_type == "dynamodb"
):
database_args = general_settings.get("database_args", None)
### LOAD FROM os.environ/ ###
for k, v in database_args.items():
if isinstance(v, str) and v.startswith("os.environ/"):
database_args[k] = litellm.get_secret(v)
if isinstance(k, str) and k == "aws_web_identity_token":
value = database_args[k]
verbose_proxy_logger.debug(
f"Loading AWS Web Identity Token from file: {value}"
)
if os.path.exists(value):
with open(value, "r") as file:
token_content = file.read()
database_args[k] = token_content
else:
verbose_proxy_logger.info(
f"DynamoDB Loading - {value} is not a valid file path"
)
verbose_proxy_logger.debug("database_args: %s", database_args)
custom_db_client = DBClient(
custom_db_args=database_args, custom_db_type=database_type
)
## ADMIN UI ACCESS ##
ui_access_mode = general_settings.get(
"ui_access_mode", "all"
@ -1951,7 +1925,7 @@ class ProxyConfig:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
model["litellm_params"][k] = get_secret(v)
print(f"\033[32m {model.get('model_name', '')}\033[0m") # noqa
litellm_model_name = model["litellm_params"]["model"]
litellm_model_api_base = model["litellm_params"].get("api_base", None)
@ -2005,7 +1979,10 @@ class ProxyConfig:
) # type:ignore
# Guardrail settings
guardrails_v2 = config.get("guardrails", None)
guardrails_v2: Optional[dict] = None
if config is not None:
guardrails_v2 = config.get("guardrails", None)
if guardrails_v2:
init_guardrails_v2(
all_guardrails=guardrails_v2, config_file_path=config_file_path
@ -2074,7 +2051,7 @@ class ProxyConfig:
### LOAD FROM os.environ/ ###
for k, v in model["litellm_params"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
model["litellm_params"][k] = litellm.get_secret(v)
model["litellm_params"][k] = get_secret(v)
## check if they have model-id's ##
model_id = model.get("model_info", {}).get("id", None)
@ -2234,7 +2211,8 @@ class ProxyConfig:
for k, v in environment_variables.items():
try:
decrypted_value = decrypt_value_helper(value=v)
os.environ[k] = decrypted_value
if decrypted_value is not None:
os.environ[k] = decrypted_value
except Exception as e:
verbose_proxy_logger.error(
"Error setting env variable: %s - %s", k, str(e)
@ -2536,7 +2514,7 @@ async def async_assistants_data_generator(
)
# chunk = chunk.model_dump_json(exclude_none=True)
async for c in chunk:
async for c in chunk: # type: ignore
c = c.model_dump_json(exclude_none=True)
try:
yield f"data: {c}\n\n"
@ -2745,17 +2723,22 @@ async def startup_event():
### LOAD MASTER KEY ###
# check if master key set in environment - load from there
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
master_key = get_secret("LITELLM_MASTER_KEY", None) # type: ignore
# check if DATABASE_URL in environment - load from there
if prisma_client is None:
prisma_setup(database_url=litellm.get_secret("DATABASE_URL", None))
_db_url: Optional[str] = get_secret("DATABASE_URL", None) # type: ignore
prisma_setup(database_url=_db_url)
### LOAD CONFIG ###
worker_config = litellm.get_secret("WORKER_CONFIG")
worker_config: Optional[Union[str, dict]] = get_secret("WORKER_CONFIG") # type: ignore
verbose_proxy_logger.debug("worker_config: %s", worker_config)
# check if it's a valid file path
if os.path.isfile(worker_config):
if proxy_config.is_yaml(config_file_path=worker_config):
if worker_config is not None:
if (
isinstance(worker_config, str)
and os.path.isfile(worker_config)
and proxy_config.is_yaml(config_file_path=worker_config)
):
(
llm_router,
llm_model_list,
@ -2763,21 +2746,23 @@ async def startup_event():
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else:
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None and isinstance(
worker_config, str
):
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
elif isinstance(worker_config, dict):
await initialize(**worker_config)
elif os.environ.get("LITELLM_CONFIG_BUCKET_NAME") is not None:
(
llm_router,
llm_model_list,
general_settings,
) = await proxy_config.load_config(
router=llm_router, config_file_path=worker_config
)
else:
# if not, assume it's a json string
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
await initialize(**worker_config)
else:
# if not, assume it's a json string
worker_config = json.loads(worker_config)
if isinstance(worker_config, dict):
await initialize(**worker_config)
## CHECK PREMIUM USER
verbose_proxy_logger.debug(
@ -2825,7 +2810,7 @@ async def startup_event():
if general_settings.get("litellm_jwtauth", None) is not None:
for k, v in general_settings["litellm_jwtauth"].items():
if isinstance(v, str) and v.startswith("os.environ/"):
general_settings["litellm_jwtauth"][k] = litellm.get_secret(v)
general_settings["litellm_jwtauth"][k] = get_secret(v)
litellm_jwtauth = LiteLLM_JWTAuth(**general_settings["litellm_jwtauth"])
else:
litellm_jwtauth = LiteLLM_JWTAuth()
@ -2948,8 +2933,7 @@ async def startup_event():
### ADD NEW MODELS ###
store_model_in_db = (
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
or store_model_in_db
get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db
) # type: ignore
if store_model_in_db == True:
scheduler.add_job(
@ -3498,7 +3482,7 @@ async def completion(
)
### CALL HOOKS ### - modify outgoing data
response = await proxy_logging_obj.post_call_success_hook(
data=data, user_api_key_dict=user_api_key_dict, response=response
data=data, user_api_key_dict=user_api_key_dict, response=response # type: ignore
)
fastapi_response.headers.update(
@ -4000,7 +3984,7 @@ async def audio_speech(
request_data=data,
)
return StreamingResponse(
generate(response), media_type="audio/mpeg", headers=custom_headers
generate(response), media_type="audio/mpeg", headers=custom_headers # type: ignore
)
except Exception as e:
@ -4288,6 +4272,7 @@ async def create_assistant(
API Reference docs - https://platform.openai.com/docs/api-reference/assistants/createAssistant
"""
global proxy_logging_obj
data = {} # ensure data always dict
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
@ -7642,6 +7627,7 @@ async def model_group_info(
)
model_groups: List[ModelGroupInfo] = []
for model in all_models_str:
_model_group_info = llm_router.get_model_group_info(model_group=model)
@ -8051,7 +8037,8 @@ async def google_login(request: Request):
with microsoft_sso:
return await microsoft_sso.get_login_redirect()
elif generic_client_id is not None:
from fastapi_sso.sso.generic import DiscoveryDocument, create_provider
from fastapi_sso.sso.base import DiscoveryDocument
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -8616,6 +8603,8 @@ async def auth_callback(request: Request):
redirect_url += "sso/callback"
else:
redirect_url += "/sso/callback"
result = None
if google_client_id is not None:
from fastapi_sso.sso.google import GoogleSSO
@ -8662,7 +8651,8 @@ async def auth_callback(request: Request):
result = await microsoft_sso.verify_and_process(request)
elif generic_client_id is not None:
# make generic sso provider
from fastapi_sso.sso.generic import DiscoveryDocument, OpenID, create_provider
from fastapi_sso.sso.base import DiscoveryDocument, OpenID
from fastapi_sso.sso.generic import create_provider
generic_client_secret = os.getenv("GENERIC_CLIENT_SECRET", None)
generic_scope = os.getenv("GENERIC_SCOPE", "openid email profile").split(" ")
@ -8766,8 +8756,8 @@ async def auth_callback(request: Request):
verbose_proxy_logger.debug("generic result: %s", result)
# User is Authe'd in - generate key for the UI to access Proxy
user_email = getattr(result, "email", None)
user_id = getattr(result, "id", None)
user_email: Optional[str] = getattr(result, "email", None)
user_id: Optional[str] = getattr(result, "id", None) if result is not None else None
if user_email is not None and os.getenv("ALLOWED_EMAIL_DOMAINS") is not None:
email_domain = user_email.split("@")[1]
@ -8783,12 +8773,12 @@ async def auth_callback(request: Request):
)
# generic client id
if generic_client_id is not None:
if generic_client_id is not None and result is not None:
user_id = getattr(result, "id", None)
user_email = getattr(result, "email", None)
user_role = getattr(result, generic_user_role_attribute_name, None)
user_role = getattr(result, generic_user_role_attribute_name, None) # type: ignore
if user_id is None:
if user_id is None and result is not None:
_first_name = getattr(result, "first_name", "") or ""
_last_name = getattr(result, "last_name", "") or ""
user_id = _first_name + _last_name
@ -8811,54 +8801,45 @@ async def auth_callback(request: Request):
"spend": 0,
"team_id": "litellm-dashboard",
}
user_defined_values: SSOUserDefinedValues = {
"models": user_id_models,
"user_id": user_id,
"user_email": user_email,
"max_budget": max_internal_user_budget,
"user_role": None,
"budget_duration": internal_user_budget_duration,
}
user_defined_values: Optional[SSOUserDefinedValues] = None
if user_id is not None:
user_defined_values = SSOUserDefinedValues(
models=user_id_models,
user_id=user_id,
user_email=user_email,
max_budget=max_internal_user_budget,
user_role=None,
budget_duration=internal_user_budget_duration,
)
_user_id_from_sso = user_id
user_role = None
try:
user_role = None
if prisma_client is not None:
user_info = await prisma_client.get_data(user_id=user_id, table_name="user")
verbose_proxy_logger.debug(
f"user_info: {user_info}; litellm.default_user_params: {litellm.default_user_params}"
)
if user_info is not None:
user_defined_values = {
"models": getattr(user_info, "models", user_id_models),
"user_id": getattr(user_info, "user_id", user_id),
"user_email": getattr(user_info, "user_id", user_email),
"user_role": getattr(user_info, "user_role", None),
"max_budget": getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
user_role = getattr(user_info, "user_role", None)
if user_info is None:
## check if user-email in db ##
user_info = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email}
)
## check if user-email in db ##
user_info = await prisma_client.db.litellm_usertable.find_first(
where={"user_email": user_email}
)
if user_info is not None:
user_defined_values = {
"models": getattr(user_info, "models", user_id_models),
"user_id": user_id,
"user_email": getattr(user_info, "user_id", user_email),
"user_role": getattr(user_info, "user_role", None),
"max_budget": getattr(
if user_info is not None and user_id is not None:
user_defined_values = SSOUserDefinedValues(
models=getattr(user_info, "models", user_id_models),
user_id=user_id,
user_email=getattr(user_info, "user_email", user_email),
user_role=getattr(user_info, "user_role", None),
max_budget=getattr(
user_info, "max_budget", max_internal_user_budget
),
"budget_duration": getattr(
budget_duration=getattr(
user_info, "budget_duration", internal_user_budget_duration
),
}
)
user_role = getattr(user_info, "user_role", None)
# update id
@ -8886,6 +8867,11 @@ async def auth_callback(request: Request):
except Exception as e:
pass
if user_defined_values is None:
raise Exception(
"Unable to map user identity to known values. 'user_defined_values' is None. File an issue - https://github.com/BerriAI/litellm/issues"
)
is_internal_user = False
if (
user_defined_values["user_role"] is not None
@ -8960,7 +8946,8 @@ async def auth_callback(request: Request):
master_key,
algorithm="HS256",
)
litellm_dashboard_ui += "?userID=" + user_id
if user_id is not None and isinstance(user_id, str):
litellm_dashboard_ui += "?userID=" + user_id
redirect_response = RedirectResponse(url=litellm_dashboard_ui, status_code=303)
redirect_response.set_cookie(key="token", value=jwt_token)
return redirect_response
@ -9023,6 +9010,7 @@ async def new_invitation(
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
} # type: ignore
)
return response
except Exception as e:
if "Foreign key constraint failed on the field" in str(e):
raise HTTPException(
@ -9031,7 +9019,7 @@ async def new_invitation(
"error": "User id does not exist in 'LiteLLM_UserTable'. Fix this by creating user via `/user/new`."
},
)
return response
raise HTTPException(status_code=500, detail={"error": str(e)})
@router.get(
@ -9951,44 +9939,46 @@ async def get_routes():
"""
routes = []
for route in app.routes:
route_info = {
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": (
getattr(route, "endpoint", None).__name__
if getattr(route, "endpoint", None)
else None
),
}
routes.append(route_info)
endpoint_route = getattr(route, "endpoint", None)
if endpoint_route is not None:
route_info = {
"path": getattr(route, "path", None),
"methods": getattr(route, "methods", None),
"name": getattr(route, "name", None),
"endpoint": (
endpoint_route.__name__
if getattr(route, "endpoint", None)
else None
),
}
routes.append(route_info)
return {"routes": routes}
#### TEST ENDPOINTS ####
@router.get(
"/token/generate",
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def token_generate():
"""
Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
"""
# Initialize AuthJWTSSO with your OpenID Provider configuration
from fastapi_sso import AuthJWTSSO
# @router.get(
# "/token/generate",
# dependencies=[Depends(user_api_key_auth)],
# include_in_schema=False,
# )
# async def token_generate():
# """
# Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc.
# """
# # Initialize AuthJWTSSO with your OpenID Provider configuration
# from fastapi_sso import AuthJWTSSO
auth_jwt_sso = AuthJWTSSO(
issuer=os.getenv("OPENID_BASE_URL"),
client_id=os.getenv("OPENID_CLIENT_ID"),
client_secret=os.getenv("OPENID_CLIENT_SECRET"),
scopes=["litellm_proxy_admin"],
)
# auth_jwt_sso = AuthJWTSSO(
# issuer=os.getenv("OPENID_BASE_URL"),
# client_id=os.getenv("OPENID_CLIENT_ID"),
# client_secret=os.getenv("OPENID_CLIENT_SECRET"),
# scopes=["litellm_proxy_admin"],
# )
token = auth_jwt_sso.create_access_token()
# token = auth_jwt_sso.create_access_token()
return {"token": token}
# return {"token": token}
@router.on_event("shutdown")
@ -10013,7 +10003,8 @@ async def shutdown_event():
# flush langfuse logs on shutdow
from litellm.utils import langFuseLogger
langFuseLogger.Langfuse.flush()
if langFuseLogger is not None:
langFuseLogger.Langfuse.flush()
except:
# [DO NOT BLOCK shutdown events for this]
pass