Merge pull request #3870 from BerriAI/litellm_rename_end_user

[Feat] Rename `/end/user/new` -> `/customer/new` (maintain backwards compatibility)
This commit is contained in:
Ishaan Jaff 2024-05-27 19:42:14 -07:00 committed by GitHub
commit e1b46d4b6e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 310 additions and 78 deletions

View file

@ -13,7 +13,7 @@ Requirements:
You can set budgets at 3 levels: You can set budgets at 3 levels:
- For the proxy - For the proxy
- For an internal user - For an internal user
- For an end-user - For a customer (end-user)
- For a key - For a key
- For a key (model specific budgets) - For a key (model specific budgets)
@ -173,7 +173,7 @@ curl --location 'http://localhost:4000/chat/completions' \
``` ```
</TabItem> </TabItem>
<TabItem value="per-user-chat" label="For End User"> <TabItem value="per-user-chat" label="For Customers">
Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user** Use this to budget `user` passed to `/chat/completions`, **without needing to create a key for every user**
@ -452,7 +452,7 @@ curl --location 'http://0.0.0.0:4000/key/generate' \
``` ```
</TabItem> </TabItem>
<TabItem value="per-end-user" label="For End User"> <TabItem value="per-end-user" label="For customers">
:::info :::info
@ -477,12 +477,12 @@ curl --location 'http://0.0.0.0:4000/budget/new' \
``` ```
#### Step 2. Create `End-User` with Budget #### Step 2. Create `Customer` with Budget
We use `budget_id="free-tier"` from Step 1 when creating this new end user We use `budget_id="free-tier"` from Step 1 when creating this new customers
```shell ```shell
curl --location 'http://0.0.0.0:4000/end_user/new' \ curl --location 'http://0.0.0.0:4000/customer/new' \
--header 'Authorization: Bearer sk-1234' \ --header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \ --header 'Content-Type: application/json' \
--data '{ --data '{
@ -492,7 +492,7 @@ curl --location 'http://0.0.0.0:4000/end_user/new' \
``` ```
#### Step 3. Pass end user id in `/chat/completions` requests #### Step 3. Pass `user_id` id in `/chat/completions` requests
Pass the `user_id` from Step 2 as `user="palantir"` Pass the `user_id` from Step 2 as `user="palantir"`

View file

@ -519,7 +519,11 @@ class UpdateUserRequest(GenerateRequestBase):
return values return values
class NewEndUserRequest(LiteLLMBase): class NewCustomerRequest(LiteLLMBase):
"""
Create a new customer, allocate a budget to them
"""
user_id: str user_id: str
alias: Optional[str] = None # human-friendly alias alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user blocked: bool = False # allow/disallow requests for this end-user
@ -540,6 +544,33 @@ class NewEndUserRequest(LiteLLMBase):
return values return values
class UpdateCustomerRequest(LiteLLMBase):
"""
Update a Customer, use this to update customer budgets etc
"""
user_id: str
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[Literal["eu"]] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
class DeleteCustomerRequest(LiteLLMBase):
"""
Delete multiple Customers
"""
user_ids: List[str]
class Member(LiteLLMBase): class Member(LiteLLMBase):
role: Literal["admin", "user"] role: Literal["admin", "user"]
user_id: Optional[str] = None user_id: Optional[str] = None

View file

@ -7137,13 +7137,15 @@ async def global_predict_spend_logs(request: Request):
#### INTERNAL USER MANAGEMENT #### #### INTERNAL USER MANAGEMENT ####
@router.post( @router.post(
"/user/new", "/user/new",
tags=["user management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
response_model=NewUserResponse, response_model=NewUserResponse,
) )
async def new_user(data: NewUserRequest): async def new_user(data: NewUserRequest):
""" """
Use this to create a new user with a budget. This creates a new user and generates a new api key for the new user. The new api key is returned. Use this to create a new INTERNAL user with a budget.
Internal Users can access LiteLLM Admin UI to make keys, request access to models.
This creates a new user and generates a new api key for the new user. The new api key is returned.
Returns user id, budget + new key. Returns user id, budget + new key.
@ -7214,7 +7216,9 @@ async def new_user(data: NewUserRequest):
@router.post( @router.post(
"/user/auth", tags=["user management"], dependencies=[Depends(user_api_key_auth)] "/user/auth",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
) )
async def user_auth(request: Request): async def user_auth(request: Request):
""" """
@ -7280,7 +7284,9 @@ async def user_auth(request: Request):
@router.get( @router.get(
"/user/info", tags=["user management"], dependencies=[Depends(user_api_key_auth)] "/user/info",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
) )
async def user_info( async def user_info(
user_id: Optional[str] = fastapi.Query( user_id: Optional[str] = fastapi.Query(
@ -7452,7 +7458,9 @@ async def user_info(
@router.post( @router.post(
"/user/update", tags=["user management"], dependencies=[Depends(user_api_key_auth)] "/user/update",
tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)],
) )
async def user_update(data: UpdateUserRequest): async def user_update(data: UpdateUserRequest):
""" """
@ -7546,7 +7554,7 @@ async def user_update(data: UpdateUserRequest):
@router.post( @router.post(
"/user/request_model", "/user/request_model",
tags=["user management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def user_request_model(request: Request): async def user_request_model(request: Request):
@ -7599,7 +7607,7 @@ async def user_request_model(request: Request):
@router.get( @router.get(
"/user/get_requests", "/user/get_requests",
tags=["user management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def user_get_requests(): async def user_get_requests():
@ -7641,7 +7649,7 @@ async def user_get_requests():
@router.get( @router.get(
"/user/get_users", "/user/get_users",
tags=["user management"], tags=["Internal User management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def get_users( async def get_users(
@ -7678,7 +7686,13 @@ async def get_users(
@router.post( @router.post(
"/end_user/block", "/end_user/block",
tags=["End User Management"], tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
@router.post(
"/customer/block",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def block_user(data: BlockUsers): async def block_user(data: BlockUsers):
@ -7721,9 +7735,15 @@ async def block_user(data: BlockUsers):
@router.post( @router.post(
"/end_user/unblock", "/end_user/unblock",
tags=["End User Management"], tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
@router.post(
"/customer/unblock",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
include_in_schema=False,
)
async def unblock_user(data: BlockUsers): async def unblock_user(data: BlockUsers):
""" """
[BETA] Unblock calls with this user id [BETA] Unblock calls with this user id
@ -7768,35 +7788,36 @@ async def unblock_user(data: BlockUsers):
@router.post( @router.post(
"/end_user/new", "/end_user/new",
tags=["End User Management"], tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
@router.post(
"/customer/new",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def new_end_user( async def new_end_user(
data: NewEndUserRequest, data: NewCustomerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
): ):
""" """
[TODO] Needs to be implemented. Allow creating a new Customer
NOTE: This used to be called `/end_user/new`, we will still be maintaining compatibility for /end_user/XXX for these endpoints
Allow creating a new end-user
- Allow specifying allowed regions - Allow specifying allowed regions
- Allow specifying default model - Allow specifying default model
Example curl: Example curl:
``` ```
curl --location 'http://0.0.0.0:4000/end_user/new' \ curl --location 'http://0.0.0.0:4000/customer/new' \
--header 'Authorization: Bearer sk-1234' \ --header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \ --header 'Content-Type: application/json' \
--data '{ --data '{
"end_user_id" : "ishaan-jaff-3", <- specific customer "user_id" : "ishaan-jaff-3",
"allowed_region": "eu",
"allowed_region": "eu" <- set region for models "budget_id": "free_tier",
+
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model? "default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
}' }'
# return end-user object # return end-user object
@ -7819,56 +7840,88 @@ async def new_end_user(
status_code=500, status_code=500,
detail={"error": CommonProxyErrors.db_not_connected_error.value}, detail={"error": CommonProxyErrors.db_not_connected_error.value},
) )
try:
## VALIDATION ## ## VALIDATION ##
if data.default_model is not None: if data.default_model is not None:
if llm_router is None: if llm_router is None:
raise HTTPException( raise HTTPException(
status_code=422, detail={"error": CommonProxyErrors.no_llm_router.value} status_code=422,
) detail={"error": CommonProxyErrors.no_llm_router.value},
elif data.default_model not in llm_router.get_model_names(): )
raise HTTPException( elif data.default_model not in llm_router.get_model_names():
status_code=422, raise HTTPException(
detail={ status_code=422,
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format( detail={
data.default_model, set(llm_router.get_model_names()) "error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
) data.default_model, set(llm_router.get_model_names())
}, )
},
)
new_end_user_obj: Dict = {}
## CREATE BUDGET ## if set
if data.max_budget is not None:
budget_record = await prisma_client.db.litellm_budgettable.create(
data={
"max_budget": data.max_budget,
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
) )
new_end_user_obj: Dict = {} new_end_user_obj["budget_id"] = budget_record.budget_id
elif data.budget_id is not None:
new_end_user_obj["budget_id"] = data.budget_id
## CREATE BUDGET ## if set _user_data = data.dict(exclude_none=True)
if data.max_budget is not None:
budget_record = await prisma_client.db.litellm_budgettable.create( for k, v in _user_data.items():
data={ if k != "max_budget" and k != "budget_id":
"max_budget": data.max_budget, new_end_user_obj[k] = v
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, ## WRITE TO DB ##
} end_user_record = await prisma_client.db.litellm_endusertable.create(
data=new_end_user_obj # type: ignore
) )
new_end_user_obj["budget_id"] = budget_record.budget_id return end_user_record
elif data.budget_id is not None: except Exception as e:
new_end_user_obj["budget_id"] = data.budget_id if "Unique constraint failed on the fields: (`user_id`)" in str(e):
raise ProxyException(
message=f"Customer already exists, passed user_id={data.user_id}. Please pass a new user_id.",
type="bad_request",
code=400,
param="user_id",
)
_user_data = data.dict(exclude_none=True) if isinstance(e, HTTPException):
raise ProxyException(
for k, v in _user_data.items(): message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
if k != "max_budget" and k != "budget_id": type="internal_error",
new_end_user_obj[k] = v param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
## WRITE TO DB ## )
end_user_record = await prisma_client.db.litellm_endusertable.create( elif isinstance(e, ProxyException):
data=new_end_user_obj # type: ignore raise e
) raise ProxyException(
message="Internal Server Error, " + str(e),
return end_user_record type="internal_error",
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
@router.get(
"/customer/info",
tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)],
)
@router.get( @router.get(
"/end_user/info", "/end_user/info",
tags=["End User Management"], tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def end_user_info( async def end_user_info(
@ -7892,26 +7945,174 @@ async def end_user_info(
@router.post( @router.post(
"/end_user/update", "/customer/update",
tags=["End User Management"], tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def update_end_user(): @router.post(
"/end_user/update",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def update_end_user(
data: UpdateCustomerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
[TODO] Needs to be implemented. Example curl
```
curl --location 'http://0.0.0.0:4000/customer/update' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_id": "test-litellm-user-4",
"budget_id": "paid_tier"
}'
See below for all params
```
""" """
global prisma_client
try:
data_json: dict = data.json()
# get the row from db
if prisma_client is None:
raise Exception("Not connected to DB!")
# get non default values for key
non_default_values = {}
for k, v in data_json.items():
if v is not None and v not in (
[],
{},
0,
): # models default to [], spend defaults to 0, we should not reset these values
non_default_values[k] = v
## ADD USER, IF NEW ##
verbose_proxy_logger.debug("/customer/update: Received data = %s", data)
if data.user_id is not None and len(data.user_id) > 0:
non_default_values["user_id"] = data.user_id # type: ignore
verbose_proxy_logger.debug("In update customer, user_id condition block.")
response = await prisma_client.db.litellm_endusertable.update(
where={"user_id": data.user_id}, data=non_default_values # type: ignore
)
if response is None:
raise ValueError(
f"Failed updating customer data. User ID does not exist passed user_id={data.user_id}"
)
verbose_proxy_logger.debug(
f"received response from updating prisma client. response={response}"
)
return response
else:
raise ValueError(f"user_id is required, passed user_id = {data.user_id}")
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
type="internal_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Internal Server Error, " + str(e),
type="internal_error",
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
pass pass
@router.post( @router.post(
"/end_user/delete", "/customer/delete",
tags=["End User Management"], tags=["Customer Management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def delete_end_user(): @router.post(
"/end_user/delete",
tags=["Customer Management"],
include_in_schema=False,
dependencies=[Depends(user_api_key_auth)],
)
async def delete_end_user(
data: DeleteCustomerRequest,
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
):
""" """
[TODO] Needs to be implemented. Example curl
```
curl --location 'http://0.0.0.0:4000/customer/delete' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"user_ids" :["ishaan-jaff-5"]
}'
See below for all params
```
""" """
global prisma_client
try:
if prisma_client is None:
raise Exception("Not connected to DB!")
verbose_proxy_logger.debug("/customer/delete: Received data = %s", data)
if (
data.user_ids is not None
and isinstance(data.user_ids, list)
and len(data.user_ids) > 0
):
response = await prisma_client.db.litellm_endusertable.delete_many(
where={"user_id": {"in": data.user_ids}}
)
if response is None:
raise ValueError(
f"Failed deleting customer data. User ID does not exist passed user_id={data.user_ids}"
)
if response != len(data.user_ids):
raise ValueError(
f"Failed deleting all customer data. User ID does not exist passed user_id={data.user_ids}. Deleted {response} customers, passed {len(data.user_ids)} customers"
)
verbose_proxy_logger.debug(
f"received response from updating prisma client. response={response}"
)
return {
"deleted_customers": response,
"message": "Successfully deleted customers with ids: "
+ str(data.user_ids),
}
else:
raise ValueError(f"user_id is required, passed user_id = {data.user_ids}")
# update based on remaining passed in values
except Exception as e:
traceback.print_exc()
if isinstance(e, HTTPException):
raise ProxyException(
message=getattr(e, "detail", f"Internal Server Error({str(e)})"),
type="internal_error",
param=getattr(e, "param", "None"),
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
)
elif isinstance(e, ProxyException):
raise e
raise ProxyException(
message="Internal Server Error, " + str(e),
type="internal_error",
param=getattr(e, "param", "None"),
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
)
pass pass