From d4ffc98a395c4853d297f2e5463127ab967c3979 Mon Sep 17 00:00:00 2001 From: ishaan-jaff Date: Wed, 14 Feb 2024 17:20:50 -0800 Subject: [PATCH] (feat) model access groups --- litellm/proxy/proxy_server.py | 47 ++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0d22921bd..24bd59b8a 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -403,34 +403,43 @@ async def user_api_key_auth( verbose_proxy_logger.debug( f"LLM Model List pre access group check: {llm_model_list}" ) - access_groups = [] + from collections import defaultdict + + access_groups = defaultdict(list) if llm_model_list is not None: for m in llm_model_list: for group in m.get("model_info", {}).get("access_groups", []): - access_groups.append((m["model_name"], group)) + model_name = m["model_name"] + access_groups[group].append(model_name) - allowed_models = valid_token.models - access_group_idx = set() + models_in_current_access_groups = [] if ( len(access_groups) > 0 ): # check if token contains any model access groups - for idx, m in enumerate(valid_token.models): - for model_name, group in access_groups: - if m == group: - access_group_idx.add(idx) - allowed_models.append(model_name) + for idx, m in enumerate( + valid_token.models + ): # loop token models, if any of them are an access group add the access group + if m in access_groups: + # if it is an access group we need to remove it from valid_token.models + models_in_group = access_groups[m] + models_in_current_access_groups.extend(models_in_group) + + # Filter out models that are access_groups + filtered_models = [ + m for m in valid_token.models if m not in access_groups + ] + + filtered_models += models_in_current_access_groups verbose_proxy_logger.debug( - f"model: {model}; allowed_models: {allowed_models}" + f"model: {model}; allowed_models: {filtered_models}" ) - if model is not None and model not in allowed_models: + if model is not None and model not in filtered_models: raise ValueError( f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" ) - for val in access_group_idx: - allowed_models.pop(val) - valid_token.models = allowed_models + valid_token.models = filtered_models verbose_proxy_logger.debug( - f"filtered allowed_models: {allowed_models}; valid_token.models: {valid_token.models}" + f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}" ) # Check 2. If user_id for this token is in budget @@ -2087,14 +2096,6 @@ def model_list( if user_model is not None: all_models += [user_model] verbose_proxy_logger.debug(f"all_models: {all_models}") - ### CHECK OLLAMA MODELS ### - try: - response = requests.get("http://0.0.0.0:11434/api/tags") - models = response.json()["models"] - ollama_models = ["ollama/" + m["name"].replace(":latest", "") for m in models] - all_models.extend(ollama_models) - except Exception as e: - pass return dict( data=[ {