forked from phoenix/litellm-mirror
Merge pull request #5263 from BerriAI/litellm_support_access_groups
[Feat-Proxy] Use model access groups for teams
This commit is contained in:
commit
83515e88ce
4 changed files with 171 additions and 6 deletions
|
@ -10,7 +10,7 @@ Run checks for:
|
||||||
"""
|
"""
|
||||||
import time
|
import time
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import TYPE_CHECKING, Any, Literal, Optional
|
from typing import TYPE_CHECKING, Any, List, Literal, Optional
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm._logging import verbose_proxy_logger
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
@ -77,6 +77,11 @@ def common_checks(
|
||||||
if "all-proxy-models" in team_object.models:
|
if "all-proxy-models" in team_object.models:
|
||||||
# this means the team has access to all models on the proxy
|
# this means the team has access to all models on the proxy
|
||||||
pass
|
pass
|
||||||
|
# check if the team model is an access_group
|
||||||
|
elif model_in_access_group(_model, team_object.models) is True:
|
||||||
|
pass
|
||||||
|
elif _model and "*" in _model:
|
||||||
|
pass
|
||||||
else:
|
else:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
|
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
|
||||||
|
@ -327,6 +332,39 @@ async def get_end_user_object(
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def model_in_access_group(model: str, team_models: Optional[List[str]]) -> bool:
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import llm_router
|
||||||
|
|
||||||
|
if team_models is None:
|
||||||
|
return True
|
||||||
|
if model in team_models:
|
||||||
|
return True
|
||||||
|
|
||||||
|
access_groups = defaultdict(list)
|
||||||
|
if llm_router:
|
||||||
|
access_groups = llm_router.get_model_access_groups()
|
||||||
|
|
||||||
|
models_in_current_access_groups = []
|
||||||
|
if len(access_groups) > 0: # check if token contains any model access groups
|
||||||
|
for idx, m in enumerate(
|
||||||
|
team_models
|
||||||
|
): # loop token models, if any of them are an access group add the access group
|
||||||
|
if m in access_groups:
|
||||||
|
# if it is an access group we need to remove it from valid_token.models
|
||||||
|
models_in_group = access_groups[m]
|
||||||
|
models_in_current_access_groups.extend(models_in_group)
|
||||||
|
|
||||||
|
# Filter out models that are access_groups
|
||||||
|
filtered_models = [m for m in team_models if m not in access_groups]
|
||||||
|
filtered_models += models_in_current_access_groups
|
||||||
|
|
||||||
|
if model in filtered_models:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@log_to_opentelemetry
|
@log_to_opentelemetry
|
||||||
async def get_user_object(
|
async def get_user_object(
|
||||||
user_id: str,
|
user_id: str,
|
||||||
|
@ -543,12 +581,11 @@ async def can_key_call_model(
|
||||||
)
|
)
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import llm_router
|
||||||
|
|
||||||
access_groups = defaultdict(list)
|
access_groups = defaultdict(list)
|
||||||
if llm_model_list is not None:
|
if llm_router:
|
||||||
for m in llm_model_list:
|
access_groups = llm_router.get_model_access_groups()
|
||||||
for group in m.get("model_info", {}).get("access_groups", []):
|
|
||||||
model_name = m["model_name"]
|
|
||||||
access_groups[group].append(model_name)
|
|
||||||
|
|
||||||
models_in_current_access_groups = []
|
models_in_current_access_groups = []
|
||||||
if len(access_groups) > 0: # check if token contains any model access groups
|
if len(access_groups) > 0: # check if token contains any model access groups
|
||||||
|
|
|
@ -4,10 +4,14 @@ model_list:
|
||||||
model: openai/fake
|
model: openai/fake
|
||||||
api_key: fake-key
|
api_key: fake-key
|
||||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||||
|
model_info:
|
||||||
|
access_groups: ["beta-models"]
|
||||||
- model_name: fireworks-llama-v3-70b-instruct
|
- model_name: fireworks-llama-v3-70b-instruct
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
model: fireworks_ai/accounts/fireworks/models/llama-v3-70b-instruct
|
||||||
api_key: "os.environ/FIREWORKS"
|
api_key: "os.environ/FIREWORKS"
|
||||||
|
model_info:
|
||||||
|
access_groups: ["beta-models"]
|
||||||
- model_name: "*"
|
- model_name: "*"
|
||||||
litellm_params:
|
litellm_params:
|
||||||
model: "*"
|
model: "*"
|
||||||
|
|
|
@ -421,6 +421,7 @@ class Router:
|
||||||
routing_strategy=routing_strategy,
|
routing_strategy=routing_strategy,
|
||||||
routing_strategy_args=routing_strategy_args,
|
routing_strategy_args=routing_strategy_args,
|
||||||
)
|
)
|
||||||
|
self.access_groups = None
|
||||||
## USAGE TRACKING ##
|
## USAGE TRACKING ##
|
||||||
if isinstance(litellm._async_success_callback, list):
|
if isinstance(litellm._async_success_callback, list):
|
||||||
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
litellm._async_success_callback.append(self.deployment_callback_on_success)
|
||||||
|
@ -4116,6 +4117,22 @@ class Router:
|
||||||
return self.model_list
|
return self.model_list
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def get_model_access_groups(self):
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
access_groups = defaultdict(list)
|
||||||
|
if self.access_groups:
|
||||||
|
return self.access_groups
|
||||||
|
|
||||||
|
if self.model_list:
|
||||||
|
for m in self.model_list:
|
||||||
|
for group in m.get("model_info", {}).get("access_groups", []):
|
||||||
|
model_name = m["model_name"]
|
||||||
|
access_groups[group].append(model_name)
|
||||||
|
# set access groups
|
||||||
|
self.access_groups = access_groups
|
||||||
|
return access_groups
|
||||||
|
|
||||||
def get_settings(self):
|
def get_settings(self):
|
||||||
"""
|
"""
|
||||||
Get router settings method, returns a dictionary of the settings and their values.
|
Get router settings method, returns a dictionary of the settings and their values.
|
||||||
|
|
|
@ -2768,3 +2768,110 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
||||||
"model_tpm_limit": {"gpt-4": 200},
|
"model_tpm_limit": {"gpt-4": 200},
|
||||||
"model_rpm_limit": {"gpt-4": 3},
|
"model_rpm_limit": {"gpt-4": 3},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio()
|
||||||
|
async def test_team_access_groups(prisma_client):
|
||||||
|
"""
|
||||||
|
Test team based model access groups
|
||||||
|
|
||||||
|
- Test calling a model in the access group -> pass
|
||||||
|
- Test calling a model not in the access group -> fail
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = True
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
# create router with access groups
|
||||||
|
litellm_router = litellm.Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gemini-pro-vision",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "vertex_ai/gemini-1.0-pro-vision-001",
|
||||||
|
},
|
||||||
|
"model_info": {"access_groups": ["beta-models"]},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4o",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
},
|
||||||
|
"model_info": {"access_groups": ["beta-models"]},
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", litellm_router)
|
||||||
|
|
||||||
|
# Create team with models=["beta-models"]
|
||||||
|
team_request = NewTeamRequest(
|
||||||
|
team_alias="testing-team",
|
||||||
|
models=["beta-models"],
|
||||||
|
)
|
||||||
|
|
||||||
|
new_team_response = await new_team(
|
||||||
|
data=team_request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
http_request=Request(scope={"type": "http"}),
|
||||||
|
)
|
||||||
|
print("new_team_response", new_team_response)
|
||||||
|
created_team_id = new_team_response["team_id"]
|
||||||
|
|
||||||
|
# create key with team_id=created_team_id
|
||||||
|
request = GenerateKeyRequest(
|
||||||
|
team_id=created_team_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
key = await generate_key_fn(
|
||||||
|
data=request,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
api_key="sk-1234",
|
||||||
|
user_id="1234",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
print(key)
|
||||||
|
|
||||||
|
generated_key = key.key
|
||||||
|
bearer_token = "Bearer " + generated_key
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http"})
|
||||||
|
request._url = URL(url="/chat/completions")
|
||||||
|
|
||||||
|
for model in ["gpt-4o", "gemini-pro-vision"]:
|
||||||
|
# Expect these to pass
|
||||||
|
async def return_body():
|
||||||
|
return_string = f'{{"model": "{model}"}}'
|
||||||
|
# return string as bytes
|
||||||
|
return return_string.encode()
|
||||||
|
|
||||||
|
request.body = return_body
|
||||||
|
|
||||||
|
# use generated key to auth in
|
||||||
|
print(
|
||||||
|
"Bearer token being sent to user_api_key_auth() - {}".format(bearer_token)
|
||||||
|
)
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
|
||||||
|
for model in ["gpt-4", "gpt-4o-mini", "gemini-experimental"]:
|
||||||
|
# Expect these to fail
|
||||||
|
async def return_body_2():
|
||||||
|
return_string = f'{{"model": "{model}"}}'
|
||||||
|
# return string as bytes
|
||||||
|
return return_string.encode()
|
||||||
|
|
||||||
|
request.body = return_body_2
|
||||||
|
|
||||||
|
# use generated key to auth in
|
||||||
|
print(
|
||||||
|
"Bearer token being sent to user_api_key_auth() - {}".format(bearer_token)
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
result = await user_api_key_auth(request=request, api_key=bearer_token)
|
||||||
|
pytest.fail(f"This should have failed!. IT's an invalid model")
|
||||||
|
except Exception as e:
|
||||||
|
print("got exception", e)
|
||||||
|
assert (
|
||||||
|
"not allowed to call model" in e.message
|
||||||
|
and "Allowed team models" in e.message
|
||||||
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue