mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(proxy_server.py): check + get end-user obj even for master key calls
fixes issue where region-based routing wasn't working for end-users if master key was given
This commit is contained in:
parent
30d2df8940
commit
9e9f5d41d9
2 changed files with 60 additions and 51 deletions
|
@ -1,41 +1,19 @@
|
||||||
model_list:
|
model_list:
|
||||||
- litellm_params:
|
- model_name: gpt-3.5-turbo
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
litellm_params:
|
||||||
api_key: my-fake-key
|
model: azure/gpt-35-turbo
|
||||||
model: openai/my-fake-model
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
model_name: fake-openai-endpoint
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
- litellm_params:
|
- model_name: gpt-3.5-turbo
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
litellm_params:
|
||||||
api_key: my-fake-key-2
|
model: azure/chatgpt-v-2
|
||||||
model: openai/my-fake-model-2
|
api_base: https://openai-gpt-4-test-v-1.openai.azure.com/
|
||||||
model_name: fake-openai-endpoint
|
api_version: "2023-05-15"
|
||||||
- litellm_params:
|
api_key: os.environ/AZURE_API_KEY # The `os.environ/` prefix tells litellm to read this from the env. See https://docs.litellm.ai/docs/simple_proxy#load-api-keys-from-vault
|
||||||
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
|
|
||||||
api_key: my-fake-key-3
|
|
||||||
model: openai/my-fake-model-3
|
|
||||||
model_name: fake-openai-endpoint
|
|
||||||
- model_name: gpt-4
|
|
||||||
litellm_params:
|
|
||||||
model: gpt-3.5-turbo
|
|
||||||
- litellm_params:
|
|
||||||
model: together_ai/codellama/CodeLlama-13b-Instruct-hf
|
|
||||||
model_name: CodeLlama-13b-Instruct
|
|
||||||
router_settings:
|
router_settings:
|
||||||
num_retries: 0
|
|
||||||
enable_pre_call_checks: true
|
enable_pre_call_checks: true
|
||||||
redis_host: os.environ/REDIS_HOST
|
|
||||||
redis_password: os.environ/REDIS_PASSWORD
|
|
||||||
redis_port: os.environ/REDIS_PORT
|
|
||||||
|
|
||||||
router_settings:
|
general_settings:
|
||||||
routing_strategy: "latency-based-routing"
|
master_key: sk-1234 # [OPTIONAL] Use to enforce auth on proxy. See - https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||||
|
|
||||||
litellm_settings:
|
|
||||||
success_callback: ["langfuse"]
|
|
||||||
|
|
||||||
general_settings:
|
|
||||||
alerting: ["slack"]
|
|
||||||
alert_types: ["llm_exceptions", "daily_reports"]
|
|
||||||
alerting_args:
|
|
||||||
daily_report_frequency: 60 # every minute
|
|
||||||
report_check_interval: 5 # every 5s
|
|
|
@ -573,16 +573,45 @@ async def user_api_key_auth(
|
||||||
):
|
):
|
||||||
return valid_token
|
return valid_token
|
||||||
|
|
||||||
|
## Check END-USER OBJECT
|
||||||
|
request_data = await _read_request_body(request=request)
|
||||||
|
_end_user_object = None
|
||||||
|
end_user_params = {}
|
||||||
|
if "user" in request_data:
|
||||||
|
_end_user_object = await get_end_user_object(
|
||||||
|
end_user_id=request_data["user"],
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
|
if _end_user_object is not None:
|
||||||
|
end_user_params["allowed_model_region"] = (
|
||||||
|
_end_user_object.allowed_model_region
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key)
|
is_master_key_valid = secrets.compare_digest(api_key, master_key) # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
is_master_key_valid = False
|
is_master_key_valid = False
|
||||||
|
|
||||||
|
## VALIDATE MASTER KEY ##
|
||||||
|
try:
|
||||||
|
assert isinstance(master_key, str)
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={
|
||||||
|
"Master key must be a valid string. Current type={}".format(
|
||||||
|
type(master_key)
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
_user_api_key_obj = UserAPIKeyAuth(
|
_user_api_key_obj = UserAPIKeyAuth(
|
||||||
api_key=master_key,
|
api_key=master_key,
|
||||||
user_role="proxy_admin",
|
user_role="proxy_admin",
|
||||||
user_id=litellm_proxy_admin_name,
|
user_id=litellm_proxy_admin_name,
|
||||||
|
**end_user_params,
|
||||||
)
|
)
|
||||||
await user_api_key_cache.async_set_cache(
|
await user_api_key_cache.async_set_cache(
|
||||||
key=hash_token(master_key), value=_user_api_key_obj
|
key=hash_token(master_key), value=_user_api_key_obj
|
||||||
|
@ -637,10 +666,6 @@ async def user_api_key_auth(
|
||||||
# 7. If token spend is under team budget
|
# 7. If token spend is under team budget
|
||||||
# 8. If team spend is under team budget
|
# 8. If team spend is under team budget
|
||||||
|
|
||||||
request_data = await _read_request_body(
|
|
||||||
request=request
|
|
||||||
) # request data, used across all checks. Making this easily available
|
|
||||||
|
|
||||||
# Check 1. If token can call model
|
# Check 1. If token can call model
|
||||||
_model_alias_map = {}
|
_model_alias_map = {}
|
||||||
if (
|
if (
|
||||||
|
@ -879,7 +904,7 @@ async def user_api_key_auth(
|
||||||
{"startTime": {"gt": twenty_eight_days_ago}},
|
{"startTime": {"gt": twenty_eight_days_ago}},
|
||||||
{"model": current_model},
|
{"model": current_model},
|
||||||
]
|
]
|
||||||
},
|
}, # type: ignore
|
||||||
)
|
)
|
||||||
if (
|
if (
|
||||||
len(model_spend) > 0
|
len(model_spend) > 0
|
||||||
|
@ -951,14 +976,6 @@ async def user_api_key_auth(
|
||||||
key=valid_token.team_id, value=_team_obj
|
key=valid_token.team_id, value=_team_obj
|
||||||
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
|
) # save team table in cache - used for tpm/rpm limiting - tpm_rpm_limiter.py
|
||||||
|
|
||||||
_end_user_object = None
|
|
||||||
if "user" in request_data:
|
|
||||||
_end_user_object = await get_end_user_object(
|
|
||||||
end_user_id=request_data["user"],
|
|
||||||
prisma_client=prisma_client,
|
|
||||||
user_api_key_cache=user_api_key_cache,
|
|
||||||
)
|
|
||||||
|
|
||||||
global_proxy_spend = None
|
global_proxy_spend = None
|
||||||
if (
|
if (
|
||||||
litellm.max_budget > 0 and prisma_client is not None
|
litellm.max_budget > 0 and prisma_client is not None
|
||||||
|
@ -2412,6 +2429,12 @@ class ProxyConfig:
|
||||||
)
|
)
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
master_key = litellm.get_secret(master_key)
|
master_key = litellm.get_secret(master_key)
|
||||||
|
if not isinstance(master_key, str):
|
||||||
|
raise Exception(
|
||||||
|
"Master key must be a string. Current type - {}".format(
|
||||||
|
type(master_key)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if master_key is not None and isinstance(master_key, str):
|
if master_key is not None and isinstance(master_key, str):
|
||||||
litellm_master_key_hash = hash_token(master_key)
|
litellm_master_key_hash = hash_token(master_key)
|
||||||
|
@ -3282,6 +3305,10 @@ async def startup_event():
|
||||||
### LOAD MASTER KEY ###
|
### LOAD MASTER KEY ###
|
||||||
# check if master key set in environment - load from there
|
# check if master key set in environment - load from there
|
||||||
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
master_key = litellm.get_secret("LITELLM_MASTER_KEY", None)
|
||||||
|
if not isinstance(master_key, str):
|
||||||
|
raise Exception(
|
||||||
|
"Master key must be a string. Current type - {}".format(type(master_key))
|
||||||
|
)
|
||||||
# check if DATABASE_URL in environment - load from there
|
# check if DATABASE_URL in environment - load from there
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
prisma_setup(database_url=os.getenv("DATABASE_URL"))
|
prisma_setup(database_url=os.getenv("DATABASE_URL"))
|
||||||
|
@ -3441,7 +3468,7 @@ async def startup_event():
|
||||||
store_model_in_db = (
|
store_model_in_db = (
|
||||||
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
|
litellm.get_secret("STORE_MODEL_IN_DB", store_model_in_db)
|
||||||
or store_model_in_db
|
or store_model_in_db
|
||||||
)
|
) # type: ignore
|
||||||
if store_model_in_db == True:
|
if store_model_in_db == True:
|
||||||
scheduler.add_job(
|
scheduler.add_job(
|
||||||
proxy_config.add_deployment,
|
proxy_config.add_deployment,
|
||||||
|
@ -3733,6 +3760,7 @@ async def chat_completion(
|
||||||
"x-litellm-cache-key": cache_key,
|
"x-litellm-cache-key": cache_key,
|
||||||
"x-litellm-model-api-base": api_base,
|
"x-litellm-model-api-base": api_base,
|
||||||
"x-litellm-version": version,
|
"x-litellm-version": version,
|
||||||
|
"x-litellm-model-region": user_api_key_dict.allowed_model_region or "",
|
||||||
}
|
}
|
||||||
selected_data_generator = select_data_generator(
|
selected_data_generator = select_data_generator(
|
||||||
response=response,
|
response=response,
|
||||||
|
@ -3749,6 +3777,9 @@ async def chat_completion(
|
||||||
fastapi_response.headers["x-litellm-cache-key"] = cache_key
|
fastapi_response.headers["x-litellm-cache-key"] = cache_key
|
||||||
fastapi_response.headers["x-litellm-model-api-base"] = api_base
|
fastapi_response.headers["x-litellm-model-api-base"] = api_base
|
||||||
fastapi_response.headers["x-litellm-version"] = version
|
fastapi_response.headers["x-litellm-version"] = version
|
||||||
|
fastapi_response.headers["x-litellm-model-region"] = (
|
||||||
|
user_api_key_dict.allowed_model_region or ""
|
||||||
|
)
|
||||||
|
|
||||||
### CALL HOOKS ### - modify outgoing data
|
### CALL HOOKS ### - modify outgoing data
|
||||||
response = await proxy_logging_obj.post_call_success_hook(
|
response = await proxy_logging_obj.post_call_success_hook(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue