diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 9271549d0..4afd0f839 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -1515,3 +1515,62 @@ async def test_proxy_model_group_alias_checks(prisma_client, hidden): assert is_model_alias_in_list is False else: assert is_model_alias_in_list, f"models: {models}" + + +@pytest.mark.asyncio +async def test_proxy_model_group_info_rerank(prisma_client): + """ + Check if rerank model is returned on the following endpoints + + `/v1/models` + `/v1/model/info` + `/v1/model_group/info` + """ + import json + + from fastapi import HTTPException, Request, Response + from starlette.datastructures import URL + + from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config") + + _model_list = [ + { + "model_name": "rerank-english-v3.0", + "litellm_params": {"model": "cohere/rerank-english-v3.0"}, + "model_info": { + "mode": "rerank", + }, + } + ] + router = litellm.Router(model_list=_model_list) + setattr(litellm.proxy.proxy_server, "llm_router", router) + setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list) + + request = Request(scope={"type": "http", "method": "POST", "headers": {}}) + request._url = URL(url="/v1/models") + + resp = await model_list( + user_api_key_dict=UserAPIKeyAuth(models=[]), + ) + + assert len(resp["data"]) == 1 + print(resp) + + resp = await model_info_v1( + user_api_key_dict=UserAPIKeyAuth(models=[]), + ) + models = resp["data"] + assert models[0]["model_info"]["mode"] == "rerank" + resp = await model_group_info( + user_api_key_dict=UserAPIKeyAuth(models=[]), + ) + + print(resp) + models = resp["data"] + assert models[0].mode == "rerank" diff --git a/litellm/types/router.py b/litellm/types/router.py index cfb90814b..1f2c7224f 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -482,7 +482,12 @@ class ModelGroupInfo(BaseModel): output_cost_per_token: Optional[float] = None mode: Optional[ Literal[ - "chat", "embedding", "completion", "image_generation", "audio_transcription" + "chat", + "embedding", + "completion", + "image_generation", + "audio_transcription", + "rerank", ] ] = Field(default="chat") tpm: Optional[int] = None