(fix proxy) model_group/info support rerank models (#5955)

* fix /model_group/info on rerank

* add test test_proxy_model_group_info_rerank
This commit is contained in:
Ishaan Jaff 2024-09-28 10:54:43 -07:00 committed by GitHub
parent 088d906276
commit 8bf7573fd8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 65 additions and 1 deletions

View file

@ -1515,3 +1515,62 @@ async def test_proxy_model_group_alias_checks(prisma_client, hidden):
assert is_model_alias_in_list is False assert is_model_alias_in_list is False
else: else:
assert is_model_alias_in_list, f"models: {models}" assert is_model_alias_in_list, f"models: {models}"
@pytest.mark.asyncio
async def test_proxy_model_group_info_rerank(prisma_client):
"""
Check if rerank model is returned on the following endpoints
`/v1/models`
`/v1/model/info`
`/v1/model_group/info`
"""
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
_model_list = [
{
"model_name": "rerank-english-v3.0",
"litellm_params": {"model": "cohere/rerank-english-v3.0"},
"model_info": {
"mode": "rerank",
},
}
]
router = litellm.Router(model_list=_model_list)
setattr(litellm.proxy.proxy_server, "llm_router", router)
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/v1/models")
resp = await model_list(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
assert len(resp["data"]) == 1
print(resp)
resp = await model_info_v1(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
models = resp["data"]
assert models[0]["model_info"]["mode"] == "rerank"
resp = await model_group_info(
user_api_key_dict=UserAPIKeyAuth(models=[]),
)
print(resp)
models = resp["data"]
assert models[0].mode == "rerank"

View file

@ -482,7 +482,12 @@ class ModelGroupInfo(BaseModel):
output_cost_per_token: Optional[float] = None output_cost_per_token: Optional[float] = None
mode: Optional[ mode: Optional[
Literal[ Literal[
"chat", "embedding", "completion", "image_generation", "audio_transcription" "chat",
"embedding",
"completion",
"image_generation",
"audio_transcription",
"rerank",
] ]
] = Field(default="chat") ] = Field(default="chat")
tpm: Optional[int] = None tpm: Optional[int] = None