feat(router.py): enable filtering model group by 'allowed_model_region'

This commit is contained in:
Krrish Dholakia 2024-05-08 22:10:17 -07:00
parent db666b01e5
commit 3d18897d69
11 changed files with 417 additions and 35 deletions

1
.gitignore vendored
View file

@ -55,3 +55,4 @@ litellm/proxy/_super_secret_config.yaml
litellm/proxy/_super_secret_config.yaml
litellm/proxy/myenv/bin/activate
litellm/proxy/myenv/bin/Activate.ps1
myenv/*

View file

@ -1,4 +1,4 @@
from typing import Optional, Union, Any
from typing import Optional, Union, Any, Literal
import types, requests
from .base import BaseLLM
from litellm.utils import (
@ -952,6 +952,81 @@ class AzureChatCompletion(BaseLLM):
)
raise e
def get_headers(
self,
model: Optional[str],
api_key: str,
api_base: str,
api_version: str,
timeout: float,
mode: str,
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
) -> dict:
client_session = litellm.client_session or httpx.Client(
transport=CustomHTTPTransport(), # handle dall-e-2 calls
)
if "gateway.ai.cloudflare.com" in api_base:
## build base url - assume api base includes resource name
if not api_base.endswith("/"):
api_base += "/"
api_base += f"{model}"
client = AzureOpenAI(
base_url=api_base,
api_version=api_version,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
model = None
# cloudflare ai gateway, needs model=None
else:
client = AzureOpenAI(
api_version=api_version,
azure_endpoint=api_base,
api_key=api_key,
timeout=timeout,
http_client=client_session,
)
# only run this check if it's not cloudflare ai gateway
if model is None and mode != "image_generation":
raise Exception("model is not set")
completion = None
if messages is None:
messages = [{"role": "user", "content": "Hey"}]
try:
completion = client.chat.completions.with_raw_response.create(
model=model, # type: ignore
messages=messages, # type: ignore
)
except Exception as e:
raise e
response = {}
if completion is None or not hasattr(completion, "headers"):
raise Exception("invalid completion response")
if (
completion.headers.get("x-ratelimit-remaining-requests", None) is not None
): # not provided for dall-e requests
response["x-ratelimit-remaining-requests"] = completion.headers[
"x-ratelimit-remaining-requests"
]
if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None:
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response
async def ahealth_check(
self,
model: Optional[str],
@ -963,7 +1038,7 @@ class AzureChatCompletion(BaseLLM):
messages: Optional[list] = None,
input: Optional[list] = None,
prompt: Optional[str] = None,
):
) -> dict:
client_session = litellm.aclient_session or httpx.AsyncClient(
transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls
)
@ -1040,4 +1115,8 @@ class AzureChatCompletion(BaseLLM):
response["x-ratelimit-remaining-tokens"] = completion.headers[
"x-ratelimit-remaining-tokens"
]
if completion.headers.get("x-ms-region", None) is not None:
response["x-ms-region"] = completion.headers["x-ms-region"]
return response

View file

@ -648,6 +648,8 @@ def completion(
"base_model",
"stream_timeout",
"supports_system_message",
"region_name",
"allowed_model_region",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -2716,6 +2718,8 @@ def embedding(
"ttl",
"cache",
"no-log",
"region_name",
"allowed_model_region",
]
default_params = openai_params + litellm_params
non_default_params = {
@ -3589,6 +3593,8 @@ def image_generation(
"caching_groups",
"ttl",
"cache",
"region_name",
"allowed_model_region",
]
default_params = openai_params + litellm_params
non_default_params = {

View file

@ -476,16 +476,6 @@ class NewEndUserRequest(LiteLLMBase):
if values.get("max_budget") is not None and values.get("budget_id") is not None:
raise ValueError("Set either 'max_budget' or 'budget_id', not both.")
if (
values.get("allowed_model_region") is not None
and values.get("default_model") is None
) or (
values.get("allowed_model_region") is None
and values.get("default_model") is not None
):
raise ValueError(
"If 'allowed_model_region' is set, then 'default_model' must be set."
)
return values
@ -867,6 +857,7 @@ class UserAPIKeyAuth(
api_key: Optional[str] = None
user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None
allowed_model_region: Optional[Literal["eu"]] = None
@root_validator(pre=True)
def check_api_key(cls, values):
@ -912,6 +903,8 @@ class LiteLLM_EndUserTable(LiteLLMBase):
blocked: bool
alias: Optional[str] = None
spend: float = 0.0
allowed_model_region: Optional[Literal["eu"]] = None
default_model: Optional[str] = None
litellm_budget_table: Optional[LiteLLM_BudgetTable] = None
@root_validator(pre=True)

View file

@ -208,7 +208,9 @@ async def get_end_user_object(
return None
# check if in cache
cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id)
cached_user_obj = user_api_key_cache.async_get_cache(
key="end_user_id:{}".format(end_user_id)
)
if cached_user_obj is not None:
if isinstance(cached_user_obj, dict):
return LiteLLM_EndUserTable(**cached_user_obj)
@ -223,7 +225,14 @@ async def get_end_user_object(
if response is None:
raise Exception
return LiteLLM_EndUserTable(**response.dict())
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key="end_user_id:{}".format(end_user_id), value=response
)
_response = LiteLLM_EndUserTable(**response.dict())
return _response
except Exception as e: # if end-user not in db
return None

View file

@ -472,10 +472,6 @@ async def user_api_key_auth(
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
# save the end-user object to cache
await user_api_key_cache.async_set_cache(
key=end_user_id, value=end_user_object
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
@ -957,13 +953,16 @@ async def user_api_key_auth(
_end_user_object = None
if "user" in request_data:
_id = "end_user_id:{}".format(request_data["user"])
_end_user_object = await user_api_key_cache.async_get_cache(key=_id)
if _end_user_object is not None:
_end_user_object = LiteLLM_EndUserTable(**_end_user_object)
_end_user_object = await get_end_user_object(
end_user_id=request_data["user"],
prisma_client=prisma_client,
user_api_key_cache=user_api_key_cache,
)
global_proxy_spend = None
if litellm.max_budget > 0: # user set proxy max budget
if (
litellm.max_budget > 0 and prisma_client is not None
): # user set proxy max budget
# check cache
global_proxy_spend = await user_api_key_cache.async_get_cache(
key="{}:spend".format(litellm_proxy_admin_name)
@ -1016,6 +1015,12 @@ async def user_api_key_auth(
)
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
if _end_user_object is not None:
valid_token_dict["allowed_model_region"] = (
_end_user_object.allowed_model_region
)
"""
asyncio create task to update the user api key cache with the user db table as well
@ -1040,10 +1045,7 @@ async def user_api_key_auth(
# check if user can access this route
query_params = request.query_params
key = query_params.get("key")
if (
key is not None
and prisma_client.hash_token(token=key) != api_key
):
if key is not None and hash_token(token=key) != api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info",
@ -1096,6 +1098,7 @@ async def user_api_key_auth(
# sso/login, ui/login, /key functions and /user functions
# this will never be allowed to call /chat/completions
token_team = getattr(valid_token, "team_id", None)
if token_team is not None and token_team == "litellm-dashboard":
# this token is only used for managing the ui
allowed_routes = [
@ -3617,6 +3620,10 @@ async def chat_completion(
**data,
} # add the team-specific configs to the completion call
### END-USER SPECIFIC PARAMS ###
if user_api_key_dict.allowed_model_region is not None:
data["allowed_model_region"] = user_api_key_dict.allowed_model_region
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
if user_temperature:

View file

@ -32,6 +32,7 @@ from litellm.utils import (
CustomStreamWrapper,
get_utc_datetime,
calculate_max_parallel_requests,
_is_region_eu,
)
import copy
from litellm._logging import verbose_router_logger
@ -1999,7 +2000,11 @@ class Router:
# user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env
# we do this here because we init clients for Azure, OpenAI and we need to set the right key
api_key = litellm_params.get("api_key") or default_api_key
if api_key and api_key.startswith("os.environ/"):
if (
api_key
and isinstance(api_key, str)
and api_key.startswith("os.environ/")
):
api_key_env_name = api_key.replace("os.environ/", "")
api_key = litellm.get_secret(api_key_env_name)
litellm_params["api_key"] = api_key
@ -2023,6 +2028,7 @@ class Router:
if (
is_azure_ai_studio_model == True
and api_base is not None
and isinstance(api_base, str)
and not api_base.endswith("/v1/")
):
# check if it ends with a trailing slash
@ -2103,13 +2109,14 @@ class Router:
organization = litellm.get_secret(organization_env_name)
litellm_params["organization"] = organization
if "azure" in model_name:
if api_base is None:
if "azure" in model_name and isinstance(api_key, str):
if api_base is None or not isinstance(api_base, str):
raise ValueError(
f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}"
)
if api_version is None:
api_version = "2023-07-01-preview"
if "gateway.ai.cloudflare.com" in api_base:
if not api_base.endswith("/"):
api_base += "/"
@ -2532,7 +2539,7 @@ class Router:
self.default_deployment = deployment.to_json(exclude_none=True)
# Azure GPT-Vision Enhancements, users can pass os.environ/
data_sources = deployment.litellm_params.get("dataSources", [])
data_sources = deployment.litellm_params.get("dataSources", []) or []
for data_source in data_sources:
params = data_source.get("parameters", {})
@ -2549,6 +2556,22 @@ class Router:
# init OpenAI, Azure clients
self.set_client(model=deployment.to_json(exclude_none=True))
# set region (if azure model)
try:
if "azure" in deployment.litellm_params.model:
region = litellm.utils.get_model_region(
litellm_params=deployment.litellm_params, mode=None
)
deployment.litellm_params.region_name = region
except Exception as e:
verbose_router_logger.error(
"Unable to get the region for azure model - {}, {}".format(
deployment.litellm_params.model, str(e)
)
)
pass # [NON-BLOCKING]
return deployment
def add_deployment(self, deployment: Deployment) -> Optional[Deployment]:
@ -2820,14 +2843,17 @@ class Router:
model: str,
healthy_deployments: List,
messages: List[Dict[str, str]],
allowed_model_region: Optional[Literal["eu"]] = None,
):
"""
Filter out model in model group, if:
- model context window < message length
- filter models above rpm limits
- if region given, filter out models not in that region / unknown region
- [TODO] function call and model doesn't support function calling
"""
verbose_router_logger.debug(
f"Starting Pre-call checks for deployments in model={model}"
)
@ -2878,9 +2904,9 @@ class Router:
except Exception as e:
verbose_router_logger.debug("An error occurs - {}".format(str(e)))
## RPM CHECK ##
_litellm_params = deployment.get("litellm_params", {})
model_id = deployment.get("model_info", {}).get("id", "")
## RPM CHECK ##
### get local router cache ###
current_request_cache_local = (
self.cache.get_cache(key=model_id, local_only=True) or 0
@ -2908,6 +2934,28 @@ class Router:
_rate_limit_error = True
continue
## REGION CHECK ##
if allowed_model_region is not None:
if _litellm_params.get("region_name") is not None and isinstance(
_litellm_params["region_name"], str
):
# check if in allowed_model_region
if (
_is_region_eu(model_region=_litellm_params["region_name"])
== False
):
invalid_model_indices.append(idx)
continue
else:
verbose_router_logger.debug(
"Filtering out model - {}, as model_region=None, and allowed_model_region={}".format(
model_id, allowed_model_region
)
)
# filter out since region unknown, and user wants to filter for specific region
invalid_model_indices.append(idx)
continue
if len(invalid_model_indices) == len(_returned_deployments):
"""
- no healthy deployments available b/c context window checks or rate limit error
@ -3047,10 +3095,31 @@ class Router:
# filter pre-call checks
if self.enable_pre_call_checks and messages is not None:
healthy_deployments = self._pre_call_checks(
model=model, healthy_deployments=healthy_deployments, messages=messages
_allowed_model_region = (
request_kwargs.get("allowed_model_region")
if request_kwargs is not None
else None
)
if _allowed_model_region == "eu":
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
allowed_model_region=_allowed_model_region,
)
else:
verbose_router_logger.debug(
"Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format(
_allowed_model_region
)
)
healthy_deployments = self._pre_call_checks(
model=model,
healthy_deployments=healthy_deployments,
messages=messages,
)
if len(healthy_deployments) == 0:
raise ValueError(
f"{RouterErrors.no_deployments_available.value}, passed model={model}"

View file

@ -123,6 +123,8 @@ class GenericLiteLLMParams(BaseModel):
)
max_retries: Optional[int] = None
organization: Optional[str] = None # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None
## VERTEX AI ##
vertex_project: Optional[str] = None
vertex_location: Optional[str] = None
@ -150,6 +152,8 @@ class GenericLiteLLMParams(BaseModel):
None # timeout when making stream=True calls, if str, pass in as os.environ/
),
organization: Optional[str] = None, # for openai orgs
## UNIFIED PROJECT/REGION ##
region_name: Optional[str] = None,
## VERTEX AI ##
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,

View file

@ -5844,6 +5844,40 @@ def calculate_max_parallel_requests(
return None
def _is_region_eu(model_region: str) -> bool:
EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"]
for region in EU_Regions:
if "europe" in model_region.lower():
return True
return False
def get_model_region(
litellm_params: LiteLLM_Params, mode: Optional[str]
) -> Optional[str]:
"""
Pass the litellm params for an azure model, and get back the region
"""
if (
"azure" in litellm_params.model
and isinstance(litellm_params.api_key, str)
and isinstance(litellm_params.api_base, str)
):
_model = litellm_params.model.replace("azure/", "")
response: dict = litellm.AzureChatCompletion().get_headers(
model=_model,
api_key=litellm_params.api_key,
api_base=litellm_params.api_base,
api_version=litellm_params.api_version or "2023-07-01-preview",
timeout=10,
mode=mode or "chat",
)
region: Optional[str] = response.get("x-ms-region", None)
return region
return None
def get_api_base(model: str, optional_params: dict) -> Optional[str]:
"""
Returns the api base used for calling the model.
@ -9423,7 +9457,9 @@ def get_secret(
else:
secret = os.environ.get(secret_name)
try:
secret_value_as_bool = ast.literal_eval(secret) if secret is not None else None
secret_value_as_bool = (
ast.literal_eval(secret) if secret is not None else None
)
if isinstance(secret_value_as_bool, bool):
return secret_value_as_bool
else:

View file

@ -1,4 +1,9 @@
model_list:
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/gpt-35-turbo
api_base: https://my-endpoint-europe-berri-992.openai.azure.com/
api_key: os.environ/AZURE_EUROPE_API_KEY
- model_name: gpt-3.5-turbo
litellm_params:
model: azure/chatgpt-v-2

173
tests/test_end_users.py Normal file
View file

@ -0,0 +1,173 @@
# What is this?
## Unit tests for the /end_users/* endpoints
import pytest
import asyncio
import aiohttp
import time
import uuid
from openai import AsyncOpenAI
from typing import Optional
"""
- `/end_user/new`
- `/end_user/info`
"""
async def chat_completion_with_headers(session, key, model="gpt-4"):
url = "http://0.0.0.0:4000/chat/completions"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model": model,
"messages": [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Hello!"},
],
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
response_header_check(
response
) # calling the function to check response headers
raw_headers = response.raw_headers
raw_headers_json = {}
for (
item
) in (
response.raw_headers
): # ((b'date', b'Fri, 19 Apr 2024 21:17:29 GMT'), (), )
raw_headers_json[item[0].decode("utf-8")] = item[1].decode("utf-8")
return raw_headers_json
async def generate_key(
session,
i,
budget=None,
budget_duration=None,
models=["azure-models", "gpt-4", "dall-e-3"],
max_parallel_requests: Optional[int] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
calling_key="sk-1234",
):
url = "http://0.0.0.0:4000/key/generate"
headers = {
"Authorization": f"Bearer {calling_key}",
"Content-Type": "application/json",
}
data = {
"models": models,
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
"duration": None,
"max_budget": budget,
"budget_duration": budget_duration,
"max_parallel_requests": max_parallel_requests,
"user_id": user_id,
"team_id": team_id,
}
print(f"data: {data}")
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
async def new_end_user(
session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None
):
url = "http://0.0.0.0:4000/end_user/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"user_id": user_id,
"allowed_model_region": model_region,
"default_model": default_model,
}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Response {i} (Status code: {status}):")
print(response_text)
print()
if status != 200:
raise Exception(f"Request {i} did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_end_user_new():
"""
Make 20 parallel calls to /user/new. Assert all worked.
"""
async with aiohttp.ClientSession() as session:
tasks = [new_end_user(session, i, str(uuid.uuid4())) for i in range(1, 11)]
await asyncio.gather(*tasks)
@pytest.mark.asyncio
async def test_end_user_specific_region():
"""
- Specify region user can make calls in
- Make a generic call
- assert returned api base is for model in region
Repeat 3 times
"""
key: str = ""
## CREATE USER ##
async with aiohttp.ClientSession() as session:
end_user_obj = await new_end_user(
session=session,
i=0,
user_id=str(uuid.uuid4()),
model_region="eu",
)
## MAKE CALL ##
key_gen = await generate_key(session=session, i=0, models=["gpt-3.5-turbo"])
key = key_gen["key"]
for _ in range(3):
client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000")
print("SENDING USER PARAM - {}".format(end_user_obj["user_id"]))
result = await client.chat.completions.with_raw_response.create(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey!"}],
user=end_user_obj["user_id"],
)
assert (
result.headers.get("x-litellm-model-api-base")
== "https://my-endpoint-europe-berri-992.openai.azure.com/"
)