Merge pull request #3536 from BerriAI/litellm_region_based_routing

feat(proxy_server.py): add CRUD endpoints for 'end_user' management
This commit is contained in:
Krish Dholakia 2024-05-08 22:23:40 -07:00 committed by GitHub
commit 66a1b581e5
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
14 changed files with 606 additions and 44 deletions

4
.gitignore vendored
View file

@ -1,5 +1,6 @@
.venv .venv
.env .env
litellm/proxy/myenv/*
litellm_uuid.txt litellm_uuid.txt
__pycache__/ __pycache__/
*.pyc *.pyc
@ -52,3 +53,6 @@ litellm/proxy/_new_secret_config.yaml
litellm/proxy/_new_secret_config.yaml litellm/proxy/_new_secret_config.yaml
litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml
litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml
litellm/proxy/myenv/bin/activate
litellm/proxy/myenv/bin/Activate.ps1
myenv/*

View file

@ -1,5 +1,5 @@
from typing import Optional, Union, Any from typing import Optional, Union, Any, Literal
import types, requests # type: ignore import types, requests
from .base import BaseLLM from .base import BaseLLM
from litellm.utils import ( from litellm.utils import (
ModelResponse, ModelResponse,
@ -952,6 +952,81 @@ class AzureChatCompletion(BaseLLM):
) )
raise e raise e
def get_headers(
self,
model: Optional[str],
api_key: str,
api_base: str,
api_version: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(), # handle dall-e-2 calls
)
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if messages is None:
messages = [{"role": "user", "content": "Hey"}]
try:
completion = client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
except Exception as e:
raise e
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
async def ahealth_check( async def ahealth_check(
self, self,
model: Optional[str], model: Optional[str],
@ -963,7 +1038,7 @@ class AzureChatCompletion(BaseLLM):
messages: Optional[list] = None, messages: Optional[list] = None,
input: Optional[list] = None, input: Optional[list] = None,
prompt: Optional[str] = None, prompt: Optional[str] = None,
): ) -> dict:
client_session = litellm.aclient_session or httpx.AsyncClient( client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
) )
@ -1040,4 +1115,8 @@ class AzureChatCompletion(BaseLLM):
response["x-ratelimit-remaining-tokens"] = completion.headers[ response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens" "x-ratelimit-remaining-tokens"
] ]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response return response

View file

@ -651,6 +651,8 @@ def completion(
"base_model", "base_model",
"stream_timeout", "stream_timeout",
"supports_system_message", "supports_system_message",
"region_name",
"allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -2721,6 +2723,8 @@ def embedding(
"ttl", "ttl",
"cache", "cache",
"no-log", "no-log",
"region_name",
"allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {
@ -3595,6 +3599,8 @@ def image_generation(
"caching_groups", "caching_groups",
"ttl", "ttl",
"cache", "cache",
"region_name",
"allowed_model_region",
] ]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = { non_default_params = {

View file

@ -458,6 +458,27 @@ class UpdateUserRequest(GenerateRequestBase):
return values return values
class NewEndUserRequest(LiteLLMBase):
user_id: str
alias: Optional[str] = None # human-friendly alias
blocked: bool = False # allow/disallow requests for this end-user
max_budget: Optional[float] = None
budget_id: Optional[str] = None # give either a budget_id or max_budget
allowed_model_region: Optional[Literal["eu"]] = (
None # require all user requests to use models in this specific region
)
default_model: Optional[str] = (
None # if no equivalent model in allowed region - default all requests to this model
)
@root_validator(pre=True)
def check_user_info(cls, values):
if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
return values
class Member(LiteLLMBase): class Member(LiteLLMBase):
role: Literal["admin", "user"] role: Literal["admin", "user"]
user_id: Optional[str] = None user_id: Optional[str] = None
@ -838,6 +859,7 @@ class UserAPIKeyAuth(
api_key: Optional[str] = None api_key: Optional[str] = None
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True) @root_validator(pre=True)
def check_api_key(cls, values): def check_api_key(cls, values):
@ -883,6 +905,8 @@ class LiteLLM_EndUserTable(LiteLLMBase):
blocked: bool blocked: bool
alias: Optional[str] = None alias: Optional[str] = None
spend: float = 0.0 spend: float = 0.0
allowed_model_region: Optional[Literal["eu"]] = None
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True) @root_validator(pre=True)

View file

@ -208,7 +208,9 @@ async def get_end_user_object(
return None return None
# check if in cache # check if in cache
cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id) cached_user_obj = user_api_key_cache.async_get_cache(
key="end_user_id:{}".format(end_user_id)
)
if cached_user_obj is not None: if cached_user_obj is not None:
if isinstance(cached_user_obj, dict): if isinstance(cached_user_obj, dict):
return LiteLLM_EndUserTable(**cached_user_obj) return LiteLLM_EndUserTable(**cached_user_obj)
@ -223,7 +225,14 @@ async def get_end_user_object(
if response is None: if response is None:
raise Exception raise Exception
return LiteLLM_EndUserTable(**response.dict()) # save the end-user object to cache
await user_api_key_cache.async_set_cache(
key="end_user_id:{}".format(end_user_id), value=response
)
_response = LiteLLM_EndUserTable(**response.dict())
return _response
except Exception as e: # if end-user not in db except Exception as e: # if end-user not in db
return None return None

View file

@ -231,6 +231,11 @@ class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models" all_team_models = "all-team-models"
class CommonProxyErrors(enum.Enum):
db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy"
@app.exception_handler(ProxyException) @app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException): async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
@ -467,10 +472,6 @@ async def user_api_key_auth(
prisma_client=prisma_client, prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache, user_api_key_cache=user_api_key_cache,
) )
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=end_user_id, value=end_user_object
)
global_proxy_spend = None global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget if litellm.max_budget > 0: # user set proxy max budget
@ -952,13 +953,16 @@ async def user_api_key_auth(
_end_user_object = None _end_user_object = None
if "user" in request_data: if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"]) _end_user_object = await get_end_user_object(
_end_user_object = await user_api_key_cache.async_get_cache(key=_id) end_user_id=request_data["user"],
if _end_user_object is not None: prisma_client=prisma_client,
_end_user_object = LiteLLM_EndUserTable(**_end_user_object) user_api_key_cache=user_api_key_cache,
)
global_proxy_spend = None global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget if (
litellm.max_budget > 0 and prisma_client is not None
): # user set proxy max budget
# check cache # check cache
global_proxy_spend = await user_api_key_cache.async_get_cache( global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name) key="{}:spend".format(litellm_proxy_admin_name)
@ -1011,6 +1015,12 @@ async def user_api_key_auth(
) )
valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
if _end_user_object is not None:
valid_token_dict["allowed_model_region"] = (
_end_user_object.allowed_model_region
)
""" """
asyncio create task to update the user api key cache with the user db table as well asyncio create task to update the user api key cache with the user db table as well
@ -1035,10 +1045,7 @@ async def user_api_key_auth(
# check if user can access this route # check if user can access this route
query_params = request.query_params query_params = request.query_params
key = query_params.get("key") key = query_params.get("key")
if ( if key is not None and hash_token(token=key) != api_key:
key is not None
and prisma_client.hash_token(token=key) != api_key
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info", detail="user not allowed to access this key's info",
@ -1091,6 +1098,7 @@ async def user_api_key_auth(
# sso/login, ui/login, /key functions and /user functions # sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions # this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None) token_team = getattr(valid_token, "team_id", None)
if token_team is not None and token_team == "litellm-dashboard": if token_team is not None and token_team == "litellm-dashboard":
# this token is only used for managing the ui # this token is only used for managing the ui
allowed_routes = [ allowed_routes = [
@ -3612,6 +3620,10 @@ async def chat_completion(
**data, **data,
} # add the team-specific configs to the completion call } # add the team-specific configs to the completion call
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
global user_temperature, user_request_timeout, user_max_tokens, user_api_base global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli # override with user settings, these are params passed via cli
if user_temperature: if user_temperature:
@ -5940,7 +5952,7 @@ async def global_predict_spend_logs(request: Request):
return _forecast_daily_cost(data) return _forecast_daily_cost(data)
#### USER MANAGEMENT #### #### INTERNAL USER MANAGEMENT ####
@router.post( @router.post(
"/user/new", "/user/new",
tags=["user management"], tags=["user management"],
@ -6433,6 +6445,43 @@ async def user_get_requests():
) )
@router.get(
"/user/get_users",
tags=["user management"],
dependencies=[Depends(user_api_key_auth)],
)
async def get_users(
role: str = fastapi.Query(
default=None,
description="Either 'proxy_admin', 'proxy_viewer', 'app_owner', 'app_user'",
)
):
"""
[BETA] This could change without notice. Give feedback - https://github.com/BerriAI/litellm/issues
Get all users who are a specific `user_role`.
Used by the UI to populate the user lists.
Currently - admin-only endpoint.
"""
global prisma_client
if prisma_client is None:
raise HTTPException(
status_code=500,
detail={"error": f"No db connected. prisma client={prisma_client}"},
)
all_users = await prisma_client.get_data(
table_name="user", query_type="find_all", key_val={"user_role": role}
)
return all_users
#### END-USER MANAGEMENT ####
@router.post( @router.post(
"/end_user/block", "/end_user/block",
tags=["End User Management"], tags=["End User Management"],
@ -6523,38 +6572,140 @@ async def unblock_user(data: BlockUsers):
return {"blocked_users": litellm.blocked_user_list} return {"blocked_users": litellm.blocked_user_list}
@router.get( @router.post(
"/user/get_users", "/end_user/new",
tags=["user management"], tags=["End User Management"],
dependencies=[Depends(user_api_key_auth)], dependencies=[Depends(user_api_key_auth)],
) )
async def get_users( async def new_end_user(
role: str = fastapi.Query( data: NewEndUserRequest,
default=None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
description="Either 'proxy_admin', 'proxy_viewer', 'app_owner', 'app_user'",
)
): ):
""" """
[BETA] This could change without notice. Give feedback - https://github.com/BerriAI/litellm/issues [TODO] Needs to be implemented.
Get all users who are a specific `user_role`. Allow creating a new end-user
Used by the UI to populate the user lists. - Allow specifying allowed regions
- Allow specifying default model
Currently - admin-only endpoint. Example curl:
```
curl --location 'http://0.0.0.0:4000/end_user/new' \
--header 'Authorization: Bearer sk-1234' \
--header 'Content-Type: application/json' \
--data '{
"end_user_id" : "ishaan-jaff-3", <- specific customer
"allowed_region": "eu" <- set region for models
+
"default_model": "azure/gpt-3.5-turbo-eu" <- all calls from this user, use this model?
}'
# return end-user object
```
""" """
global prisma_client global prisma_client, llm_router
"""
Validation:
- check if default model exists
- create budget object if not already created
- Add user to end user table
Return
- end-user object
- currently allowed models
"""
if prisma_client is None: if prisma_client is None:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail={"error": f"No db connected. prisma client={prisma_client}"}, detail={"error": CommonProxyErrors.db_not_connected_error.value},
)
all_users = await prisma_client.get_data(
table_name="user", query_type="find_all", key_val={"user_role": role}
) )
return all_users ## VALIDATION ##
if data.default_model is not None:
if llm_router is None:
raise HTTPException(
status_code=422, detail={"error": CommonProxyErrors.no_llm_router.value}
)
elif data.default_model not in llm_router.get_model_names():
raise HTTPException(
status_code=422,
detail={
"error": "Default Model not on proxy. Configure via `/model/new` or config.yaml. Default_model={}, proxy_model_names={}".format(
data.default_model, set(llm_router.get_model_names())
)
},
)
new_end_user_obj: Dict = {}
## CREATE BUDGET ## if set
if data.max_budget is not None:
budget_record = await prisma_client.db.litellm_budgettable.create(
data={
"max_budget": data.max_budget,
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, # type: ignore
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
}
)
new_end_user_obj["budget_id"] = budget_record.budget_id
elif data.budget_id is not None:
new_end_user_obj["budget_id"] = data.budget_id
_user_data = data.dict(exclude_none=True)
for k, v in _user_data.items():
if k != "max_budget" and k != "budget_id":
new_end_user_obj[k] = v
## WRITE TO DB ##
end_user_record = await prisma_client.db.litellm_endusertable.create(
data=new_end_user_obj # type: ignore
)
return end_user_record
@router.post(
"/end_user/info",
tags=["End User Management"],
dependencies=[Depends(user_api_key_auth)],
)
async def end_user_info():
"""
[TODO] Needs to be implemented.
"""
pass
@router.post(
"/end_user/update",
tags=["End User Management"],
dependencies=[Depends(user_api_key_auth)],
)
async def update_end_user():
"""
[TODO] Needs to be implemented.
"""
pass
@router.post(
"/end_user/delete",
tags=["End User Management"],
dependencies=[Depends(user_api_key_auth)],
)
async def delete_end_user():
"""
[TODO] Needs to be implemented.
"""
pass
#### TEAM MANAGEMENT #### #### TEAM MANAGEMENT ####

View file

@ -150,6 +150,8 @@ model LiteLLM_EndUserTable {
user_id String @id user_id String @id
alias String? // admin-facing alias alias String? // admin-facing alias
spend Float @default(0.0) spend Float @default(0.0)
allowed_model_region String? // require all user requests to use models in this specific region
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
budget_id String? budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
blocked Boolean @default(false) blocked Boolean @default(false)

View file

@ -526,7 +526,7 @@ class PrismaClient:
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
# Now you can import the Prisma Client # Now you can import the Prisma Client
from prisma import Prisma # type: ignore from prisma import Prisma
self.db = Prisma() # Client to connect to Prisma db self.db = Prisma() # Client to connect to Prisma db

View file

@ -32,6 +32,7 @@ from litellm.utils import (
CustomStreamWrapper, CustomStreamWrapper,
get_utc_datetime, get_utc_datetime,
calculate_max_parallel_requests, calculate_max_parallel_requests,
_is_region_eu,
) )
import copy import copy
from litellm._logging import verbose_router_logger from litellm._logging import verbose_router_logger
@ -1999,7 +2000,11 @@ class Router:
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key # we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"): if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "") api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name) api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key litellm_params["api_key"] = api_key
@ -2023,6 +2028,7 @@ class Router:
if ( if (
is_azure_ai_studio_model == True is_azure_ai_studio_model == True
and api_base is not None and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/") and not api_base.endswith("/v1/")
): ):
# check if it ends with a trailing slash # check if it ends with a trailing slash
@ -2103,13 +2109,14 @@ class Router:
organization = litellm.get_secret(organization_env_name) organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization litellm_params["organization"] = organization
if "azure" in model_name: if "azure" in model_name and isinstance(api_key, str):
if api_base is None: if api_base is None or not isinstance(api_base, str):
raise ValueError( raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
) )
if api_version is None: if api_version is None:
api_version = "2023-07-01-preview" api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base: if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"): if not api_base.endswith("/"):
api_base += "/" api_base += "/"
@ -2532,7 +2539,7 @@ class Router:
self.default_deployment = deployment.to_json(exclude_none=True) self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/ # Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", []) data_sources = deployment.litellm_params.get("dataSources", []) or []
for data_source in data_sources: for data_source in data_sources:
params = data_source.get("parameters", {}) params = data_source.get("parameters", {})
@ -2549,6 +2556,22 @@ class Router:
# init OpenAI, Azure clients # init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True)) self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model)
try:
if "azure" in deployment.litellm_params.model:
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region
except Exception as e:
verbose_router_logger.error(
"Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e)
)
)
pass # [NON-BLOCKING]
return deployment return deployment
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
@ -2820,14 +2843,17 @@ class Router:
model: str, model: str,
healthy_deployments: List, healthy_deployments: List,
messages: List[Dict[str, str]], messages: List[Dict[str, str]],
allowed_model_region: Optional[Literal["eu"]] = None,
): ):
""" """
Filter out model in model group, if: Filter out model in model group, if:
- model context window < message length - model context window < message length
- filter models above rpm limits - filter models above rpm limits
- if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling - [TODO] function call and model doesn't support function calling
""" """
verbose_router_logger.debug( verbose_router_logger.debug(
f"Starting Pre-call checks for deployments in model={model}" f"Starting Pre-call checks for deployments in model={model}"
) )
@ -2878,9 +2904,9 @@ class Router:
except Exception as e: except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e))) verbose_router_logger.debug("An error occurs - {}".format(str(e)))
## RPM CHECK ##
_litellm_params = deployment.get("litellm_params", {}) _litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "") model_id = deployment.get("model_info", {}).get("id", "")
## RPM CHECK ##
### get local router cache ### ### get local router cache ###
current_request_cache_local = ( current_request_cache_local = (
self.cache.get_cache(key=model_id, local_only=True) or 0 self.cache.get_cache(key=model_id, local_only=True) or 0
@ -2908,6 +2934,28 @@ class Router:
_rate_limit_error = True _rate_limit_error = True
continue continue
## REGION CHECK ##
if allowed_model_region is not None:
if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str
):
# check if in allowed_model_region
if (
_is_region_eu(model_region=_litellm_params["region_name"])
== False
):
invalid_model_indices.append(idx)
continue
else:
verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, allowed_model_region
)
)
# filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx)
continue
if len(invalid_model_indices) == len(_returned_deployments): if len(invalid_model_indices) == len(_returned_deployments):
""" """
- no healthy deployments available b/c context window checks or rate limit error - no healthy deployments available b/c context window checks or rate limit error
@ -3047,8 +3095,29 @@ class Router:
# filter pre-call checks # filter pre-call checks
if self.enable_pre_call_checks and messages is not None: if self.enable_pre_call_checks and messages is not None:
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks( healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages model=model,
healthy_deployments=healthy_deployments,
messages=messages,
allowed_model_region=_allowed_model_region,
)
else:
verbose_router_logger.debug(
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
_allowed_model_region
)
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
) )
if len(healthy_deployments) == 0: if len(healthy_deployments) == 0:

View file

@ -123,6 +123,8 @@ class GenericLiteLLMParams(BaseModel):
) )
max_retries: Optional[int] = None max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs organization: Optional[str] = None # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None
## VERTEX AI ## ## VERTEX AI ##
vertex_project: Optional[str] = None vertex_project: Optional[str] = None
vertex_location: Optional[str] = None vertex_location: Optional[str] = None
@ -150,6 +152,8 @@ class GenericLiteLLMParams(BaseModel):
None # timeout when making stream=True calls, if str, pass in as os.environ/ None # timeout when making stream=True calls, if str, pass in as os.environ/
), ),
organization: Optional[str] = None, # for openai orgs organization: Optional[str] = None, # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None,
## VERTEX AI ## ## VERTEX AI ##
vertex_project: Optional[str] = None, vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None, vertex_location: Optional[str] = None,

View file

@ -5866,6 +5866,40 @@ def calculate_max_parallel_requests(
return None return None
def _is_region_eu(model_region: str) -> bool:
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
for region in EU_Regions:
if "europe" in model_region.lower():
return True
return False
def get_model_region(
litellm_params: LiteLLM_Params, mode: Optional[str]
) -> Optional[str]:
"""
Pass the litellm params for an azure model, and get back the region
"""
if (
"azure" in litellm_params.model
and isinstance(litellm_params.api_key, str)
and isinstance(litellm_params.api_base, str)
):
_model = litellm_params.model.replace("azure/", "")
response: dict = litellm.AzureChatCompletion().get_headers(
model=_model,
api_key=litellm_params.api_key,
api_base=litellm_params.api_base,
api_version=litellm_params.api_version or "2023-07-01-preview",
timeout=10,
mode=mode or "chat",
)
region: Optional[str] = response.get("x-ms-region", None)
return region
return None
def get_api_base(model: str, optional_params: dict) -> Optional[str]: def get_api_base(model: str, optional_params: dict) -> Optional[str]:
""" """
Returns the api base used for calling the model. Returns the api base used for calling the model.

View file

@ -1,4 +1,9 @@
model_list: model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-turbo
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
- model_name: gpt-3.5-turbo - model_name: gpt-3.5-turbo
litellm_params: litellm_params:
model: azure/chatgpt-v-2 model: azure/chatgpt-v-2

View file

@ -150,6 +150,8 @@ model LiteLLM_EndUserTable {
user_id String @id user_id String @id
alias String? // admin-facing alias alias String? // admin-facing alias
spend Float @default(0.0) spend Float @default(0.0)
allowed_model_region String? // require all user requests to use models in this specific region
default_model String? // use along with 'allowed_model_region'. if no available model in region, default to this model.
budget_id String? budget_id String?
litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id]) litellm_budget_table LiteLLM_BudgetTable? @relation(fields: [budget_id], references: [budget_id])
blocked Boolean @default(false) blocked Boolean @default(false)

173
tests/test_end_users.py Normal file
View file

@ -0,0 +1,173 @@
# What is this?
## Unit tests for the /end_users/* endpoints
import pytest
import asyncio
import aiohttp
import time
import uuid
from openai import AsyncOpenAI
from typing import Optional
"""
- `/end_user/new`
- `/end_user/info`
"""
async def chat_completion_with_headers(session, key, model="gpt-4"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
response_header_check(
response
) # calling the function to check response headers
raw_headers = response.raw_headers
raw_headers_json = {}
for (
item
) in (
response.raw_headers
): # ((b'date', b'Fri, 19 Apr 2024 21:17:29 GMT'), (), )
raw_headers_json[item[0].decode("utf-8")] = item[1].decode("utf-8")
return raw_headers_json
async def generate_key(
session,
i,
budget=None,
budget_duration=None,
models=["azure-models", "gpt-4", "dall-e-3"],
max_parallel_requests: Optional[int] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
calling_key="sk-1234",
):
url = "http://0.0.0.0:4000/key/generate"
headers = {
"Authorization": f"Bearer {calling_key}",
"Content-Type": "application/json",
}
data = {
"models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
"max_parallel_requests": max_parallel_requests,
"user_id": user_id,
"team_id": team_id,
}
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
async def new_end_user(
session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None
):
url = "http://0.0.0.0:4000/end_user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"user_id": user_id,
"allowed_model_region": model_region,
"default_model": default_model,
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_end_user_new():
"""
Make 20 parallel calls to /user/new. Assert all worked.
"""
async with aiohttp.ClientSession() as session:
tasks = [new_end_user(session, i, str(uuid.uuid4())) for i in range(1, 11)]
await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_end_user_specific_region():
"""
- Specify region user can make calls in
- Make a generic call
- assert returned api base is for model in region
Repeat 3 times
"""
key: str = ""
## CREATE USER ##
async with aiohttp.ClientSession() as session:
end_user_obj = await new_end_user(
session=session,
i=0,
user_id=str(uuid.uuid4()),
model_region="eu",
)
## MAKE CALL ##
key_gen = await generate_key(session=session, i=0, models=["gpt-3.5-turbo"])
key = key_gen["key"]
for _ in range(3):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
print("SENDING USER PARAM - {}".format(end_user_obj["user_id"]))
result = await client.chat.completions.with_raw_response.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
user=end_user_obj["user_id"],
)
assert (
result.headers.get("x-litellm-model-api-base")
== "https://my-endpoint-europe-berri-992.openai.azure.com/"
)