Litellm dev 12 31 2024 p1 (#7488)

* fix(internal_user_endpoints.py): fix team list sort - handle team_alias being set + None

* fix(key_management_endpoints.py): allow team admin to create key for member via admin ui

Fixes https://github.com/BerriAI/litellm/issues/7482

* fix(proxy_server.py): allow querying info on specific model group via `/model_group/info`

allows client-side user to get model info from proxy

* fix(proxy_server.py): add docstring on `/model_group/info` showing how to filter by model name

* test(test_proxy_utils.py): add unit test for returning model group info filtered

* fix(proxy_server.py): fix query param

* fix(test_Get_model_info.py): handle no whitelisted bedrock modells
This commit is contained in:
Krish Dholakia 2024-12-31 23:21:51 -08:00 committed by GitHub
parent 080de89cfb
commit 39cbd9d878
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 124 additions and 18 deletions

View file

@ -2237,3 +2237,6 @@ class ProxyStateVariables(TypedDict):
""" """
spend_logs_row_count: int spend_logs_row_count: int
UI_TEAM_ID = "litellm-dashboard"

View file

@ -370,7 +370,7 @@ async def user_info(
## REMOVE HASHED TOKEN INFO before returning ## ## REMOVE HASHED TOKEN INFO before returning ##
returned_keys = _process_keys_for_user_info(keys=keys, all_teams=teams_1) returned_keys = _process_keys_for_user_info(keys=keys, all_teams=teams_1)
team_list.sort(key=lambda x: (getattr(x, "team_alias", ""))) team_list.sort(key=lambda x: (getattr(x, "team_alias", "") or ""))
_user_info = ( _user_info = (
user_info.model_dump() if isinstance(user_info, BaseModel) else user_info user_info.model_dump() if isinstance(user_info, BaseModel) else user_info
) )

View file

@ -85,6 +85,11 @@ def _is_allowed_to_create_key(
) )
if team_id is not None: if team_id is not None:
if (
user_api_key_dict.team_id is not None
and user_api_key_dict.team_id == UI_TEAM_ID
):
return True # handle https://github.com/BerriAI/litellm/issues/7482
assert ( assert (
user_api_key_dict.team_id == team_id user_api_key_dict.team_id == team_id
), "User can only create keys for their own team. Got={}, Your Team ID={}".format( ), "User can only create keys for their own team. Got={}, Your Team ID={}".format(

View file

@ -295,6 +295,7 @@ from fastapi import (
Header, Header,
HTTPException, HTTPException,
Path, Path,
Query,
Request, Request,
Response, Response,
UploadFile, UploadFile,
@ -6622,6 +6623,20 @@ async def model_info_v1( # noqa: PLR0915
return {"data": all_models} return {"data": all_models}
def _get_model_group_info(
llm_router: Router, all_models_str: List[str], model_group: Optional[str]
) -> List[ModelGroupInfo]:
model_groups: List[ModelGroupInfo] = []
for model in all_models_str:
if model_group is not None and model_group != model:
continue
_model_group_info = llm_router.get_model_group_info(model_group=model)
if _model_group_info is not None:
model_groups.append(_model_group_info)
return model_groups
@router.get( @router.get(
"/model_group/info", "/model_group/info",
tags=["model management"], tags=["model management"],
@ -6629,14 +6644,17 @@ async def model_info_v1( # noqa: PLR0915
) )
async def model_group_info( async def model_group_info(
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
model_group: Optional[str] = None,
): ):
""" """
Get information about all the deployments on litellm proxy, including config.yaml descriptions (except api key and api base) Get information about all the deployments on litellm proxy, including config.yaml descriptions (except api key and api base)
- /models returns all deployments. Proxy Admins can use this to list all deployments setup on the proxy
- /model_group/info returns all model groups. End users of proxy should use /model_group/info since those models will be used for /chat/completions, /embeddings, etc. - /model_group/info returns all model groups. End users of proxy should use /model_group/info since those models will be used for /chat/completions, /embeddings, etc.
- /model_group/info?model_group=rerank-english-v3.0 returns all model groups for a specific model group (`model_name` in config.yaml)
Example Request (All Models):
```shell ```shell
curl -X 'GET' \ curl -X 'GET' \
'http://localhost:4000/model_group/info' \ 'http://localhost:4000/model_group/info' \
@ -6644,6 +6662,24 @@ async def model_group_info(
-H 'x-api-key: sk-1234' -H 'x-api-key: sk-1234'
``` ```
Example Request (Specific Model Group):
```shell
curl -X 'GET' \
'http://localhost:4000/model_group/info?model_group=rerank-english-v3.0' \
-H 'accept: application/json' \
-H 'Authorization: Bearer sk-1234'
```
Example Request (Specific Wildcard Model Group): (e.g. `model_name: openai/*` on config.yaml)
```shell
curl -X 'GET' \
'http://localhost:4000/model_group/info?model_group=openai/tts-1'
-H 'accept: application/json' \
-H 'Authorization: Bearersk-1234'
```
Learn how to use and set wildcard models [here](https://docs.litellm.ai/docs/wildcard_routing)
Example Response: Example Response:
```json ```json
{ {
@ -6796,13 +6832,9 @@ async def model_group_info(
infer_model_from_keys=general_settings.get("infer_model_from_keys", False), infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
) )
model_groups: List[ModelGroupInfo] = [] model_groups: List[ModelGroupInfo] = _get_model_group_info(
llm_router=llm_router, all_models_str=all_models_str, model_group=model_group
for model in all_models_str: )
_model_group_info = llm_router.get_model_group_info(model_group=model)
if _model_group_info is not None:
model_groups.append(_model_group_info)
return {"data": model_groups} return {"data": model_groups}

View file

@ -509,6 +509,8 @@ class ModelGroupInfo(BaseModel):
input_cost_per_token: Optional[float] = None input_cost_per_token: Optional[float] = None
output_cost_per_token: Optional[float] = None output_cost_per_token: Optional[float] = None
mode: Optional[ mode: Optional[
Union[
str,
Literal[ Literal[
"chat", "chat",
"embedding", "embedding",
@ -516,6 +518,8 @@ class ModelGroupInfo(BaseModel):
"image_generation", "image_generation",
"audio_transcription", "audio_transcription",
"rerank", "rerank",
"moderations",
],
] ]
] = Field(default="chat") ] = Field(default="chat")
tpm: Optional[int] = None tpm: Optional[int] = None

View file

@ -209,7 +209,6 @@ def test_model_info_bedrock_converse(monkeypatch):
""" """
monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True") monkeypatch.setenv("LITELLM_LOCAL_MODEL_COST_MAP", "True")
litellm.model_cost = litellm.get_model_cost_map(url="") litellm.model_cost = litellm.get_model_cost_map(url="")
try: try:
# Load whitelist models from file # Load whitelist models from file
with open("whitelisted_bedrock_models.txt", "r") as file: with open("whitelisted_bedrock_models.txt", "r") as file:

View file

@ -1914,8 +1914,10 @@ async def test_proxy_model_group_alias_checks(prisma_client, hidden):
resp = await model_group_info( resp = await model_group_info(
user_api_key_dict=UserAPIKeyAuth(models=[]), user_api_key_dict=UserAPIKeyAuth(models=[]),
) )
print(f"resp: {resp}")
models = resp["data"] models = resp["data"]
is_model_alias_in_list = False is_model_alias_in_list = False
print(f"model_alias: {model_alias}, models: {models}")
for item in models: for item in models:
if model_alias == item.model_group: if model_alias == item.model_group:
is_model_alias_in_list = True is_model_alias_in_list = True

View file

@ -1191,3 +1191,64 @@ def test_litellm_verification_token_view_response_with_budget_table(
getattr(resp, expected_user_api_key_auth_key) getattr(resp, expected_user_api_key_auth_key)
== expected_user_api_key_auth_value == expected_user_api_key_auth_value
) )
def test_is_allowed_to_create_key():
from litellm.proxy._types import LitellmUserRoles
from litellm.proxy.management_endpoints.key_management_endpoints import (
_is_allowed_to_create_key,
)
assert (
_is_allowed_to_create_key(
user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id", user_role=LitellmUserRoles.PROXY_ADMIN
),
user_id="test_user_id",
team_id="test_team_id",
)
is True
)
assert (
_is_allowed_to_create_key(
user_api_key_dict=UserAPIKeyAuth(
user_id="test_user_id",
user_role=LitellmUserRoles.INTERNAL_USER,
team_id="litellm-dashboard",
),
user_id="test_user_id",
team_id="test_team_id",
)
is True
)
def test_get_model_group_info():
from litellm.proxy.proxy_server import _get_model_group_info
from litellm import Router
router = Router(
model_list=[
{
"model_name": "openai/tts-1",
"litellm_params": {
"model": "openai/tts-1",
"api_key": "sk-1234",
},
},
{
"model_name": "openai/gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "sk-1234",
},
},
]
)
model_list = _get_model_group_info(
llm_router=router,
all_models_str=["openai/tts-1", "openai/gpt-3.5-turbo"],
model_group="openai/tts-1",
)
assert len(model_list) == 1