forked from phoenix/litellm-mirror
Merge pull request #3536 from BerriAI/litellm_region_based_routing
feat(proxy_server.py): add CRUD endpoints for 'end_user' management
This commit is contained in:
commit
66a1b581e5
14 changed files with 606 additions and 44 deletions
4
.gitignore
vendored
4
.gitignore
vendored
|
@ -1,5 +1,6 @@
|
||||||
.venv
|
.venv
|
||||||
.env
|
.env
|
||||||
|
litellm/proxy/myenv/*
|
||||||
litellm_uuid.txt
|
litellm_uuid.txt
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
|
@ -52,3 +53,6 @@ litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_new_secret_config.yaml
|
litellm/proxy/_new_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
litellm/proxy/_super_secret_config.yaml
|
litellm/proxy/_super_secret_config.yaml
|
||||||
|
litellm/proxy/myenv/bin/activate
|
||||||
|
litellm/proxy/myenv/bin/Activate.ps1
|
||||||
|
myenv/*
|
|
@ -1,5 +1,5 @@
|
||||||
from typing import Optional, Union, Any
|
from typing import Optional, Union, Any, Literal
|
||||||
import types, requests # type: ignore
|
import types, requests
|
||||||
from .base import BaseLLM
|
from .base import BaseLLM
|
||||||
from litellm.utils import (
|
from litellm.utils import (
|
||||||
ModelResponse,
|
ModelResponse,
|
||||||
|
@ -952,6 +952,81 @@ class AzureChatCompletion(BaseLLM):
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def get_headers(
|
||||||
|
self,
|
||||||
|
model: Optional[str],
|
||||||
|
api_key: str,
|
||||||
|
api_base: str,
|
||||||
|
api_version: str,
|
||||||
|
timeout: float,
|
||||||
|
mode: str,
|
||||||
|
messages: Optional[list] = None,
|
||||||
|
input: Optional[list] = None,
|
||||||
|
prompt: Optional[str] = None,
|
||||||
|
) -> dict:
|
||||||
|
client_session = litellm.client_session or httpx.Client(
|
||||||
|
transport=CustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
|
)
|
||||||
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
|
## build base url - assume api base includes resource name
|
||||||
|
if not api_base.endswith("/"):
|
||||||
|
api_base += "/"
|
||||||
|
api_base += f"{model}"
|
||||||
|
client = AzureOpenAI(
|
||||||
|
base_url=api_base,
|
||||||
|
api_version=api_version,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
model = None
|
||||||
|
# cloudflare ai gateway, needs model=None
|
||||||
|
else:
|
||||||
|
client = AzureOpenAI(
|
||||||
|
api_version=api_version,
|
||||||
|
azure_endpoint=api_base,
|
||||||
|
api_key=api_key,
|
||||||
|
timeout=timeout,
|
||||||
|
http_client=client_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
# only run this check if it's not cloudflare ai gateway
|
||||||
|
if model is None and mode != "image_generation":
|
||||||
|
raise Exception("model is not set")
|
||||||
|
|
||||||
|
completion = None
|
||||||
|
|
||||||
|
if messages is None:
|
||||||
|
messages = [{"role": "user", "content": "Hey"}]
|
||||||
|
try:
|
||||||
|
completion = client.chat.completions.with_raw_response.create(
|
||||||
|
model=model, # type: ignore
|
||||||
|
messages=messages, # type: ignore
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
response = {}
|
||||||
|
|
||||||
|
if completion is None or not hasattr(completion, "headers"):
|
||||||
|
raise Exception("invalid completion response")
|
||||||
|
|
||||||
|
if (
|
||||||
|
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
|
||||||
|
): # not provided for dall-e requests
|
||||||
|
response["x-ratelimit-remaining-requests"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-requests"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
|
||||||
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
|
"x-ratelimit-remaining-tokens"
|
||||||
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
|
return response
|
||||||
|
|
||||||
async def ahealth_check(
|
async def ahealth_check(
|
||||||
self,
|
self,
|
||||||
model: Optional[str],
|
model: Optional[str],
|
||||||
|
@ -963,7 +1038,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
messages: Optional[list] = None,
|
messages: Optional[list] = None,
|
||||||
input: Optional[list] = None,
|
input: Optional[list] = None,
|
||||||
prompt: Optional[str] = None,
|
prompt: Optional[str] = None,
|
||||||
):
|
) -> dict:
|
||||||
client_session = litellm.aclient_session or httpx.AsyncClient(
|
client_session = litellm.aclient_session or httpx.AsyncClient(
|
||||||
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
|
||||||
)
|
)
|
||||||
|
@ -1040,4 +1115,8 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
response["x-ratelimit-remaining-tokens"] = completion.headers[
|
||||||
"x-ratelimit-remaining-tokens"
|
"x-ratelimit-remaining-tokens"
|
||||||
]
|
]
|
||||||
|
|
||||||
|
if completion.headers.get("x-ms-region", None) is not None:
|
||||||
|
response["x-ms-region"] = completion.headers["x-ms-region"]
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
|
@ -651,6 +651,8 @@ def completion(
|
||||||
"base_model",
|
"base_model",
|
||||||
"stream_timeout",
|
"stream_timeout",
|
||||||
"supports_system_message",
|
"supports_system_message",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -2721,6 +2723,8 @@ def embedding(
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
"no-log",
|
"no-log",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
@ -3595,6 +3599,8 @@ def image_generation(
|
||||||
"caching_groups",
|
"caching_groups",
|
||||||
"ttl",
|
"ttl",
|
||||||
"cache",
|
"cache",
|
||||||
|
"region_name",
|
||||||
|
"allowed_model_region",
|
||||||
]
|
]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {
|
non_default_params = {
|
||||||
|
|
|
@ -458,6 +458,27 @@ class UpdateUserRequest(GenerateRequestBase):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class NewEndUserRequest(LiteLLMBase):
|
||||||
|
user_id: str
|
||||||
|
alias: Optional[str] = None # human-friendly alias
|
||||||
|
blocked: bool = False # allow/disallow requests for this end-user
|
||||||
|
max_budget: Optional[float] = None
|
||||||
|
budget_id: Optional[str] = None # give either a budget_id or max_budget
|
||||||
|
allowed_model_region: Optional[Literal["eu"]] = (
|
||||||
|
None # require all user requests to use models in this specific region
|
||||||
|
)
|
||||||
|
default_model: Optional[str] = (
|
||||||
|
None # if no equivalent model in allowed region - default all requests to this model
|
||||||
|
)
|
||||||
|
|
||||||
|
@root_validator(pre=True)
|
||||||
|
def check_user_info(cls, values):
|
||||||
|
if values.get("max_budget") is not None and values.get("budget_id") is not None:
|
||||||
|
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class Member(LiteLLMBase):
|
class Member(LiteLLMBase):
|
||||||
role: Literal["admin", "user"]
|
role: Literal["admin", "user"]
|
||||||
user_id: Optional[str] = None
|
user_id: Optional[str] = None
|
||||||
|
@ -838,6 +859,7 @@ class UserAPIKeyAuth(
|
||||||
|
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
|
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
|
||||||
|
allowed_model_region: Optional[Literal["eu"]] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def check_api_key(cls, values):
|
def check_api_key(cls, values):
|
||||||
|
@ -883,6 +905,8 @@ class LiteLLM_EndUserTable(LiteLLMBase):
|
||||||
blocked: bool
|
blocked: bool
|
||||||
alias: Optional[str] = None
|
alias: Optional[str] = None
|
||||||
spend: float = 0.0
|
spend: float = 0.0
|
||||||
|
allowed_model_region: Optional[Literal["eu"]] = None
|
||||||
|
default_model: Optional[str] = None
|
||||||
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
|
|
|
@ -208,7 +208,9 @@ async def get_end_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
# check if in cache
|
# check if in cache
|
||||||
cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id)
|
cached_user_obj = user_api_key_cache.async_get_cache(
|
||||||
|
key="end_user_id:{}".format(end_user_id)
|
||||||
|
)
|
||||||
if cached_user_obj is not None:
|
if cached_user_obj is not None:
|
||||||
if isinstance(cached_user_obj, dict):
|
if isinstance(cached_user_obj, dict):
|
||||||
return LiteLLM_EndUserTable(**cached_user_obj)
|
return LiteLLM_EndUserTable(**cached_user_obj)
|
||||||
|
@ -223,7 +225,14 @@ async def get_end_user_object(
|
||||||
if response is None:
|
if response is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
|
||||||
return LiteLLM_EndUserTable(**response.dict())
|
# save the end-user object to cache
|
||||||
|
await user_api_key_cache.async_set_cache(
|
||||||
|
key="end_user_id:{}".format(end_user_id), value=response
|
||||||
|
)
|
||||||
|
|
||||||
|
_response = LiteLLM_EndUserTable(**response.dict())
|
||||||
|
|
||||||
|
return _response
|
||||||
except Exception as e: # if end-user not in db
|
except Exception as e: # if end-user not in db
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
|
@ -231,6 +231,11 @@ class SpecialModelNames(enum.Enum):
|
||||||
all_team_models = "all-team-models"
|
all_team_models = "all-team-models"
|
||||||
|
|
||||||
|
|
||||||
|
class CommonProxyErrors(enum.Enum):
|
||||||
|
db_not_connected_error = "DB not connected"
|
||||||
|
no_llm_router = "No models configured on proxy"
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(ProxyException)
|
@app.exception_handler(ProxyException)
|
||||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||||
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
||||||
|
@ -467,10 +472,6 @@ async def user_api_key_auth(
|
||||||
prisma_client=prisma_client,
|
prisma_client=prisma_client,
|
||||||
user_api_key_cache=user_api_key_cache,
|
user_api_key_cache=user_api_key_cache,
|
||||||
)
|
)
|
||||||
# save the end-user object to cache
|
|
||||||
await user_api_key_cache.async_set_cache(
|
|
||||||
key=end_user_id, value=end_user_object
|
|
||||||
)
|
|
||||||
|
|
||||||
global_proxy_spend = None
|
global_proxy_spend = None
|
||||||
if litellm.max_budget > 0: # user set proxy max budget
|
if litellm.max_budget > 0: # user set proxy max budget
|
||||||
|
@ -952,13 +953,16 @@ async def user_api_key_auth(
|
||||||
|
|
||||||
_end_user_object = None
|
_end_user_object = None
|
||||||
if "user" in request_data:
|
if "user" in request_data:
|
||||||
_id = "end_user_id:{}".format(request_data["user"])
|
_end_user_object = await get_end_user_object(
|
||||||
_end_user_object = await user_api_key_cache.async_get_cache(key=_id)
|
end_user_id=request_data["user"],
|
||||||
if _end_user_object is not None:
|
prisma_client=prisma_client,
|
||||||
_end_user_object = LiteLLM_EndUserTable(**_end_user_object)
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
)
|
||||||
|
|
||||||
global_proxy_spend = None
|
global_proxy_spend = None
|
||||||
if litellm.max_budget > 0: # user set proxy max budget
|
if (
|
||||||
|
litellm.max_budget > 0 and prisma_client is not None
|
||||||
|
): # user set proxy max budget
|
||||||
# check cache
|
# check cache
|
||||||
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
global_proxy_spend = await user_api_key_cache.async_get_cache(
|
||||||
key="{}:spend".format(litellm_proxy_admin_name)
|
key="{}:spend".format(litellm_proxy_admin_name)
|
||||||
|
@ -1011,6 +1015,12 @@ async def user_api_key_auth(
|
||||||
)
|
)
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
|
|
||||||
|
if _end_user_object is not None:
|
||||||
|
valid_token_dict["allowed_model_region"] = (
|
||||||
|
_end_user_object.allowed_model_region
|
||||||
|
)
|
||||||
|
|
||||||
"""
|
"""
|
||||||
asyncio create task to update the user api key cache with the user db table as well
|
asyncio create task to update the user api key cache with the user db table as well
|
||||||
|
|
||||||
|
@ -1035,10 +1045,7 @@ async def user_api_key_auth(
|
||||||
# check if user can access this route
|
# check if user can access this route
|
||||||
query_params = request.query_params
|
query_params = request.query_params
|
||||||
key = query_params.get("key")
|
key = query_params.get("key")
|
||||||
if (
|
if key is not None and hash_token(token=key) != api_key:
|
||||||
key is not None
|
|
||||||
and prisma_client.hash_token(token=key) != api_key
|
|
||||||
):
|
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail="user not allowed to access this key's info",
|
detail="user not allowed to access this key's info",
|
||||||
|
@ -1091,6 +1098,7 @@ async def user_api_key_auth(
|
||||||
# sso/login, ui/login, /key functions and /user functions
|
# sso/login, ui/login, /key functions and /user functions
|
||||||
# this will never be allowed to call /chat/completions
|
# this will never be allowed to call /chat/completions
|
||||||
token_team = getattr(valid_token, "team_id", None)
|
token_team = getattr(valid_token, "team_id", None)
|
||||||
|
|
||||||
if token_team is not None and token_team == "litellm-dashboard":
|
if token_team is not None and token_team == "litellm-dashboard":
|
||||||
# this token is only used for managing the ui
|
# this token is only used for managing the ui
|
||||||
allowed_routes = [
|
allowed_routes = [
|
||||||
|
@ -3612,6 +3620,10 @@ async def chat_completion(
|
||||||
**data,
|
**data,
|
||||||
} # add the team-specific configs to the completion call
|
} # add the team-specific configs to the completion call
|
||||||
|
|
||||||
|
### END-USER SPECIFIC PARAMS ###
|
||||||
|
if user_api_key_dict.allowed_model_region is not None:
|
||||||
|
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
|
||||||
|
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
if user_temperature:
|
if user_temperature:
|
||||||
|
@ -5940,7 +5952,7 @@ async def global_predict_spend_logs(request: Request):
|
||||||
return _forecast_daily_cost(data)
|
return _forecast_daily_cost(data)
|
||||||
|
|
||||||
|
|
||||||
#### USER MANAGEMENT ####
|
#### INTERNAL USER MANAGEMENT ####
|
||||||
@router.post(
|
@router.post(
|
||||||
"/user/new",
|
"/user/new",
|
||||||
tags=["user management"],
|
tags=["user management"],
|
||||||
|
@ -6433,6 +6445,43 @@ async def user_get_requests():
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/user/get_users",
|
||||||
|
tags=["user management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def get_users(
|
||||||
|
role: str = fastapi.Query(
|
||||||
|
default=None,
|
||||||
|
description="Either 'proxy_admin', 'proxy_viewer', 'app_owner', 'app_user'",
|
||||||
|
)
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
[BETA] This could change without notice. Give feedback - https://github.com/BerriAI/litellm/issues
|
||||||
|
|
||||||
|
Get all users who are a specific `user_role`.
|
||||||
|
|
||||||
|
Used by the UI to populate the user lists.
|
||||||
|
|
||||||
|
Currently - admin-only endpoint.
|
||||||
|
"""
|
||||||
|
global prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=500,
|
||||||
|
detail={"error": f"No db connected. prisma client={prisma_client}"},
|
||||||
|
)
|
||||||
|
all_users = await prisma_client.get_data(
|
||||||
|
table_name="user", query_type="find_all", key_val={"user_role": role}
|
||||||
|
)
|
||||||
|
|
||||||
|
return all_users
|
||||||
|
|
||||||
|
|
||||||
|
#### END-USER MANAGEMENT ####
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
@router.post(
|
||||||
"/end_user/block",
|
"/end_user/block",
|
||||||
tags=["End User Management"],
|
tags=["End User Management"],
|
||||||
|
@ -6523,38 +6572,140 @@ async def unblock_user(data: BlockUsers):
|
||||||
return {"blocked_users": litellm.blocked_user_list}
|
return {"blocked_users": litellm.blocked_user_list}
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
@router.post(
|
||||||
"/user/get_users",
|
"/end_user/new",
|
||||||
tags=["user management"],
|
tags=["End User Management"],
|
||||||
dependencies=[Depends(user_api_key_auth)],
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
)
|
)
|
||||||
async def get_users(
|
async def new_end_user(
|
||||||
role: str = fastapi.Query(
|
data: NewEndUserRequest,
|
||||||
default=None,
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
description="Either 'proxy_admin', 'proxy_viewer', 'app_owner', 'app_user'",
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
[BETA] This could change without notice. Give feedback - https://github.com/BerriAI/litellm/issues
|
[TODO] Needs to be implemented.
|
||||||
|
|
||||||
Get all users who are a specific `user_role`.
|
Allow creating a new end-user
|
||||||
|
|
||||||
Used by the UI to populate the user lists.
|
- Allow specifying allowed regions
|
||||||
|
- Allow specifying default model
|
||||||
|
|
||||||
Currently - admin-only endpoint.
|
Example curl:
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/end_user/new' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"end_user_id" : "ishaan-jaff-3", <- specific customer
|
||||||
|
|
||||||
|
"allowed_region": "eu" <- set region for models
|
||||||
|
|
||||||
|
+
|
||||||
|
|
||||||
|
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
|
||||||
|
|
||||||
|
}'
|
||||||
|
|
||||||
|
# return end-user object
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
global prisma_client
|
global prisma_client, llm_router
|
||||||
|
"""
|
||||||
|
Validation:
|
||||||
|
- check if default model exists
|
||||||
|
- create budget object if not already created
|
||||||
|
|
||||||
|
- Add user to end user table
|
||||||
|
|
||||||
|
Return
|
||||||
|
- end-user object
|
||||||
|
- currently allowed models
|
||||||
|
"""
|
||||||
if prisma_client is None:
|
if prisma_client is None:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
detail={"error": f"No db connected. prisma client={prisma_client}"},
|
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||||
)
|
|
||||||
all_users = await prisma_client.get_data(
|
|
||||||
table_name="user", query_type="find_all", key_val={"user_role": role}
|
|
||||||
)
|
)
|
||||||
|
|
||||||
return all_users
|
## VALIDATION ##
|
||||||
|
if data.default_model is not None:
|
||||||
|
if llm_router is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422, detail={"error": CommonProxyErrors.no_llm_router.value}
|
||||||
|
)
|
||||||
|
elif data.default_model not in llm_router.get_model_names():
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=422,
|
||||||
|
detail={
|
||||||
|
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
|
||||||
|
data.default_model, set(llm_router.get_model_names())
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_end_user_obj: Dict = {}
|
||||||
|
|
||||||
|
## CREATE BUDGET ## if set
|
||||||
|
if data.max_budget is not None:
|
||||||
|
budget_record = await prisma_client.db.litellm_budgettable.create(
|
||||||
|
data={
|
||||||
|
"max_budget": data.max_budget,
|
||||||
|
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
|
||||||
|
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
new_end_user_obj["budget_id"] = budget_record.budget_id
|
||||||
|
elif data.budget_id is not None:
|
||||||
|
new_end_user_obj["budget_id"] = data.budget_id
|
||||||
|
|
||||||
|
_user_data = data.dict(exclude_none=True)
|
||||||
|
|
||||||
|
for k, v in _user_data.items():
|
||||||
|
if k != "max_budget" and k != "budget_id":
|
||||||
|
new_end_user_obj[k] = v
|
||||||
|
|
||||||
|
## WRITE TO DB ##
|
||||||
|
end_user_record = await prisma_client.db.litellm_endusertable.create(
|
||||||
|
data=new_end_user_obj # type: ignore
|
||||||
|
)
|
||||||
|
|
||||||
|
return end_user_record
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/end_user/info",
|
||||||
|
tags=["End User Management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def end_user_info():
|
||||||
|
"""
|
||||||
|
[TODO] Needs to be implemented.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/end_user/update",
|
||||||
|
tags=["End User Management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def update_end_user():
|
||||||
|
"""
|
||||||
|
[TODO] Needs to be implemented.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/end_user/delete",
|
||||||
|
tags=["End User Management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
async def delete_end_user():
|
||||||
|
"""
|
||||||
|
[TODO] Needs to be implemented.
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
#### TEAM MANAGEMENT ####
|
#### TEAM MANAGEMENT ####
|
||||||
|
|
|
@ -150,6 +150,8 @@ model LiteLLM_EndUserTable {
|
||||||
user_id String @id
|
user_id String @id
|
||||||
alias String? // admin-facing alias
|
alias String? // admin-facing alias
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
|
allowed_model_region String? // require all user requests to use models in this specific region
|
||||||
|
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
|
||||||
budget_id String?
|
budget_id String?
|
||||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
blocked Boolean @default(false)
|
blocked Boolean @default(false)
|
||||||
|
|
|
@ -526,7 +526,7 @@ class PrismaClient:
|
||||||
finally:
|
finally:
|
||||||
os.chdir(original_dir)
|
os.chdir(original_dir)
|
||||||
# Now you can import the Prisma Client
|
# Now you can import the Prisma Client
|
||||||
from prisma import Prisma # type: ignore
|
from prisma import Prisma
|
||||||
|
|
||||||
self.db = Prisma() # Client to connect to Prisma db
|
self.db = Prisma() # Client to connect to Prisma db
|
||||||
|
|
||||||
|
|
|
@ -32,6 +32,7 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
get_utc_datetime,
|
get_utc_datetime,
|
||||||
calculate_max_parallel_requests,
|
calculate_max_parallel_requests,
|
||||||
|
_is_region_eu,
|
||||||
)
|
)
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
|
@ -1999,7 +2000,11 @@ class Router:
|
||||||
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
|
||||||
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
|
||||||
api_key = litellm_params.get("api_key") or default_api_key
|
api_key = litellm_params.get("api_key") or default_api_key
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if (
|
||||||
|
api_key
|
||||||
|
and isinstance(api_key, str)
|
||||||
|
and api_key.startswith("os.environ/")
|
||||||
|
):
|
||||||
api_key_env_name = api_key.replace("os.environ/", "")
|
api_key_env_name = api_key.replace("os.environ/", "")
|
||||||
api_key = litellm.get_secret(api_key_env_name)
|
api_key = litellm.get_secret(api_key_env_name)
|
||||||
litellm_params["api_key"] = api_key
|
litellm_params["api_key"] = api_key
|
||||||
|
@ -2023,6 +2028,7 @@ class Router:
|
||||||
if (
|
if (
|
||||||
is_azure_ai_studio_model == True
|
is_azure_ai_studio_model == True
|
||||||
and api_base is not None
|
and api_base is not None
|
||||||
|
and isinstance(api_base, str)
|
||||||
and not api_base.endswith("/v1/")
|
and not api_base.endswith("/v1/")
|
||||||
):
|
):
|
||||||
# check if it ends with a trailing slash
|
# check if it ends with a trailing slash
|
||||||
|
@ -2103,13 +2109,14 @@ class Router:
|
||||||
organization = litellm.get_secret(organization_env_name)
|
organization = litellm.get_secret(organization_env_name)
|
||||||
litellm_params["organization"] = organization
|
litellm_params["organization"] = organization
|
||||||
|
|
||||||
if "azure" in model_name:
|
if "azure" in model_name and isinstance(api_key, str):
|
||||||
if api_base is None:
|
if api_base is None or not isinstance(api_base, str):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
|
||||||
)
|
)
|
||||||
if api_version is None:
|
if api_version is None:
|
||||||
api_version = "2023-07-01-preview"
|
api_version = "2023-07-01-preview"
|
||||||
|
|
||||||
if "gateway.ai.cloudflare.com" in api_base:
|
if "gateway.ai.cloudflare.com" in api_base:
|
||||||
if not api_base.endswith("/"):
|
if not api_base.endswith("/"):
|
||||||
api_base += "/"
|
api_base += "/"
|
||||||
|
@ -2532,7 +2539,7 @@ class Router:
|
||||||
self.default_deployment = deployment.to_json(exclude_none=True)
|
self.default_deployment = deployment.to_json(exclude_none=True)
|
||||||
|
|
||||||
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
# Azure GPT-Vision Enhancements, users can pass os.environ/
|
||||||
data_sources = deployment.litellm_params.get("dataSources", [])
|
data_sources = deployment.litellm_params.get("dataSources", []) or []
|
||||||
|
|
||||||
for data_source in data_sources:
|
for data_source in data_sources:
|
||||||
params = data_source.get("parameters", {})
|
params = data_source.get("parameters", {})
|
||||||
|
@ -2549,6 +2556,22 @@ class Router:
|
||||||
# init OpenAI, Azure clients
|
# init OpenAI, Azure clients
|
||||||
self.set_client(model=deployment.to_json(exclude_none=True))
|
self.set_client(model=deployment.to_json(exclude_none=True))
|
||||||
|
|
||||||
|
# set region (if azure model)
|
||||||
|
try:
|
||||||
|
if "azure" in deployment.litellm_params.model:
|
||||||
|
region = litellm.utils.get_model_region(
|
||||||
|
litellm_params=deployment.litellm_params, mode=None
|
||||||
|
)
|
||||||
|
|
||||||
|
deployment.litellm_params.region_name = region
|
||||||
|
except Exception as e:
|
||||||
|
verbose_router_logger.error(
|
||||||
|
"Unable to get the region for azure model - {}, {}".format(
|
||||||
|
deployment.litellm_params.model, str(e)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pass # [NON-BLOCKING]
|
||||||
|
|
||||||
return deployment
|
return deployment
|
||||||
|
|
||||||
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
|
||||||
|
@ -2820,14 +2843,17 @@ class Router:
|
||||||
model: str,
|
model: str,
|
||||||
healthy_deployments: List,
|
healthy_deployments: List,
|
||||||
messages: List[Dict[str, str]],
|
messages: List[Dict[str, str]],
|
||||||
|
allowed_model_region: Optional[Literal["eu"]] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Filter out model in model group, if:
|
Filter out model in model group, if:
|
||||||
|
|
||||||
- model context window < message length
|
- model context window < message length
|
||||||
- filter models above rpm limits
|
- filter models above rpm limits
|
||||||
|
- if region given, filter out models not in that region / unknown region
|
||||||
- [TODO] function call and model doesn't support function calling
|
- [TODO] function call and model doesn't support function calling
|
||||||
"""
|
"""
|
||||||
|
|
||||||
verbose_router_logger.debug(
|
verbose_router_logger.debug(
|
||||||
f"Starting Pre-call checks for deployments in model={model}"
|
f"Starting Pre-call checks for deployments in model={model}"
|
||||||
)
|
)
|
||||||
|
@ -2878,9 +2904,9 @@ class Router:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
|
||||||
|
|
||||||
## RPM CHECK ##
|
|
||||||
_litellm_params = deployment.get("litellm_params", {})
|
_litellm_params = deployment.get("litellm_params", {})
|
||||||
model_id = deployment.get("model_info", {}).get("id", "")
|
model_id = deployment.get("model_info", {}).get("id", "")
|
||||||
|
## RPM CHECK ##
|
||||||
### get local router cache ###
|
### get local router cache ###
|
||||||
current_request_cache_local = (
|
current_request_cache_local = (
|
||||||
self.cache.get_cache(key=model_id, local_only=True) or 0
|
self.cache.get_cache(key=model_id, local_only=True) or 0
|
||||||
|
@ -2908,6 +2934,28 @@ class Router:
|
||||||
_rate_limit_error = True
|
_rate_limit_error = True
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
## REGION CHECK ##
|
||||||
|
if allowed_model_region is not None:
|
||||||
|
if _litellm_params.get("region_name") is not None and isinstance(
|
||||||
|
_litellm_params["region_name"], str
|
||||||
|
):
|
||||||
|
# check if in allowed_model_region
|
||||||
|
if (
|
||||||
|
_is_region_eu(model_region=_litellm_params["region_name"])
|
||||||
|
== False
|
||||||
|
):
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
|
||||||
|
model_id, allowed_model_region
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# filter out since region unknown, and user wants to filter for specific region
|
||||||
|
invalid_model_indices.append(idx)
|
||||||
|
continue
|
||||||
|
|
||||||
if len(invalid_model_indices) == len(_returned_deployments):
|
if len(invalid_model_indices) == len(_returned_deployments):
|
||||||
"""
|
"""
|
||||||
- no healthy deployments available b/c context window checks or rate limit error
|
- no healthy deployments available b/c context window checks or rate limit error
|
||||||
|
@ -3047,8 +3095,29 @@ class Router:
|
||||||
|
|
||||||
# filter pre-call checks
|
# filter pre-call checks
|
||||||
if self.enable_pre_call_checks and messages is not None:
|
if self.enable_pre_call_checks and messages is not None:
|
||||||
|
_allowed_model_region = (
|
||||||
|
request_kwargs.get("allowed_model_region")
|
||||||
|
if request_kwargs is not None
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
|
||||||
|
if _allowed_model_region == "eu":
|
||||||
healthy_deployments = self._pre_call_checks(
|
healthy_deployments = self._pre_call_checks(
|
||||||
model=model, healthy_deployments=healthy_deployments, messages=messages
|
model=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
|
allowed_model_region=_allowed_model_region,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
verbose_router_logger.debug(
|
||||||
|
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
|
||||||
|
_allowed_model_region
|
||||||
|
)
|
||||||
|
)
|
||||||
|
healthy_deployments = self._pre_call_checks(
|
||||||
|
model=model,
|
||||||
|
healthy_deployments=healthy_deployments,
|
||||||
|
messages=messages,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(healthy_deployments) == 0:
|
if len(healthy_deployments) == 0:
|
||||||
|
|
|
@ -123,6 +123,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
)
|
)
|
||||||
max_retries: Optional[int] = None
|
max_retries: Optional[int] = None
|
||||||
organization: Optional[str] = None # for openai orgs
|
organization: Optional[str] = None # for openai orgs
|
||||||
|
## UNIFIED PROJECT/REGION ##
|
||||||
|
region_name: Optional[str] = None
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
vertex_location: Optional[str] = None
|
vertex_location: Optional[str] = None
|
||||||
|
@ -150,6 +152,8 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
None # timeout when making stream=True calls, if str, pass in as os.environ/
|
||||||
),
|
),
|
||||||
organization: Optional[str] = None, # for openai orgs
|
organization: Optional[str] = None, # for openai orgs
|
||||||
|
## UNIFIED PROJECT/REGION ##
|
||||||
|
region_name: Optional[str] = None,
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
vertex_project: Optional[str] = None,
|
vertex_project: Optional[str] = None,
|
||||||
vertex_location: Optional[str] = None,
|
vertex_location: Optional[str] = None,
|
||||||
|
|
|
@ -5866,6 +5866,40 @@ def calculate_max_parallel_requests(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _is_region_eu(model_region: str) -> bool:
|
||||||
|
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
|
||||||
|
for region in EU_Regions:
|
||||||
|
if "europe" in model_region.lower():
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def get_model_region(
|
||||||
|
litellm_params: LiteLLM_Params, mode: Optional[str]
|
||||||
|
) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Pass the litellm params for an azure model, and get back the region
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
"azure" in litellm_params.model
|
||||||
|
and isinstance(litellm_params.api_key, str)
|
||||||
|
and isinstance(litellm_params.api_base, str)
|
||||||
|
):
|
||||||
|
_model = litellm_params.model.replace("azure/", "")
|
||||||
|
response: dict = litellm.AzureChatCompletion().get_headers(
|
||||||
|
model=_model,
|
||||||
|
api_key=litellm_params.api_key,
|
||||||
|
api_base=litellm_params.api_base,
|
||||||
|
api_version=litellm_params.api_version or "2023-07-01-preview",
|
||||||
|
timeout=10,
|
||||||
|
mode=mode or "chat",
|
||||||
|
)
|
||||||
|
|
||||||
|
region: Optional[str] = response.get("x-ms-region", None)
|
||||||
|
return region
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Returns the api base used for calling the model.
|
Returns the api base used for calling the model.
|
||||||
|
|
|
@ -1,4 +1,9 @@
|
||||||
model_list:
|
model_list:
|
||||||
|
- model_name: gpt-3.5-turbo
|
||||||
|
litellm_params:
|
||||||
|
model: azure/gpt-35-turbo
|
||||||
|
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
|
||||||
|
api_key: os.environ/AZURE_EUROPE_API_KEY
|
||||||
- model_name: gpt-3.5-turbo
|
- model_name: gpt-3.5-turbo
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: azure/chatgpt-v-2
|
model: azure/chatgpt-v-2
|
||||||
|
|
|
@ -150,6 +150,8 @@ model LiteLLM_EndUserTable {
|
||||||
user_id String @id
|
user_id String @id
|
||||||
alias String? // admin-facing alias
|
alias String? // admin-facing alias
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
|
allowed_model_region String? // require all user requests to use models in this specific region
|
||||||
|
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
|
||||||
budget_id String?
|
budget_id String?
|
||||||
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
|
||||||
blocked Boolean @default(false)
|
blocked Boolean @default(false)
|
||||||
|
|
173
tests/test_end_users.py
Normal file
173
tests/test_end_users.py
Normal file
|
@ -0,0 +1,173 @@
|
||||||
|
# What is this?
|
||||||
|
## Unit tests for the /end_users/* endpoints
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
"""
|
||||||
|
- `/end_user/new`
|
||||||
|
- `/end_user/info`
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_completion_with_headers(session, key, model="gpt-4"):
|
||||||
|
url = "http://0.0.0.0:4000/chat/completions"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"model": model,
|
||||||
|
"messages": [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Hello!"},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
response_header_check(
|
||||||
|
response
|
||||||
|
) # calling the function to check response headers
|
||||||
|
|
||||||
|
raw_headers = response.raw_headers
|
||||||
|
raw_headers_json = {}
|
||||||
|
|
||||||
|
for (
|
||||||
|
item
|
||||||
|
) in (
|
||||||
|
response.raw_headers
|
||||||
|
): # ((b'date', b'Fri, 19 Apr 2024 21:17:29 GMT'), (), )
|
||||||
|
raw_headers_json[item[0].decode("utf-8")] = item[1].decode("utf-8")
|
||||||
|
|
||||||
|
return raw_headers_json
|
||||||
|
|
||||||
|
|
||||||
|
async def generate_key(
|
||||||
|
session,
|
||||||
|
i,
|
||||||
|
budget=None,
|
||||||
|
budget_duration=None,
|
||||||
|
models=["azure-models", "gpt-4", "dall-e-3"],
|
||||||
|
max_parallel_requests: Optional[int] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
team_id: Optional[str] = None,
|
||||||
|
calling_key="sk-1234",
|
||||||
|
):
|
||||||
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {calling_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
data = {
|
||||||
|
"models": models,
|
||||||
|
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||||
|
"duration": None,
|
||||||
|
"max_budget": budget,
|
||||||
|
"budget_duration": budget_duration,
|
||||||
|
"max_parallel_requests": max_parallel_requests,
|
||||||
|
"user_id": user_id,
|
||||||
|
"team_id": team_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"data: {data}")
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(f"Response {i} (Status code: {status}):")
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request {i} did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def new_end_user(
|
||||||
|
session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None
|
||||||
|
):
|
||||||
|
url = "http://0.0.0.0:4000/end_user/new"
|
||||||
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
|
data = {
|
||||||
|
"user_id": user_id,
|
||||||
|
"allowed_model_region": model_region,
|
||||||
|
"default_model": default_model,
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
|
||||||
|
print(f"Response {i} (Status code: {status}):")
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(f"Request {i} did not return a 200 status code: {status}")
|
||||||
|
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_user_new():
|
||||||
|
"""
|
||||||
|
Make 20 parallel calls to /user/new. Assert all worked.
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
tasks = [new_end_user(session, i, str(uuid.uuid4())) for i in range(1, 11)]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_end_user_specific_region():
|
||||||
|
"""
|
||||||
|
- Specify region user can make calls in
|
||||||
|
- Make a generic call
|
||||||
|
- assert returned api base is for model in region
|
||||||
|
|
||||||
|
Repeat 3 times
|
||||||
|
"""
|
||||||
|
key: str = ""
|
||||||
|
## CREATE USER ##
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
end_user_obj = await new_end_user(
|
||||||
|
session=session,
|
||||||
|
i=0,
|
||||||
|
user_id=str(uuid.uuid4()),
|
||||||
|
model_region="eu",
|
||||||
|
)
|
||||||
|
|
||||||
|
## MAKE CALL ##
|
||||||
|
key_gen = await generate_key(session=session, i=0, models=["gpt-3.5-turbo"])
|
||||||
|
|
||||||
|
key = key_gen["key"]
|
||||||
|
|
||||||
|
for _ in range(3):
|
||||||
|
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
|
||||||
|
|
||||||
|
print("SENDING USER PARAM - {}".format(end_user_obj["user_id"]))
|
||||||
|
result = await client.chat.completions.with_raw_response.create(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey!"}],
|
||||||
|
user=end_user_obj["user_id"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
result.headers.get("x-litellm-model-api-base")
|
||||||
|
== "https://my-endpoint-europe-berri-992.openai.azure.com/"
|
||||||
|
)
|
Loading…
Add table
Add a link
Reference in a new issue