forked from phoenix/litellm-mirror
Merge branch 'main' into litellm_add_semantic_cache
This commit is contained in:
commit
7cb69c72c8
25 changed files with 1499 additions and 342 deletions
|
@ -636,6 +636,36 @@ async def user_api_key_auth(
|
|||
raise Exception(
|
||||
f"Only master key can be used to generate, delete, update or get info for new keys/users. Value of allow_user_auth={allow_user_auth}"
|
||||
)
|
||||
|
||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||
# sso/login, ui/login, /key functions and /user functions
|
||||
# this will never be allowed to call /chat/completions
|
||||
token_team = getattr(valid_token, "team_id", None)
|
||||
if token_team is not None:
|
||||
if token_team == "litellm-dashboard":
|
||||
# this token is only used for managing the ui
|
||||
allowed_routes = [
|
||||
"/sso",
|
||||
"/login",
|
||||
"/key",
|
||||
"/spend",
|
||||
"/user",
|
||||
]
|
||||
# check if the current route startswith any of the allowed routes
|
||||
if (
|
||||
route is not None
|
||||
and isinstance(route, str)
|
||||
and any(
|
||||
route.startswith(allowed_route)
|
||||
for allowed_route in allowed_routes
|
||||
)
|
||||
):
|
||||
# Do something if the current route starts with any of the allowed routes
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||
)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid Key Passed to LiteLLM Proxy")
|
||||
|
@ -758,9 +788,10 @@ async def _PROXY_track_cost_callback(
|
|||
verbose_proxy_logger.info(
|
||||
f"response_cost {response_cost}, for user_id {user_id}"
|
||||
)
|
||||
if user_api_key and (
|
||||
prisma_client is not None or custom_db_client is not None
|
||||
):
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key {user_api_key}, prisma_client: {prisma_client}, custom_db_client: {custom_db_client}"
|
||||
)
|
||||
if user_api_key is not None:
|
||||
await update_database(
|
||||
token=user_api_key,
|
||||
response_cost=response_cost,
|
||||
|
@ -770,6 +801,8 @@ async def _PROXY_track_cost_callback(
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
else:
|
||||
raise Exception("User API key missing from custom callback.")
|
||||
else:
|
||||
if kwargs["stream"] != True or (
|
||||
kwargs["stream"] == True
|
||||
|
@ -1361,6 +1394,26 @@ class ProxyConfig:
|
|||
proxy_config = ProxyConfig()
|
||||
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
|
||||
async def generate_key_helper_fn(
|
||||
duration: Optional[str],
|
||||
models: list,
|
||||
|
@ -1395,25 +1448,6 @@ async def generate_key_helper_fn(
|
|||
if token is None:
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
if duration is None: # allow tokens that never expire
|
||||
expires = None
|
||||
else:
|
||||
|
@ -2630,6 +2664,36 @@ async def generate_key_fn(
|
|||
elif key == "metadata" and value == {}:
|
||||
setattr(data, key, litellm.default_key_generate_params.get(key, {}))
|
||||
|
||||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.upperbound_key_generate_params is not None:
|
||||
for elem in data:
|
||||
# if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key]
|
||||
key, value = elem
|
||||
if value is not None and key in litellm.upperbound_key_generate_params:
|
||||
# if value is float/int
|
||||
if key in [
|
||||
"max_budget",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
]:
|
||||
if value > litellm.upperbound_key_generate_params[key]:
|
||||
# directly compare floats/ints
|
||||
setattr(
|
||||
data, key, litellm.upperbound_key_generate_params[key]
|
||||
)
|
||||
elif key == "budget_duration":
|
||||
# budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m)
|
||||
# compare the duration in seconds and max duration in seconds
|
||||
upperbound_budget_duration = _duration_in_seconds(
|
||||
duration=litellm.upperbound_key_generate_params[key]
|
||||
)
|
||||
user_set_budget_duration = _duration_in_seconds(duration=value)
|
||||
if user_set_budget_duration > upperbound_budget_duration:
|
||||
setattr(
|
||||
data, key, litellm.upperbound_key_generate_params[key]
|
||||
)
|
||||
|
||||
data_json = data.json() # type: ignore
|
||||
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue