forked from phoenix/litellm-mirror
test(test_keys.py): reset proxy spend
This commit is contained in:
parent
34c4532e7e
commit
8e1157fc92
4 changed files with 77 additions and 16 deletions
|
@ -122,11 +122,12 @@ class ModelParams(LiteLLMBase):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class GenerateKeyRequest(LiteLLMBase):
|
class GenerateRequestBase(LiteLLMBase):
|
||||||
duration: Optional[str] = "1h"
|
"""
|
||||||
|
Overlapping schema between key and user generate/update requests
|
||||||
|
"""
|
||||||
|
|
||||||
models: Optional[list] = []
|
models: Optional[list] = []
|
||||||
aliases: Optional[dict] = {}
|
|
||||||
config: Optional[dict] = {}
|
|
||||||
spend: Optional[float] = 0
|
spend: Optional[float] = 0
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
|
@ -138,21 +139,18 @@ class GenerateKeyRequest(LiteLLMBase):
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
class UpdateKeyRequest(LiteLLMBase):
|
class GenerateKeyRequest(GenerateRequestBase):
|
||||||
|
duration: Optional[str] = "1h"
|
||||||
|
aliases: Optional[dict] = {}
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateKeyRequest(GenerateKeyRequest):
|
||||||
# Note: the defaults of all Params here MUST BE NONE
|
# Note: the defaults of all Params here MUST BE NONE
|
||||||
# else they will get overwritten
|
# else they will get overwritten
|
||||||
key: str
|
key: str
|
||||||
duration: Optional[str] = None
|
duration: Optional[str] = None
|
||||||
models: Optional[list] = None
|
|
||||||
aliases: Optional[dict] = None
|
|
||||||
config: Optional[dict] = None
|
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
max_budget: Optional[float] = None
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
max_parallel_requests: Optional[int] = None
|
|
||||||
metadata: Optional[dict] = None
|
metadata: Optional[dict] = None
|
||||||
tpm_limit: Optional[int] = None
|
|
||||||
rpm_limit: Optional[int] = None
|
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
class UserAPIKeyAuth(LiteLLMBase): # the expected response object for user api key auth
|
||||||
|
@ -192,6 +190,14 @@ class NewUserResponse(GenerateKeyResponse):
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateUserRequest(GenerateRequestBase):
|
||||||
|
# Note: the defaults of all Params here MUST BE NONE
|
||||||
|
# else they will get overwritten
|
||||||
|
user_id: str
|
||||||
|
spend: Optional[float] = None
|
||||||
|
metadata: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class KeyManagementSystem(enum.Enum):
|
class KeyManagementSystem(enum.Enum):
|
||||||
GOOGLE_KMS = "google_kms"
|
GOOGLE_KMS = "google_kms"
|
||||||
AZURE_KEY_VAULT = "azure_key_vault"
|
AZURE_KEY_VAULT = "azure_key_vault"
|
||||||
|
|
|
@ -2799,11 +2799,42 @@ async def user_info(
|
||||||
@router.post(
|
@router.post(
|
||||||
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)]
|
||||||
)
|
)
|
||||||
async def user_update(request: Request):
|
async def user_update(data: UpdateUserRequest):
|
||||||
"""
|
"""
|
||||||
[TODO]: Use this to update user budget
|
[TODO]: Use this to update user budget
|
||||||
"""
|
"""
|
||||||
pass
|
global prisma_client
|
||||||
|
try:
|
||||||
|
data_json: dict = data.json()
|
||||||
|
# get the row from db
|
||||||
|
if prisma_client is None:
|
||||||
|
raise Exception("Not connected to DB!")
|
||||||
|
|
||||||
|
non_default_values = {k: v for k, v in data_json.items() if v is not None}
|
||||||
|
response = await prisma_client.update_data(
|
||||||
|
user_id=data_json["user_id"],
|
||||||
|
data=non_default_values,
|
||||||
|
update_key_values=non_default_values,
|
||||||
|
)
|
||||||
|
return {"user_id": data_json["user_id"], **non_default_values}
|
||||||
|
# update based on remaining passed in values
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
if isinstance(e, HTTPException):
|
||||||
|
raise ProxyException(
|
||||||
|
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||||
|
type="auth_error",
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||||
|
)
|
||||||
|
elif isinstance(e, ProxyException):
|
||||||
|
raise e
|
||||||
|
raise ProxyException(
|
||||||
|
message="Authentication Error, " + str(e),
|
||||||
|
type="auth_error",
|
||||||
|
param=getattr(e, "param", "None"),
|
||||||
|
code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
#### MODEL MANAGEMENT ####
|
#### MODEL MANAGEMENT ####
|
||||||
|
|
|
@ -655,13 +655,14 @@ class PrismaClient:
|
||||||
user_id is not None
|
user_id is not None
|
||||||
or (table_name is not None and table_name == "user")
|
or (table_name is not None and table_name == "user")
|
||||||
and query_type == "update"
|
and query_type == "update"
|
||||||
and update_key_values is not None
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
If data['spend'] + data['user'], update the user table with spend info as well
|
If data['spend'] + data['user'], update the user table with spend info as well
|
||||||
"""
|
"""
|
||||||
if user_id is None:
|
if user_id is None:
|
||||||
user_id = db_data["user_id"]
|
user_id = db_data["user_id"]
|
||||||
|
if update_key_values is None:
|
||||||
|
update_key_values = db_data
|
||||||
update_user_row = await self.db.litellm_usertable.upsert(
|
update_user_row = await self.db.litellm_usertable.upsert(
|
||||||
where={"user_id": user_id}, # type: ignore
|
where={"user_id": user_id}, # type: ignore
|
||||||
data={
|
data={
|
||||||
|
|
|
@ -67,6 +67,28 @@ async def update_key(session, get_key):
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def update_proxy_budget(session):
|
||||||
|
"""
|
||||||
|
Make sure only models user has access to are returned
|
||||||
|
"""
|
||||||
|
url = "http://0.0.0.0:4000/user/update"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer sk-1234",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
data = {"user_id": "litellm-proxy-budget", "spend": 0}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def chat_completion(session, key, model="gpt-4"):
|
async def chat_completion(session, key, model="gpt-4"):
|
||||||
url = "http://0.0.0.0:4000/chat/completions"
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
headers = {
|
headers = {
|
||||||
|
@ -135,6 +157,7 @@ async def test_key_update():
|
||||||
session=session,
|
session=session,
|
||||||
get_key=key,
|
get_key=key,
|
||||||
)
|
)
|
||||||
|
await update_proxy_budget(session=session) # resets proxy spend
|
||||||
await chat_completion(session=session, key=key)
|
await chat_completion(session=session, key=key)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue