mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Merge pull request #2208 from BerriAI/litellm_enforce_team_limits
Litellm enforce team limits
This commit is contained in:
commit
365e7ed5b9
8 changed files with 302 additions and 47 deletions
|
@ -458,8 +458,19 @@ class LiteLLM_VerificationToken(LiteLLMBase):
|
||||||
protected_namespaces = ()
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
|
"""
|
||||||
|
Combined view of litellm verification token + litellm team table (select values)
|
||||||
|
"""
|
||||||
|
|
||||||
|
team_spend: Optional[float] = None
|
||||||
|
team_tpm_limit: Optional[int] = None
|
||||||
|
team_rpm_limit: Optional[int] = None
|
||||||
|
team_max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(
|
class UserAPIKeyAuth(
|
||||||
LiteLLM_VerificationToken
|
LiteLLM_VerificationTokenView
|
||||||
): # the expected response object for user api key auth
|
): # the expected response object for user api key auth
|
||||||
"""
|
"""
|
||||||
Return the row in the db
|
Return the row in the db
|
||||||
|
|
|
@ -38,7 +38,6 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
current = cache.get_cache(
|
current = cache.get_cache(
|
||||||
key=request_count_api_key
|
key=request_count_api_key
|
||||||
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
) # {"current_requests": 1, "current_tpm": 1, "current_rpm": 10}
|
||||||
# print(f"current: {current}")
|
|
||||||
if current is None:
|
if current is None:
|
||||||
new_val = {
|
new_val = {
|
||||||
"current_requests": 1,
|
"current_requests": 1,
|
||||||
|
@ -73,8 +72,12 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
self.print_verbose(f"Inside Max Parallel Request Pre-Call Hook")
|
||||||
api_key = user_api_key_dict.api_key
|
api_key = user_api_key_dict.api_key
|
||||||
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
max_parallel_requests = user_api_key_dict.max_parallel_requests or sys.maxsize
|
||||||
tpm_limit = user_api_key_dict.tpm_limit or sys.maxsize
|
tpm_limit = getattr(user_api_key_dict, "tpm_limit", sys.maxsize)
|
||||||
rpm_limit = user_api_key_dict.rpm_limit or sys.maxsize
|
if tpm_limit is None:
|
||||||
|
tpm_limit = sys.maxsize
|
||||||
|
rpm_limit = getattr(user_api_key_dict, "rpm_limit", sys.maxsize)
|
||||||
|
if rpm_limit is None:
|
||||||
|
rpm_limit = sys.maxsize
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
return
|
return
|
||||||
|
@ -131,17 +134,46 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
_user_id_rate_limits = user_api_key_dict.user_id_rate_limits
|
||||||
|
|
||||||
# get user tpm/rpm limits
|
# get user tpm/rpm limits
|
||||||
if _user_id_rate_limits is None or _user_id_rate_limits == {}:
|
if _user_id_rate_limits is not None and isinstance(_user_id_rate_limits, dict):
|
||||||
return
|
user_tpm_limit = _user_id_rate_limits.get("tpm_limit", None)
|
||||||
user_tpm_limit = _user_id_rate_limits.get("tpm_limit")
|
user_rpm_limit = _user_id_rate_limits.get("rpm_limit", None)
|
||||||
user_rpm_limit = _user_id_rate_limits.get("rpm_limit")
|
if user_tpm_limit is None:
|
||||||
if user_tpm_limit is None:
|
user_tpm_limit = sys.maxsize
|
||||||
user_tpm_limit = sys.maxsize
|
if user_rpm_limit is None:
|
||||||
if user_rpm_limit is None:
|
user_rpm_limit = sys.maxsize
|
||||||
user_rpm_limit = sys.maxsize
|
|
||||||
|
# now do the same tpm/rpm checks
|
||||||
|
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
||||||
|
|
||||||
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
|
await self.check_key_in_limits(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=cache,
|
||||||
|
data=data,
|
||||||
|
call_type=call_type,
|
||||||
|
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||||
|
request_count_api_key=request_count_api_key,
|
||||||
|
tpm_limit=user_tpm_limit,
|
||||||
|
rpm_limit=user_rpm_limit,
|
||||||
|
)
|
||||||
|
|
||||||
|
# TEAM RATE LIMITS
|
||||||
|
## get team tpm/rpm limits
|
||||||
|
team_id = user_api_key_dict.team_id
|
||||||
|
team_tpm_limit = getattr(user_api_key_dict, "team_tpm_limit", sys.maxsize)
|
||||||
|
if team_tpm_limit is None:
|
||||||
|
team_tpm_limit = sys.maxsize
|
||||||
|
team_rpm_limit = getattr(user_api_key_dict, "team_rpm_limit", sys.maxsize)
|
||||||
|
if team_rpm_limit is None:
|
||||||
|
team_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
|
if team_tpm_limit is None:
|
||||||
|
team_tpm_limit = sys.maxsize
|
||||||
|
if team_rpm_limit is None:
|
||||||
|
team_rpm_limit = sys.maxsize
|
||||||
|
|
||||||
# now do the same tpm/rpm checks
|
# now do the same tpm/rpm checks
|
||||||
request_count_api_key = f"{user_id}::{precise_minute}::request_count"
|
request_count_api_key = f"{team_id}::{precise_minute}::request_count"
|
||||||
|
|
||||||
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
# print(f"Checking if {request_count_api_key} is allowed to make request for minute {precise_minute}")
|
||||||
await self.check_key_in_limits(
|
await self.check_key_in_limits(
|
||||||
|
@ -151,8 +183,8 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
call_type=call_type,
|
call_type=call_type,
|
||||||
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
max_parallel_requests=sys.maxsize, # TODO: Support max parallel requests for a user
|
||||||
request_count_api_key=request_count_api_key,
|
request_count_api_key=request_count_api_key,
|
||||||
tpm_limit=user_tpm_limit,
|
tpm_limit=team_tpm_limit,
|
||||||
rpm_limit=user_rpm_limit,
|
rpm_limit=team_rpm_limit,
|
||||||
)
|
)
|
||||||
return
|
return
|
||||||
|
|
||||||
|
@ -163,6 +195,9 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
user_api_key_user_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
"user_api_key_user_id", None
|
"user_api_key_user_id", None
|
||||||
)
|
)
|
||||||
|
user_api_key_team_id = kwargs["litellm_params"]["metadata"].get(
|
||||||
|
"user_api_key_team_id", None
|
||||||
|
)
|
||||||
|
|
||||||
if user_api_key is None:
|
if user_api_key is None:
|
||||||
return
|
return
|
||||||
|
@ -212,7 +247,41 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage - User
|
# Update usage - User
|
||||||
# ------------
|
# ------------
|
||||||
if user_api_key_user_id is None:
|
if user_api_key_user_id is not None:
|
||||||
|
total_tokens = 0
|
||||||
|
|
||||||
|
if isinstance(response_obj, ModelResponse):
|
||||||
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
|
request_count_api_key = (
|
||||||
|
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
||||||
|
)
|
||||||
|
|
||||||
|
current = self.user_api_key_cache.get_cache(
|
||||||
|
key=request_count_api_key
|
||||||
|
) or {
|
||||||
|
"current_requests": 1,
|
||||||
|
"current_tpm": total_tokens,
|
||||||
|
"current_rpm": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
new_val = {
|
||||||
|
"current_requests": max(current["current_requests"] - 1, 0),
|
||||||
|
"current_tpm": current["current_tpm"] + total_tokens,
|
||||||
|
"current_rpm": current["current_rpm"] + 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
self.print_verbose(
|
||||||
|
f"updated_value in success call: {new_val}, precise_minute: {precise_minute}"
|
||||||
|
)
|
||||||
|
self.user_api_key_cache.set_cache(
|
||||||
|
request_count_api_key, new_val, ttl=60
|
||||||
|
) # store in cache for 1 min.
|
||||||
|
|
||||||
|
# ------------
|
||||||
|
# Update usage - Team
|
||||||
|
# ------------
|
||||||
|
if user_api_key_team_id is None:
|
||||||
return
|
return
|
||||||
|
|
||||||
total_tokens = 0
|
total_tokens = 0
|
||||||
|
@ -221,7 +290,7 @@ class _PROXY_MaxParallelRequestsHandler(CustomLogger):
|
||||||
total_tokens = response_obj.usage.total_tokens
|
total_tokens = response_obj.usage.total_tokens
|
||||||
|
|
||||||
request_count_api_key = (
|
request_count_api_key = (
|
||||||
f"{user_api_key_user_id}::{precise_minute}::request_count"
|
f"{user_api_key_team_id}::{precise_minute}::request_count"
|
||||||
)
|
)
|
||||||
|
|
||||||
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
|
current = self.user_api_key_cache.get_cache(key=request_count_api_key) or {
|
||||||
|
|
|
@ -356,7 +356,7 @@ async def user_api_key_auth(
|
||||||
verbose_proxy_logger.debug(f"api key: {api_key}")
|
verbose_proxy_logger.debug(f"api key: {api_key}")
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
valid_token = await prisma_client.get_data(
|
valid_token = await prisma_client.get_data(
|
||||||
token=api_key,
|
token=api_key, table_name="combined_view"
|
||||||
)
|
)
|
||||||
|
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
|
@ -381,7 +381,8 @@ async def user_api_key_auth(
|
||||||
# 4. If token is expired
|
# 4. If token is expired
|
||||||
# 5. If token spend is under Budget for the token
|
# 5. If token spend is under Budget for the token
|
||||||
# 6. If token spend per model is under budget per model
|
# 6. If token spend per model is under budget per model
|
||||||
|
# 7. If token spend is under team budget
|
||||||
|
# 8. If team spend is under team budget
|
||||||
request_data = await _read_request_body(
|
request_data = await _read_request_body(
|
||||||
request=request
|
request=request
|
||||||
) # request data, used across all checks. Making this easily available
|
) # request data, used across all checks. Making this easily available
|
||||||
|
@ -610,6 +611,47 @@ async def user_api_key_auth(
|
||||||
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
f"ExceededModelBudget: Current spend for model: {current_model_spend}; Max Budget for Model: {current_model_budget}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Check 6. Token spend is under Team budget
|
||||||
|
if (
|
||||||
|
valid_token.spend is not None
|
||||||
|
and hasattr(valid_token, "team_max_budget")
|
||||||
|
and valid_token.team_max_budget is not None
|
||||||
|
):
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
user_max_budget=valid_token.team_max_budget,
|
||||||
|
user_current_spend=valid_token.spend,
|
||||||
|
type="token_budget",
|
||||||
|
user_info=valid_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_token.spend >= valid_token.team_max_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededTokenBudget: Current spend for token: {valid_token.spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check 7. Team spend is under Team budget
|
||||||
|
if (
|
||||||
|
hasattr(valid_token, "team_spend")
|
||||||
|
and valid_token.team_spend is not None
|
||||||
|
and hasattr(valid_token, "team_max_budget")
|
||||||
|
and valid_token.team_max_budget is not None
|
||||||
|
):
|
||||||
|
asyncio.create_task(
|
||||||
|
proxy_logging_obj.budget_alerts(
|
||||||
|
user_max_budget=valid_token.team_max_budget,
|
||||||
|
user_current_spend=valid_token.team_spend,
|
||||||
|
type="token_budget",
|
||||||
|
user_info=valid_token,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if valid_token.team_spend >= valid_token.team_max_budget:
|
||||||
|
raise Exception(
|
||||||
|
f"ExceededTokenBudget: Current Team Spend: {valid_token.team_spend}; Max Budget for Team: {valid_token.team_max_budget}"
|
||||||
|
)
|
||||||
|
|
||||||
# Token passed all checks
|
# Token passed all checks
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
|
|
||||||
|
@ -1870,14 +1912,6 @@ async def generate_key_helper_fn(
|
||||||
saved_token["expires"], datetime
|
saved_token["expires"], datetime
|
||||||
):
|
):
|
||||||
saved_token["expires"] = saved_token["expires"].isoformat()
|
saved_token["expires"] = saved_token["expires"].isoformat()
|
||||||
if key_data["token"] is not None and isinstance(key_data["token"], str):
|
|
||||||
hashed_token = hash_token(key_data["token"])
|
|
||||||
saved_token["token"] = hashed_token
|
|
||||||
user_api_key_cache.set_cache(
|
|
||||||
key=hashed_token,
|
|
||||||
value=LiteLLM_VerificationToken(**saved_token), # type: ignore
|
|
||||||
ttl=600,
|
|
||||||
)
|
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
## CREATE USER (If necessary)
|
## CREATE USER (If necessary)
|
||||||
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
verbose_proxy_logger.debug(f"prisma_client: Creating User={user_data}")
|
||||||
|
@ -2263,6 +2297,11 @@ async def startup_event():
|
||||||
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
duration=None, models=[], aliases={}, config={}, spend=0, token=master_key
|
||||||
)
|
)
|
||||||
|
|
||||||
|
### CHECK IF VIEW EXISTS ###
|
||||||
|
if prisma_client is not None:
|
||||||
|
create_view_response = await prisma_client.check_view_exists()
|
||||||
|
print(f"create_view_response: {create_view_response}") # noqa
|
||||||
|
|
||||||
### START BUDGET SCHEDULER ###
|
### START BUDGET SCHEDULER ###
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
scheduler = AsyncIOScheduler()
|
scheduler = AsyncIOScheduler()
|
||||||
|
|
|
@ -5,6 +5,7 @@ from litellm.proxy._types import (
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
DynamoDBArgs,
|
DynamoDBArgs,
|
||||||
LiteLLM_VerificationToken,
|
LiteLLM_VerificationToken,
|
||||||
|
LiteLLM_VerificationTokenView,
|
||||||
LiteLLM_SpendLogs,
|
LiteLLM_SpendLogs,
|
||||||
)
|
)
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
|
@ -479,6 +480,46 @@ class PrismaClient:
|
||||||
db_data[k] = json.dumps(v)
|
db_data[k] = json.dumps(v)
|
||||||
return db_data
|
return db_data
|
||||||
|
|
||||||
|
@backoff.on_exception(
|
||||||
|
backoff.expo,
|
||||||
|
Exception, # base exception to catch for the backoff
|
||||||
|
max_tries=3, # maximum number of retries
|
||||||
|
max_time=10, # maximum total time to retry for
|
||||||
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
|
)
|
||||||
|
async def check_view_exists(self):
|
||||||
|
"""
|
||||||
|
Checks if the LiteLLM_VerificationTokenView exists in the user's db.
|
||||||
|
|
||||||
|
This is used for getting the token + team data in user_api_key_auth
|
||||||
|
|
||||||
|
If the view doesn't exist, one will be created.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Try to select one row from the view
|
||||||
|
await self.db.execute_raw(
|
||||||
|
"""SELECT 1 FROM "LiteLLM_VerificationTokenView" LIMIT 1"""
|
||||||
|
)
|
||||||
|
return "LiteLLM_VerificationTokenView Exists!"
|
||||||
|
except Exception as e:
|
||||||
|
# If an error occurs, the view does not exist, so create it
|
||||||
|
value = await self.health_check()
|
||||||
|
await self.db.execute_raw(
|
||||||
|
"""
|
||||||
|
CREATE VIEW "LiteLLM_VerificationTokenView" AS
|
||||||
|
SELECT
|
||||||
|
v.*,
|
||||||
|
t.spend AS team_spend,
|
||||||
|
t.max_budget AS team_max_budget,
|
||||||
|
t.tpm_limit AS team_tpm_limit,
|
||||||
|
t.rpm_limit AS team_rpm_limit
|
||||||
|
FROM "LiteLLM_VerificationToken" v
|
||||||
|
LEFT JOIN "LiteLLM_TeamTable" t ON v.team_id = t.team_id;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
return "LiteLLM_VerificationTokenView Created!"
|
||||||
|
|
||||||
@backoff.on_exception(
|
@backoff.on_exception(
|
||||||
backoff.expo,
|
backoff.expo,
|
||||||
Exception, # base exception to catch for the backoff
|
Exception, # base exception to catch for the backoff
|
||||||
|
@ -535,7 +576,15 @@ class PrismaClient:
|
||||||
team_id_list: Optional[list] = None,
|
team_id_list: Optional[list] = None,
|
||||||
key_val: Optional[dict] = None,
|
key_val: Optional[dict] = None,
|
||||||
table_name: Optional[
|
table_name: Optional[
|
||||||
Literal["user", "key", "config", "spend", "team", "user_notification"]
|
Literal[
|
||||||
|
"user",
|
||||||
|
"key",
|
||||||
|
"config",
|
||||||
|
"spend",
|
||||||
|
"team",
|
||||||
|
"user_notification",
|
||||||
|
"combined_view",
|
||||||
|
]
|
||||||
] = None,
|
] = None,
|
||||||
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
query_type: Literal["find_unique", "find_all"] = "find_unique",
|
||||||
expires: Optional[datetime] = None,
|
expires: Optional[datetime] = None,
|
||||||
|
@ -543,7 +592,9 @@ class PrismaClient:
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
response: Any = None
|
response: Any = None
|
||||||
if token is not None or (table_name is not None and table_name == "key"):
|
if (token is not None and table_name is None) or (
|
||||||
|
table_name is not None and table_name == "key"
|
||||||
|
):
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
if token is not None:
|
if token is not None:
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
|
@ -723,6 +774,38 @@ class PrismaClient:
|
||||||
elif query_type == "find_all":
|
elif query_type == "find_all":
|
||||||
response = await self.db.litellm_usernotifications.find_many() # type: ignore
|
response = await self.db.litellm_usernotifications.find_many() # type: ignore
|
||||||
return response
|
return response
|
||||||
|
elif table_name == "combined_view":
|
||||||
|
# check if plain text or hash
|
||||||
|
if token is not None:
|
||||||
|
if isinstance(token, str):
|
||||||
|
hashed_token = token
|
||||||
|
if token.startswith("sk-"):
|
||||||
|
hashed_token = self.hash_token(token=token)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||||
|
)
|
||||||
|
if query_type == "find_unique":
|
||||||
|
if token is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail={"error": f"No token passed in. Token={token}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
sql_query = f"""
|
||||||
|
SELECT *
|
||||||
|
FROM "LiteLLM_VerificationTokenView"
|
||||||
|
WHERE token = '{token}'
|
||||||
|
"""
|
||||||
|
|
||||||
|
response = await self.db.query_first(query=sql_query)
|
||||||
|
if response is not None:
|
||||||
|
response = LiteLLM_VerificationTokenView(**response)
|
||||||
|
# for prisma we need to cast the expires time to str
|
||||||
|
if response.expires is not None and isinstance(
|
||||||
|
response.expires, datetime
|
||||||
|
):
|
||||||
|
response.expires = response.expires.isoformat()
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||||
import traceback
|
import traceback
|
||||||
|
|
|
@ -483,9 +483,12 @@ def test_redis_cache_completion_stream():
|
||||||
max_tokens=40,
|
max_tokens=40,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
caching=True,
|
||||||
)
|
)
|
||||||
response_1_content = ""
|
response_1_content = ""
|
||||||
|
response_1_id = None
|
||||||
for chunk in response1:
|
for chunk in response1:
|
||||||
|
response_1_id = chunk.id
|
||||||
print(chunk)
|
print(chunk)
|
||||||
response_1_content += chunk.choices[0].delta.content or ""
|
response_1_content += chunk.choices[0].delta.content or ""
|
||||||
print(response_1_content)
|
print(response_1_content)
|
||||||
|
@ -497,16 +500,22 @@ def test_redis_cache_completion_stream():
|
||||||
max_tokens=40,
|
max_tokens=40,
|
||||||
temperature=0.2,
|
temperature=0.2,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
caching=True,
|
||||||
)
|
)
|
||||||
response_2_content = ""
|
response_2_content = ""
|
||||||
|
response_2_id = None
|
||||||
for chunk in response2:
|
for chunk in response2:
|
||||||
|
response_2_id = chunk.id
|
||||||
print(chunk)
|
print(chunk)
|
||||||
response_2_content += chunk.choices[0].delta.content or ""
|
response_2_content += chunk.choices[0].delta.content or ""
|
||||||
print("\nresponse 1", response_1_content)
|
print("\nresponse 1", response_1_content)
|
||||||
print("\nresponse 2", response_2_content)
|
print("\nresponse 2", response_2_content)
|
||||||
assert (
|
assert (
|
||||||
response_1_content == response_2_content
|
response_1_id == response_2_id
|
||||||
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
|
# assert (
|
||||||
|
# response_1_content == response_2_content
|
||||||
|
# ), f"Response 1 != Response 2. Same params, Response 1{response_1_content} != Response 2{response_2_content}"
|
||||||
litellm.success_callback = []
|
litellm.success_callback = []
|
||||||
litellm._async_success_callback = []
|
litellm._async_success_callback = []
|
||||||
litellm.cache = None
|
litellm.cache = None
|
||||||
|
|
|
@ -124,25 +124,12 @@ def test_generate_and_call_with_valid_key(prisma_client):
|
||||||
bearer_token = "Bearer " + generated_key
|
bearer_token = "Bearer " + generated_key
|
||||||
|
|
||||||
assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict
|
assert generated_key not in user_api_key_cache.in_memory_cache.cache_dict
|
||||||
assert (
|
|
||||||
hash_token(generated_key)
|
|
||||||
in user_api_key_cache.in_memory_cache.cache_dict
|
|
||||||
)
|
|
||||||
|
|
||||||
cached_value = user_api_key_cache.in_memory_cache.cache_dict[
|
value_from_prisma = await prisma_client.get_data(
|
||||||
hash_token(generated_key)
|
|
||||||
]
|
|
||||||
|
|
||||||
print("cached value=", cached_value)
|
|
||||||
print("cached token", cached_value.token)
|
|
||||||
|
|
||||||
value_from_prisma = valid_token = await prisma_client.get_data(
|
|
||||||
token=generated_key,
|
token=generated_key,
|
||||||
)
|
)
|
||||||
print("token from prisma", value_from_prisma)
|
print("token from prisma", value_from_prisma)
|
||||||
|
|
||||||
assert value_from_prisma.token == cached_value.token
|
|
||||||
|
|
||||||
request = Request(scope={"type": "http"})
|
request = Request(scope={"type": "http"})
|
||||||
request._url = URL(url="/chat/completions")
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
@ -1241,7 +1228,7 @@ async def test_call_with_key_never_over_budget(prisma_client):
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key: {result}")
|
||||||
|
|
||||||
# update spend using track_cost callback, make 2nd request, it should fail
|
# update spend using track_cost callback, make 2nd request, it should fail
|
||||||
from litellm.proxy.proxy_server import (
|
from litellm.proxy.proxy_server import (
|
||||||
|
@ -1312,7 +1299,7 @@ async def test_call_with_key_over_budget_stream(prisma_client):
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
user_id = key.user_id
|
user_id = key.user_id
|
||||||
bearer_token = "Bearer " + generated_key
|
bearer_token = "Bearer " + generated_key
|
||||||
|
print(f"generated_key: {generated_key}")
|
||||||
request = Request(scope={"type": "http"})
|
request = Request(scope={"type": "http"})
|
||||||
request._url = URL(url="/chat/completions")
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
|
|
@ -99,6 +99,59 @@ async def test_pre_call_hook_rpm_limits():
|
||||||
assert e.status_code == 429
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pre_call_hook_team_rpm_limits():
|
||||||
|
"""
|
||||||
|
Test if error raised on hitting team rpm limits
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
_api_key = "sk-12345"
|
||||||
|
_team_id = "unique-team-id"
|
||||||
|
user_api_key_dict = UserAPIKeyAuth(
|
||||||
|
api_key=_api_key,
|
||||||
|
max_parallel_requests=1,
|
||||||
|
tpm_limit=9,
|
||||||
|
rpm_limit=10,
|
||||||
|
team_rpm_limit=1,
|
||||||
|
team_id=_team_id,
|
||||||
|
)
|
||||||
|
local_cache = DualCache()
|
||||||
|
parallel_request_handler = MaxParallelRequestsHandler()
|
||||||
|
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict, cache=local_cache, data={}, call_type=""
|
||||||
|
)
|
||||||
|
|
||||||
|
kwargs = {
|
||||||
|
"litellm_params": {
|
||||||
|
"metadata": {"user_api_key": _api_key, "user_api_key_team_id": _team_id}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await parallel_request_handler.async_log_success_event(
|
||||||
|
kwargs=kwargs,
|
||||||
|
response_obj="",
|
||||||
|
start_time="",
|
||||||
|
end_time="",
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"local_cache: {local_cache}")
|
||||||
|
|
||||||
|
## Expected cache val: {"current_requests": 0, "current_tpm": 0, "current_rpm": 1}
|
||||||
|
|
||||||
|
try:
|
||||||
|
await parallel_request_handler.async_pre_call_hook(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
cache=local_cache,
|
||||||
|
data={},
|
||||||
|
call_type="",
|
||||||
|
)
|
||||||
|
|
||||||
|
pytest.fail(f"Expected call to fail")
|
||||||
|
except Exception as e:
|
||||||
|
assert e.status_code == 429
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_pre_call_hook_tpm_limits():
|
async def test_pre_call_hook_tpm_limits():
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -1183,7 +1183,7 @@ class Logging:
|
||||||
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
||||||
## BUILD COMPLETE STREAMED RESPONSE
|
## BUILD COMPLETE STREAMED RESPONSE
|
||||||
complete_streaming_response = None
|
complete_streaming_response = None
|
||||||
if self.stream:
|
if self.stream and isinstance(result, ModelResponse):
|
||||||
if (
|
if (
|
||||||
result.choices[0].finish_reason is not None
|
result.choices[0].finish_reason is not None
|
||||||
): # if it's the last chunk
|
): # if it's the last chunk
|
||||||
|
@ -8682,6 +8682,8 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
|
if hasattr(chunk, "id"):
|
||||||
|
model_response.id = chunk.id
|
||||||
if response_obj["is_finished"]:
|
if response_obj["is_finished"]:
|
||||||
model_response.choices[0].finish_reason = response_obj[
|
model_response.choices[0].finish_reason = response_obj[
|
||||||
"finish_reason"
|
"finish_reason"
|
||||||
|
@ -8704,6 +8706,8 @@ class CustomStreamWrapper:
|
||||||
model_response.system_fingerprint = getattr(
|
model_response.system_fingerprint = getattr(
|
||||||
response_obj["original_chunk"], "system_fingerprint", None
|
response_obj["original_chunk"], "system_fingerprint", None
|
||||||
)
|
)
|
||||||
|
if hasattr(response_obj["original_chunk"], "id"):
|
||||||
|
model_response.id = response_obj["original_chunk"].id
|
||||||
if response_obj["logprobs"] is not None:
|
if response_obj["logprobs"] is not None:
|
||||||
model_response.choices[0].logprobs = response_obj["logprobs"]
|
model_response.choices[0].logprobs = response_obj["logprobs"]
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue