From 08db691decf5480b2e2b2077b347d81ea6c2f0eb Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 17 Aug 2024 16:45:53 -0700 Subject: [PATCH 1/2] use model access groups for teams --- litellm/proxy/auth/auth_checks.py | 49 +++++++++++++++++++++++++++---- litellm/proxy/proxy_config.yaml | 4 +++ litellm/router.py | 17 +++++++++++ 3 files changed, 64 insertions(+), 6 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 366eb1fb2..cf5065c2e 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -10,7 +10,7 @@ Run checks for: """ import time from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, Optional +from typing import TYPE_CHECKING, Any, List, Literal, Optional import litellm from litellm._logging import verbose_proxy_logger @@ -77,6 +77,11 @@ def common_checks( if "all-proxy-models" in team_object.models: # this means the team has access to all models on the proxy pass + # check if the team model is an access_group + elif model_in_access_group(_model, team_object.models) is True: + pass + elif _model and "*" in _model: + pass else: raise Exception( f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}" @@ -327,6 +332,39 @@ async def get_end_user_object( return None +def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool: + from collections import defaultdict + + from litellm.proxy.proxy_server import llm_router + + if team_models is None: + return True + if model in team_models: + return True + + access_groups = defaultdict(list) + if llm_router: + access_groups = llm_router.get_model_access_groups() + + models_in_current_access_groups = [] + if len(access_groups) > 0: # check if token contains any model access groups + for idx, m in enumerate( + team_models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + # if it is an access group we need to remove it from valid_token.models + models_in_group = access_groups[m] + models_in_current_access_groups.extend(models_in_group) + + # Filter out models that are access_groups + filtered_models = [m for m in team_models if m not in access_groups] + filtered_models += models_in_current_access_groups + + if model in filtered_models: + return True + return False + + @log_to_opentelemetry async def get_user_object( user_id: str, @@ -543,12 +581,11 @@ async def can_key_call_model( ) from collections import defaultdict + from litellm.proxy.proxy_server import llm_router + access_groups = defaultdict(list) - if llm_model_list is not None: - for m in llm_model_list: - for group in m.get("model_info", {}).get("access_groups", []): - model_name = m["model_name"] - access_groups[group].append(model_name) + if llm_router: + access_groups = llm_router.get_model_access_groups() models_in_current_access_groups = [] if len(access_groups) > 0: # check if token contains any model access groups diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 336cd8458..e08be88aa 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -4,10 +4,14 @@ model_list: model: openai/fake api_key: fake-key api_base: https://exampleopenaiendpoint-production.up.railway.app/ + model_info: + access_groups: ["beta-models"] - model_name: fireworks-llama-v3-70b-instruct litellm_params: model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct api_key: "os.environ/FIREWORKS" + model_info: + access_groups: ["beta-models"] - model_name: "*" litellm_params: model: "*" diff --git a/litellm/router.py b/litellm/router.py index 7bc8acae4..2d7bd517b 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -421,6 +421,7 @@ class Router: routing_strategy=routing_strategy, routing_strategy_args=routing_strategy_args, ) + self.access_groups = None ## USAGE TRACKING ## if isinstance(litellm._async_success_callback, list): litellm._async_success_callback.append(self.deployment_callback_on_success) @@ -4116,6 +4117,22 @@ class Router: return self.model_list return None + def get_model_access_groups(self): + from collections import defaultdict + + access_groups = defaultdict(list) + if self.access_groups: + return self.access_groups + + if self.model_list: + for m in self.model_list: + for group in m.get("model_info", {}).get("access_groups", []): + model_name = m["model_name"] + access_groups[group].append(model_name) + # set access groups + self.access_groups = access_groups + return access_groups + def get_settings(self): """ Get router settings method, returns a dictionary of the settings and their values. From 6fee350938b9c8cc05ebb721540b9662919ecdc0 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Sat, 17 Aug 2024 17:10:10 -0700 Subject: [PATCH 2/2] feat add model access groups for teams --- litellm/tests/test_key_generate_prisma.py | 107 ++++++++++++++++++++++ 1 file changed, 107 insertions(+) diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 66d49f886..64800c99e 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2767,3 +2767,110 @@ async def test_generate_key_with_model_tpm_limit(prisma_client): "model_tpm_limit": {"gpt-4": 200}, "model_rpm_limit": {"gpt-4": 3}, } + + +@pytest.mark.asyncio() +async def test_team_access_groups(prisma_client): + """ + Test team based model access groups + + - Test calling a model in the access group -> pass + - Test calling a model not in the access group -> fail + """ + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + # create router with access groups + litellm_router = litellm.Router( + model_list=[ + { + "model_name": "gemini-pro-vision", + "litellm_params": { + "model": "vertex_ai/gemini-1.0-pro-vision-001", + }, + "model_info": {"access_groups": ["beta-models"]}, + }, + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "gpt-4o", + }, + "model_info": {"access_groups": ["beta-models"]}, + }, + ] + ) + setattr(litellm.proxy.proxy_server, "llm_router", litellm_router) + + # Create team with models=["beta-models"] + team_request = NewTeamRequest( + team_alias="testing-team", + models=["beta-models"], + ) + + new_team_response = await new_team( + data=team_request, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + http_request=Request(scope={"type": "http"}), + ) + print("new_team_response", new_team_response) + created_team_id = new_team_response["team_id"] + + # create key with team_id=created_team_id + request = GenerateKeyRequest( + team_id=created_team_id, + ) + + key = await generate_key_fn( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print(key) + + generated_key = key.key + bearer_token = "Bearer " + generated_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + for model in ["gpt-4o", "gemini-pro-vision"]: + # Expect these to pass + async def return_body(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body + + # use generated key to auth in + print( + "Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) + ) + result = await user_api_key_auth(request=request, api_key=bearer_token) + + for model in ["gpt-4", "gpt-4o-mini", "gemini-experimental"]: + # Expect these to fail + async def return_body_2(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_2 + + # use generated key to auth in + print( + "Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) + ) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail(f"This should have failed!. IT's an invalid model") + except Exception as e: + print("got exception", e) + assert ( + "not allowed to call model" in e.message + and "Allowed team models" in e.message + )