Merge pull request #5263 from BerriAI/litellm_support_access_groups

[Feat-Proxy] Use model access groups for teams
This commit is contained in:
Ishaan Jaff 2024-08-17 17:11:11 -07:00 committed by GitHub
commit 83515e88ce
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 171 additions and 6 deletions

View file

@ -10,7 +10,7 @@ Run checks for:
"""
import time
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, Optional
from typing import TYPE_CHECKING, Any, List, Literal, Optional
import litellm
from litellm._logging import verbose_proxy_logger
@ -77,6 +77,11 @@ def common_checks(
if "all-proxy-models" in team_object.models:
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
elif model_in_access_group(_model, team_object.models) is True:
pass
elif _model and "*" in _model:
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
@ -327,6 +332,39 @@ async def get_end_user_object(
return None
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
from collections import defaultdict
from litellm.proxy.proxy_server import llm_router
if team_models is None:
return True
if model in team_models:
return True
access_groups = defaultdict(list)
if llm_router:
access_groups = llm_router.get_model_access_groups()
models_in_current_access_groups = []
if len(access_groups) > 0: # check if token contains any model access groups
for idx, m in enumerate(
team_models
): # loop token models, if any of them are an access group add the access group
if m in access_groups:
# if it is an access group we need to remove it from valid_token.models
models_in_group = access_groups[m]
models_in_current_access_groups.extend(models_in_group)
# Filter out models that are access_groups
filtered_models = [m for m in team_models if m not in access_groups]
filtered_models += models_in_current_access_groups
if model in filtered_models:
return True
return False
@log_to_opentelemetry
async def get_user_object(
user_id: str,
@ -543,12 +581,11 @@ async def can_key_call_model(
)
from collections import defaultdict
from litellm.proxy.proxy_server import llm_router
access_groups = defaultdict(list)
if llm_model_list is not None:
for m in llm_model_list:
for group in m.get("model_info", {}).get("access_groups", []):
model_name = m["model_name"]
access_groups[group].append(model_name)
if llm_router:
access_groups = llm_router.get_model_access_groups()
models_in_current_access_groups = []
if len(access_groups) > 0: # check if token contains any model access groups

View file

@ -4,10 +4,14 @@ model_list:
model: openai/fake
api_key: fake-key
api_base: https://exampleopenaiendpoint-production.up.railway.app/
model_info:
access_groups: ["beta-models"]
- model_name: fireworks-llama-v3-70b-instruct
litellm_params:
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
api_key: "os.environ/FIREWORKS"
model_info:
access_groups: ["beta-models"]
- model_name: "*"
litellm_params:
model: "*"