From 9e9f5d41d9715c600e55141b3356157bb5257ce0 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 10 May 2024 16:54:51 -0700 Subject: [PATCH] fix(proxy_server.py): check + get end-user obj even for master key calls fixes issue where region-based routing wasn't working for end-users if master key was given --- litellm/proxy/_super_secret_config.yaml | 50 ++++++-------------- litellm/proxy/proxy_server.py | 61 +++++++++++++++++++------ 2 files changed, 60 insertions(+), 51 deletions(-) diff --git a/litellm/proxy/_super_secret_config.yaml b/litellm/proxy/_super_secret_config.yaml index 4ea9846114..cd65708532 100644 --- a/litellm/proxy/_super_secret_config.yaml +++ b/litellm/proxy/_super_secret_config.yaml @@ -1,41 +1,19 @@ model_list: -- litellm_params: - api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ - api_key: my-fake-key - model: openai/my-fake-model - model_name: fake-openai-endpoint -- litellm_params: - api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ - api_key: my-fake-key-2 - model: openai/my-fake-model-2 - model_name: fake-openai-endpoint -- litellm_params: - api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/ - api_key: my-fake-key-3 - model: openai/my-fake-model-3 - model_name: fake-openai-endpoint -- model_name: gpt-4 - litellm_params: - model: gpt-3.5-turbo -- litellm_params: - model: together_ai/codellama/CodeLlama-13b-Instruct-hf - model_name: CodeLlama-13b-Instruct + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/gpt-35-turbo + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: os.environ/AZURE_EUROPE_API_KEY + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/chatgpt-v-2 + api_base: https://openai-gpt-4-test-v-1.openai.azure.com/ + api_version: "2023-05-15" + api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault + router_settings: - num_retries: 0 enable_pre_call_checks: true - redis_host: os.environ/REDIS_HOST - redis_password: os.environ/REDIS_PASSWORD - redis_port: os.environ/REDIS_PORT -router_settings: - routing_strategy: "latency-based-routing" +general_settings: + master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys -litellm_settings: - success_callback: ["langfuse"] - -general_settings: - alerting: ["slack"] - alert_types: ["llm_exceptions", "daily_reports"] - alerting_args: - daily_report_frequency: 60 # every minute - report_check_interval: 5 # every 5s \ No newline at end of file diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index af1be7f266..cef432f8a4 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -573,16 +573,45 @@ async def user_api_key_auth( ): return valid_token + ## Check END-USER OBJECT + request_data = await _read_request_body(request=request) + _end_user_object = None + end_user_params = {} + if "user" in request_data: + _end_user_object = await get_end_user_object( + end_user_id=request_data["user"], + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) + if _end_user_object is not None: + end_user_params["allowed_model_region"] = ( + _end_user_object.allowed_model_region + ) + try: - is_master_key_valid = secrets.compare_digest(api_key, master_key) + is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore except Exception as e: is_master_key_valid = False + ## VALIDATE MASTER KEY ## + try: + assert isinstance(master_key, str) + except Exception as e: + raise HTTPException( + status_code=500, + detail={ + "Master key must be a valid string. Current type={}".format( + type(master_key) + ) + }, + ) + if is_master_key_valid: _user_api_key_obj = UserAPIKeyAuth( api_key=master_key, user_role="proxy_admin", user_id=litellm_proxy_admin_name, + **end_user_params, ) await user_api_key_cache.async_set_cache( key=hash_token(master_key), value=_user_api_key_obj @@ -637,10 +666,6 @@ async def user_api_key_auth( # 7. If token spend is under team budget # 8. If team spend is under team budget - request_data = await _read_request_body( - request=request - ) # request data, used across all checks. Making this easily available - # Check 1. If token can call model _model_alias_map = {} if ( @@ -879,7 +904,7 @@ async def user_api_key_auth( {"startTime": {"gt": twenty_eight_days_ago}}, {"model": current_model}, ] - }, + }, # type: ignore ) if ( len(model_spend) > 0 @@ -951,14 +976,6 @@ async def user_api_key_auth( key=valid_token.team_id, value=_team_obj ) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py - _end_user_object = None - if "user" in request_data: - _end_user_object = await get_end_user_object( - end_user_id=request_data["user"], - prisma_client=prisma_client, - user_api_key_cache=user_api_key_cache, - ) - global_proxy_spend = None if ( litellm.max_budget > 0 and prisma_client is not None @@ -2412,6 +2429,12 @@ class ProxyConfig: ) if master_key and master_key.startswith("os.environ/"): master_key = litellm.get_secret(master_key) + if not isinstance(master_key, str): + raise Exception( + "Master key must be a string. Current type - {}".format( + type(master_key) + ) + ) if master_key is not None and isinstance(master_key, str): litellm_master_key_hash = hash_token(master_key) @@ -3282,6 +3305,10 @@ async def startup_event(): ### LOAD MASTER KEY ### # check if master key set in environment - load from there master_key = litellm.get_secret("LITELLM_MASTER_KEY", None) + if not isinstance(master_key, str): + raise Exception( + "Master key must be a string. Current type - {}".format(type(master_key)) + ) # check if DATABASE_URL in environment - load from there if prisma_client is None: prisma_setup(database_url=os.getenv("DATABASE_URL")) @@ -3441,7 +3468,7 @@ async def startup_event(): store_model_in_db = ( litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db) or store_model_in_db - ) + ) # type: ignore if store_model_in_db == True: scheduler.add_job( proxy_config.add_deployment, @@ -3733,6 +3760,7 @@ async def chat_completion( "x-litellm-cache-key": cache_key, "x-litellm-model-api-base": api_base, "x-litellm-version": version, + "x-litellm-model-region": user_api_key_dict.allowed_model_region or "", } selected_data_generator = select_data_generator( response=response, @@ -3749,6 +3777,9 @@ async def chat_completion( fastapi_response.headers["x-litellm-cache-key"] = cache_key fastapi_response.headers["x-litellm-model-api-base"] = api_base fastapi_response.headers["x-litellm-version"] = version + fastapi_response.headers["x-litellm-model-region"] = ( + user_api_key_dict.allowed_model_region or "" + ) ### CALL HOOKS ### - modify outgoing data response = await proxy_logging_obj.post_call_success_hook(