diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index 66d49f886..64800c99e 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -2767,3 +2767,110 @@ async def test_generate_key_with_model_tpm_limit(prisma_client): "model_tpm_limit": {"gpt-4": 200}, "model_rpm_limit": {"gpt-4": 3}, } + + +@pytest.mark.asyncio() +async def test_team_access_groups(prisma_client): + """ + Test team based model access groups + + - Test calling a model in the access group -> pass + - Test calling a model not in the access group -> fail + """ + litellm.set_verbose = True + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + # create router with access groups + litellm_router = litellm.Router( + model_list=[ + { + "model_name": "gemini-pro-vision", + "litellm_params": { + "model": "vertex_ai/gemini-1.0-pro-vision-001", + }, + "model_info": {"access_groups": ["beta-models"]}, + }, + { + "model_name": "gpt-4o", + "litellm_params": { + "model": "gpt-4o", + }, + "model_info": {"access_groups": ["beta-models"]}, + }, + ] + ) + setattr(litellm.proxy.proxy_server, "llm_router", litellm_router) + + # Create team with models=["beta-models"] + team_request = NewTeamRequest( + team_alias="testing-team", + models=["beta-models"], + ) + + new_team_response = await new_team( + data=team_request, + user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN), + http_request=Request(scope={"type": "http"}), + ) + print("new_team_response", new_team_response) + created_team_id = new_team_response["team_id"] + + # create key with team_id=created_team_id + request = GenerateKeyRequest( + team_id=created_team_id, + ) + + key = await generate_key_fn( + data=request, + user_api_key_dict=UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + api_key="sk-1234", + user_id="1234", + ), + ) + print(key) + + generated_key = key.key + bearer_token = "Bearer " + generated_key + + request = Request(scope={"type": "http"}) + request._url = URL(url="/chat/completions") + + for model in ["gpt-4o", "gemini-pro-vision"]: + # Expect these to pass + async def return_body(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body + + # use generated key to auth in + print( + "Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) + ) + result = await user_api_key_auth(request=request, api_key=bearer_token) + + for model in ["gpt-4", "gpt-4o-mini", "gemini-experimental"]: + # Expect these to fail + async def return_body_2(): + return_string = f'{{"model": "{model}"}}' + # return string as bytes + return return_string.encode() + + request.body = return_body_2 + + # use generated key to auth in + print( + "Bearer token being sent to user_api_key_auth() - {}".format(bearer_token) + ) + try: + result = await user_api_key_auth(request=request, api_key=bearer_token) + pytest.fail(f"This should have failed!. IT's an invalid model") + except Exception as e: + print("got exception", e) + assert ( + "not allowed to call model" in e.message + and "Allowed team models" in e.message + )