From 3d18897d69e017b1ea2a96f0b6098db5359b3500 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 8 May 2024 22:10:17 -0700 Subject: [PATCH] feat(router.py): enable filtering model group by 'allowed_model_region' --- .gitignore | 1 + litellm/llms/azure.py | 83 +++++++++++++- litellm/main.py | 6 ++ litellm/proxy/_types.py | 13 +-- litellm/proxy/auth/auth_checks.py | 13 ++- litellm/proxy/proxy_server.py | 33 +++--- litellm/router.py | 83 ++++++++++++-- litellm/types/router.py | 4 + litellm/utils.py | 38 ++++++- proxy_server_config.yaml | 5 + tests/test_end_users.py | 173 ++++++++++++++++++++++++++++++ 11 files changed, 417 insertions(+), 35 deletions(-) create mode 100644 tests/test_end_users.py diff --git a/.gitignore b/.gitignore index 1f827e463..b75a92309 100644 --- a/.gitignore +++ b/.gitignore @@ -55,3 +55,4 @@ litellm/proxy/_super_secret_config.yaml litellm/proxy/_super_secret_config.yaml litellm/proxy/myenv/bin/activate litellm/proxy/myenv/bin/Activate.ps1 +myenv/* \ No newline at end of file diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index e7af9d43b..c3e024525 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Any +from typing import Optional, Union, Any, Literal import types, requests from .base import BaseLLM from litellm.utils import ( @@ -952,6 +952,81 @@ class AzureChatCompletion(BaseLLM): ) raise e + def get_headers( + self, + model: Optional[str], + api_key: str, + api_base: str, + api_version: str, + timeout: float, + mode: str, + messages: Optional[list] = None, + input: Optional[list] = None, + prompt: Optional[str] = None, + ) -> dict: + client_session = litellm.client_session or httpx.Client( + transport=CustomHTTPTransport(), # handle dall-e-2 calls + ) + if "gateway.ai.cloudflare.com" in api_base: + ## build base url - assume api base includes resource name + if not api_base.endswith("/"): + api_base += "/" + api_base += f"{model}" + client = AzureOpenAI( + base_url=api_base, + api_version=api_version, + api_key=api_key, + timeout=timeout, + http_client=client_session, + ) + model = None + # cloudflare ai gateway, needs model=None + else: + client = AzureOpenAI( + api_version=api_version, + azure_endpoint=api_base, + api_key=api_key, + timeout=timeout, + http_client=client_session, + ) + + # only run this check if it's not cloudflare ai gateway + if model is None and mode != "image_generation": + raise Exception("model is not set") + + completion = None + + if messages is None: + messages = [{"role": "user", "content": "Hey"}] + try: + completion = client.chat.completions.with_raw_response.create( + model=model, # type: ignore + messages=messages, # type: ignore + ) + except Exception as e: + raise e + response = {} + + if completion is None or not hasattr(completion, "headers"): + raise Exception("invalid completion response") + + if ( + completion.headers.get("x-ratelimit-remaining-requests", None) is not None + ): # not provided for dall-e requests + response["x-ratelimit-remaining-requests"] = completion.headers[ + "x-ratelimit-remaining-requests" + ] + + if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: + response["x-ratelimit-remaining-tokens"] = completion.headers[ + "x-ratelimit-remaining-tokens" + ] + + if completion.headers.get("x-ms-region", None) is not None: + response["x-ms-region"] = completion.headers["x-ms-region"] + + return response + async def ahealth_check( self, model: Optional[str], @@ -963,7 +1038,7 @@ class AzureChatCompletion(BaseLLM): messages: Optional[list] = None, input: Optional[list] = None, prompt: Optional[str] = None, - ): + ) -> dict: client_session = litellm.aclient_session or httpx.AsyncClient( transport=AsyncCustomHTTPTransport(), # handle dall-e-2 calls ) @@ -1040,4 +1115,8 @@ class AzureChatCompletion(BaseLLM): response["x-ratelimit-remaining-tokens"] = completion.headers[ "x-ratelimit-remaining-tokens" ] + + if completion.headers.get("x-ms-region", None) is not None: + response["x-ms-region"] = completion.headers["x-ms-region"] + return response diff --git a/litellm/main.py b/litellm/main.py index bff9886ac..99e5ec224 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -648,6 +648,8 @@ def completion( "base_model", "stream_timeout", "supports_system_message", + "region_name", + "allowed_model_region", ] default_params = openai_params + litellm_params non_default_params = { @@ -2716,6 +2718,8 @@ def embedding( "ttl", "cache", "no-log", + "region_name", + "allowed_model_region", ] default_params = openai_params + litellm_params non_default_params = { @@ -3589,6 +3593,8 @@ def image_generation( "caching_groups", "ttl", "cache", + "region_name", + "allowed_model_region", ] default_params = openai_params + litellm_params non_default_params = { diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 37e00e27e..e775070b4 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -476,16 +476,6 @@ class NewEndUserRequest(LiteLLMBase): if values.get("max_budget") is not None and values.get("budget_id") is not None: raise ValueError("Set either 'max_budget' or 'budget_id', not both.") - if ( - values.get("allowed_model_region") is not None - and values.get("default_model") is None - ) or ( - values.get("allowed_model_region") is None - and values.get("default_model") is not None - ): - raise ValueError( - "If 'allowed_model_region' is set, then 'default_model' must be set." - ) return values @@ -867,6 +857,7 @@ class UserAPIKeyAuth( api_key: Optional[str] = None user_role: Optional[Literal["proxy_admin", "app_owner", "app_user"]] = None + allowed_model_region: Optional[Literal["eu"]] = None @root_validator(pre=True) def check_api_key(cls, values): @@ -912,6 +903,8 @@ class LiteLLM_EndUserTable(LiteLLMBase): blocked: bool alias: Optional[str] = None spend: float = 0.0 + allowed_model_region: Optional[Literal["eu"]] = None + default_model: Optional[str] = None litellm_budget_table: Optional[LiteLLM_BudgetTable] = None @root_validator(pre=True) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index a393ec90a..920de3cc8 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -208,7 +208,9 @@ async def get_end_user_object( return None # check if in cache - cached_user_obj = user_api_key_cache.async_get_cache(key=end_user_id) + cached_user_obj = user_api_key_cache.async_get_cache( + key="end_user_id:{}".format(end_user_id) + ) if cached_user_obj is not None: if isinstance(cached_user_obj, dict): return LiteLLM_EndUserTable(**cached_user_obj) @@ -223,7 +225,14 @@ async def get_end_user_object( if response is None: raise Exception - return LiteLLM_EndUserTable(**response.dict()) + # save the end-user object to cache + await user_api_key_cache.async_set_cache( + key="end_user_id:{}".format(end_user_id), value=response + ) + + _response = LiteLLM_EndUserTable(**response.dict()) + + return _response except Exception as e: # if end-user not in db return None diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 009ea279f..7f06ba6a2 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -472,10 +472,6 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, ) - # save the end-user object to cache - await user_api_key_cache.async_set_cache( - key=end_user_id, value=end_user_object - ) global_proxy_spend = None if litellm.max_budget > 0: # user set proxy max budget @@ -957,13 +953,16 @@ async def user_api_key_auth( _end_user_object = None if "user" in request_data: - _id = "end_user_id:{}".format(request_data["user"]) - _end_user_object = await user_api_key_cache.async_get_cache(key=_id) - if _end_user_object is not None: - _end_user_object = LiteLLM_EndUserTable(**_end_user_object) + _end_user_object = await get_end_user_object( + end_user_id=request_data["user"], + prisma_client=prisma_client, + user_api_key_cache=user_api_key_cache, + ) global_proxy_spend = None - if litellm.max_budget > 0: # user set proxy max budget + if ( + litellm.max_budget > 0 and prisma_client is not None + ): # user set proxy max budget # check cache global_proxy_spend = await user_api_key_cache.async_get_cache( key="{}:spend".format(litellm_proxy_admin_name) @@ -1016,6 +1015,12 @@ async def user_api_key_auth( ) valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict.pop("token", None) + + if _end_user_object is not None: + valid_token_dict["allowed_model_region"] = ( + _end_user_object.allowed_model_region + ) + """ asyncio create task to update the user api key cache with the user db table as well @@ -1040,10 +1045,7 @@ async def user_api_key_auth( # check if user can access this route query_params = request.query_params key = query_params.get("key") - if ( - key is not None - and prisma_client.hash_token(token=key) != api_key - ): + if key is not None and hash_token(token=key) != api_key: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="user not allowed to access this key's info", @@ -1096,6 +1098,7 @@ async def user_api_key_auth( # sso/login, ui/login, /key functions and /user functions # this will never be allowed to call /chat/completions token_team = getattr(valid_token, "team_id", None) + if token_team is not None and token_team == "litellm-dashboard": # this token is only used for managing the ui allowed_routes = [ @@ -3617,6 +3620,10 @@ async def chat_completion( **data, } # add the team-specific configs to the completion call + ### END-USER SPECIFIC PARAMS ### + if user_api_key_dict.allowed_model_region is not None: + data["allowed_model_region"] = user_api_key_dict.allowed_model_region + global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli if user_temperature: diff --git a/litellm/router.py b/litellm/router.py index f1ac0135a..68f49a0a0 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -32,6 +32,7 @@ from litellm.utils import ( CustomStreamWrapper, get_utc_datetime, calculate_max_parallel_requests, + _is_region_eu, ) import copy from litellm._logging import verbose_router_logger @@ -1999,7 +2000,11 @@ class Router: # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # we do this here because we init clients for Azure, OpenAI and we need to set the right key api_key = litellm_params.get("api_key") or default_api_key - if api_key and api_key.startswith("os.environ/"): + if ( + api_key + and isinstance(api_key, str) + and api_key.startswith("os.environ/") + ): api_key_env_name = api_key.replace("os.environ/", "") api_key = litellm.get_secret(api_key_env_name) litellm_params["api_key"] = api_key @@ -2023,6 +2028,7 @@ class Router: if ( is_azure_ai_studio_model == True and api_base is not None + and isinstance(api_base, str) and not api_base.endswith("/v1/") ): # check if it ends with a trailing slash @@ -2103,13 +2109,14 @@ class Router: organization = litellm.get_secret(organization_env_name) litellm_params["organization"] = organization - if "azure" in model_name: - if api_base is None: + if "azure" in model_name and isinstance(api_key, str): + if api_base is None or not isinstance(api_base, str): raise ValueError( f"api_base is required for Azure OpenAI. Set it on your config. Model - {model}" ) if api_version is None: api_version = "2023-07-01-preview" + if "gateway.ai.cloudflare.com" in api_base: if not api_base.endswith("/"): api_base += "/" @@ -2532,7 +2539,7 @@ class Router: self.default_deployment = deployment.to_json(exclude_none=True) # Azure GPT-Vision Enhancements, users can pass os.environ/ - data_sources = deployment.litellm_params.get("dataSources", []) + data_sources = deployment.litellm_params.get("dataSources", []) or [] for data_source in data_sources: params = data_source.get("parameters", {}) @@ -2549,6 +2556,22 @@ class Router: # init OpenAI, Azure clients self.set_client(model=deployment.to_json(exclude_none=True)) + # set region (if azure model) + try: + if "azure" in deployment.litellm_params.model: + region = litellm.utils.get_model_region( + litellm_params=deployment.litellm_params, mode=None + ) + + deployment.litellm_params.region_name = region + except Exception as e: + verbose_router_logger.error( + "Unable to get the region for azure model - {}, {}".format( + deployment.litellm_params.model, str(e) + ) + ) + pass # [NON-BLOCKING] + return deployment def add_deployment(self, deployment: Deployment) -> Optional[Deployment]: @@ -2820,14 +2843,17 @@ class Router: model: str, healthy_deployments: List, messages: List[Dict[str, str]], + allowed_model_region: Optional[Literal["eu"]] = None, ): """ Filter out model in model group, if: - model context window < message length - filter models above rpm limits + - if region given, filter out models not in that region / unknown region - [TODO] function call and model doesn't support function calling """ + verbose_router_logger.debug( f"Starting Pre-call checks for deployments in model={model}" ) @@ -2878,9 +2904,9 @@ class Router: except Exception as e: verbose_router_logger.debug("An error occurs - {}".format(str(e))) - ## RPM CHECK ## _litellm_params = deployment.get("litellm_params", {}) model_id = deployment.get("model_info", {}).get("id", "") + ## RPM CHECK ## ### get local router cache ### current_request_cache_local = ( self.cache.get_cache(key=model_id, local_only=True) or 0 @@ -2908,6 +2934,28 @@ class Router: _rate_limit_error = True continue + ## REGION CHECK ## + if allowed_model_region is not None: + if _litellm_params.get("region_name") is not None and isinstance( + _litellm_params["region_name"], str + ): + # check if in allowed_model_region + if ( + _is_region_eu(model_region=_litellm_params["region_name"]) + == False + ): + invalid_model_indices.append(idx) + continue + else: + verbose_router_logger.debug( + "Filtering out model - {}, as model_region=None, and allowed_model_region={}".format( + model_id, allowed_model_region + ) + ) + # filter out since region unknown, and user wants to filter for specific region + invalid_model_indices.append(idx) + continue + if len(invalid_model_indices) == len(_returned_deployments): """ - no healthy deployments available b/c context window checks or rate limit error @@ -3047,10 +3095,31 @@ class Router: # filter pre-call checks if self.enable_pre_call_checks and messages is not None: - healthy_deployments = self._pre_call_checks( - model=model, healthy_deployments=healthy_deployments, messages=messages + _allowed_model_region = ( + request_kwargs.get("allowed_model_region") + if request_kwargs is not None + else None ) + if _allowed_model_region == "eu": + healthy_deployments = self._pre_call_checks( + model=model, + healthy_deployments=healthy_deployments, + messages=messages, + allowed_model_region=_allowed_model_region, + ) + else: + verbose_router_logger.debug( + "Ignoring given 'allowed_model_region'={}. Only 'eu' is allowed".format( + _allowed_model_region + ) + ) + healthy_deployments = self._pre_call_checks( + model=model, + healthy_deployments=healthy_deployments, + messages=messages, + ) + if len(healthy_deployments) == 0: raise ValueError( f"{RouterErrors.no_deployments_available.value}, passed model={model}" diff --git a/litellm/types/router.py b/litellm/types/router.py index ec7decf34..dbf36f17c 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -123,6 +123,8 @@ class GenericLiteLLMParams(BaseModel): ) max_retries: Optional[int] = None organization: Optional[str] = None # for openai orgs + ## UNIFIED PROJECT/REGION ## + region_name: Optional[str] = None ## VERTEX AI ## vertex_project: Optional[str] = None vertex_location: Optional[str] = None @@ -150,6 +152,8 @@ class GenericLiteLLMParams(BaseModel): None # timeout when making stream=True calls, if str, pass in as os.environ/ ), organization: Optional[str] = None, # for openai orgs + ## UNIFIED PROJECT/REGION ## + region_name: Optional[str] = None, ## VERTEX AI ## vertex_project: Optional[str] = None, vertex_location: Optional[str] = None, diff --git a/litellm/utils.py b/litellm/utils.py index 88e395233..f0d6805ab 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5844,6 +5844,40 @@ def calculate_max_parallel_requests( return None +def _is_region_eu(model_region: str) -> bool: + EU_Regions = ["europe", "sweden", "switzerland", "france", "uk"] + for region in EU_Regions: + if "europe" in model_region.lower(): + return True + return False + + +def get_model_region( + litellm_params: LiteLLM_Params, mode: Optional[str] +) -> Optional[str]: + """ + Pass the litellm params for an azure model, and get back the region + """ + if ( + "azure" in litellm_params.model + and isinstance(litellm_params.api_key, str) + and isinstance(litellm_params.api_base, str) + ): + _model = litellm_params.model.replace("azure/", "") + response: dict = litellm.AzureChatCompletion().get_headers( + model=_model, + api_key=litellm_params.api_key, + api_base=litellm_params.api_base, + api_version=litellm_params.api_version or "2023-07-01-preview", + timeout=10, + mode=mode or "chat", + ) + + region: Optional[str] = response.get("x-ms-region", None) + return region + return None + + def get_api_base(model: str, optional_params: dict) -> Optional[str]: """ Returns the api base used for calling the model. @@ -9423,7 +9457,9 @@ def get_secret( else: secret = os.environ.get(secret_name) try: - secret_value_as_bool = ast.literal_eval(secret) if secret is not None else None + secret_value_as_bool = ( + ast.literal_eval(secret) if secret is not None else None + ) if isinstance(secret_value_as_bool, bool): return secret_value_as_bool else: diff --git a/proxy_server_config.yaml b/proxy_server_config.yaml index fe58d64c6..046ed7e95 100644 --- a/proxy_server_config.yaml +++ b/proxy_server_config.yaml @@ -1,4 +1,9 @@ model_list: + - model_name: gpt-3.5-turbo + litellm_params: + model: azure/gpt-35-turbo + api_base: https://my-endpoint-europe-berri-992.openai.azure.com/ + api_key: os.environ/AZURE_EUROPE_API_KEY - model_name: gpt-3.5-turbo litellm_params: model: azure/chatgpt-v-2 diff --git a/tests/test_end_users.py b/tests/test_end_users.py new file mode 100644 index 000000000..96cfc2bde --- /dev/null +++ b/tests/test_end_users.py @@ -0,0 +1,173 @@ +# What is this? +## Unit tests for the /end_users/* endpoints +import pytest +import asyncio +import aiohttp +import time +import uuid +from openai import AsyncOpenAI +from typing import Optional + +""" +- `/end_user/new` +- `/end_user/info` +""" + + +async def chat_completion_with_headers(session, key, model="gpt-4"): + url = "http://0.0.0.0:4000/chat/completions" + headers = { + "Authorization": f"Bearer {key}", + "Content-Type": "application/json", + } + data = { + "model": model, + "messages": [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"}, + ], + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(response_text) + print() + + if status != 200: + raise Exception(f"Request did not return a 200 status code: {status}") + + response_header_check( + response + ) # calling the function to check response headers + + raw_headers = response.raw_headers + raw_headers_json = {} + + for ( + item + ) in ( + response.raw_headers + ): # ((b'date', b'Fri, 19 Apr 2024 21:17:29 GMT'), (), ) + raw_headers_json[item[0].decode("utf-8")] = item[1].decode("utf-8") + + return raw_headers_json + + +async def generate_key( + session, + i, + budget=None, + budget_duration=None, + models=["azure-models", "gpt-4", "dall-e-3"], + max_parallel_requests: Optional[int] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + calling_key="sk-1234", +): + url = "http://0.0.0.0:4000/key/generate" + headers = { + "Authorization": f"Bearer {calling_key}", + "Content-Type": "application/json", + } + data = { + "models": models, + "aliases": {"mistral-7b": "gpt-3.5-turbo"}, + "duration": None, + "max_budget": budget, + "budget_duration": budget_duration, + "max_parallel_requests": max_parallel_requests, + "user_id": user_id, + "team_id": team_id, + } + + print(f"data: {data}") + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +async def new_end_user( + session, i, user_id=str(uuid.uuid4()), model_region=None, default_model=None +): + url = "http://0.0.0.0:4000/end_user/new" + headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} + data = { + "user_id": user_id, + "allowed_model_region": model_region, + "default_model": default_model, + } + + async with session.post(url, headers=headers, json=data) as response: + status = response.status + response_text = await response.text() + + print(f"Response {i} (Status code: {status}):") + print(response_text) + print() + + if status != 200: + raise Exception(f"Request {i} did not return a 200 status code: {status}") + + return await response.json() + + +@pytest.mark.asyncio +async def test_end_user_new(): + """ + Make 20 parallel calls to /user/new. Assert all worked. + """ + async with aiohttp.ClientSession() as session: + tasks = [new_end_user(session, i, str(uuid.uuid4())) for i in range(1, 11)] + await asyncio.gather(*tasks) + + +@pytest.mark.asyncio +async def test_end_user_specific_region(): + """ + - Specify region user can make calls in + - Make a generic call + - assert returned api base is for model in region + + Repeat 3 times + """ + key: str = "" + ## CREATE USER ## + async with aiohttp.ClientSession() as session: + end_user_obj = await new_end_user( + session=session, + i=0, + user_id=str(uuid.uuid4()), + model_region="eu", + ) + + ## MAKE CALL ## + key_gen = await generate_key(session=session, i=0, models=["gpt-3.5-turbo"]) + + key = key_gen["key"] + + for _ in range(3): + client = AsyncOpenAI(api_key=key, base_url="http://0.0.0.0:4000") + + print("SENDING USER PARAM - {}".format(end_user_obj["user_id"])) + result = await client.chat.completions.with_raw_response.create( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hey!"}], + user=end_user_obj["user_id"], + ) + + assert ( + result.headers.get("x-litellm-model-api-base") + == "https://my-endpoint-europe-berri-992.openai.azure.com/" + )