forked from phoenix/litellm-mirror
use model access groups for teams
This commit is contained in:
parent
d9c91838ce
commit
08db691dec
3 changed files with 64 additions and 6 deletions
|
@ -10,7 +10,7 @@ Run checks for:
|
|||
"""
|
||||
import time
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
@ -77,6 +77,11 @@ def common_checks(
|
|||
if "all-proxy-models" in team_object.models:
|
||||
# this means the team has access to all models on the proxy
|
||||
pass
|
||||
# check if the team model is an access_group
|
||||
elif model_in_access_group(_model, team_object.models) is True:
|
||||
pass
|
||||
elif _model and "*" in _model:
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
|
||||
|
@ -327,6 +332,39 @@ async def get_end_user_object(
|
|||
return None
|
||||
|
||||
|
||||
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
if team_models is None:
|
||||
return True
|
||||
if model in team_models:
|
||||
return True
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
team_models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [m for m in team_models if m not in access_groups]
|
||||
filtered_models += models_in_current_access_groups
|
||||
|
||||
if model in filtered_models:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
@log_to_opentelemetry
|
||||
async def get_user_object(
|
||||
user_id: str,
|
||||
|
@ -543,12 +581,11 @@ async def can_key_call_model(
|
|||
)
|
||||
from collections import defaultdict
|
||||
|
||||
from litellm.proxy.proxy_server import llm_router
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_model_list is not None:
|
||||
for m in llm_model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
if llm_router:
|
||||
access_groups = llm_router.get_model_access_groups()
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if len(access_groups) > 0: # check if token contains any model access groups
|
||||
|
|
|
@ -4,10 +4,14 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
model_info:
|
||||
access_groups: ["beta-models"]
|
||||
- model_name: fireworks-llama-v3-70b-instruct
|
||||
litellm_params:
|
||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||
api_key: "os.environ/FIREWORKS"
|
||||
model_info:
|
||||
access_groups: ["beta-models"]
|
||||
- model_name: "*"
|
||||
litellm_params:
|
||||
model: "*"
|
||||
|
|
|
@ -421,6 +421,7 @@ class Router:
|
|||
routing_strategy=routing_strategy,
|
||||
routing_strategy_args=routing_strategy_args,
|
||||
)
|
||||
self.access_groups = None
|
||||
## USAGE TRACKING ##
|
||||
if isinstance(litellm._async_success_callback, list):
|
||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||
|
@ -4116,6 +4117,22 @@ class Router:
|
|||
return self.model_list
|
||||
return None
|
||||
|
||||
def get_model_access_groups(self):
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if self.access_groups:
|
||||
return self.access_groups
|
||||
|
||||
if self.model_list:
|
||||
for m in self.model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
# set access groups
|
||||
self.access_groups = access_groups
|
||||
return access_groups
|
||||
|
||||
def get_settings(self):
|
||||
"""
|
||||
Get router settings method, returns a dictionary of the settings and their values.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue