mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
(fix proxy) model_group/info support rerank models (#5955)
* fix /model_group/info on rerank * add test test_proxy_model_group_info_rerank
This commit is contained in:
parent
088d906276
commit
8bf7573fd8
2 changed files with 65 additions and 1 deletions
|
@ -1515,3 +1515,62 @@ async def test_proxy_model_group_alias_checks(prisma_client, hidden):
|
|||
assert is_model_alias_in_list is False
|
||||
else:
|
||||
assert is_model_alias_in_list, f"models: {models}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_proxy_model_group_info_rerank(prisma_client):
|
||||
"""
|
||||
Check if rerank model is returned on the following endpoints
|
||||
|
||||
`/v1/models`
|
||||
`/v1/model/info`
|
||||
`/v1/model_group/info`
|
||||
"""
|
||||
import json
|
||||
|
||||
from fastapi import HTTPException, Request, Response
|
||||
from starlette.datastructures import URL
|
||||
|
||||
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
||||
|
||||
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||
|
||||
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
||||
|
||||
_model_list = [
|
||||
{
|
||||
"model_name": "rerank-english-v3.0",
|
||||
"litellm_params": {"model": "cohere/rerank-english-v3.0"},
|
||||
"model_info": {
|
||||
"mode": "rerank",
|
||||
},
|
||||
}
|
||||
]
|
||||
router = litellm.Router(model_list=_model_list)
|
||||
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
||||
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
||||
|
||||
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||
request._url = URL(url="/v1/models")
|
||||
|
||||
resp = await model_list(
|
||||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||
)
|
||||
|
||||
assert len(resp["data"]) == 1
|
||||
print(resp)
|
||||
|
||||
resp = await model_info_v1(
|
||||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||
)
|
||||
models = resp["data"]
|
||||
assert models[0]["model_info"]["mode"] == "rerank"
|
||||
resp = await model_group_info(
|
||||
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||
)
|
||||
|
||||
print(resp)
|
||||
models = resp["data"]
|
||||
assert models[0].mode == "rerank"
|
||||
|
|
|
@ -482,7 +482,12 @@ class ModelGroupInfo(BaseModel):
|
|||
output_cost_per_token: Optional[float] = None
|
||||
mode: Optional[
|
||||
Literal[
|
||||
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
||||
"chat",
|
||||
"embedding",
|
||||
"completion",
|
||||
"image_generation",
|
||||
"audio_transcription",
|
||||
"rerank",
|
||||
]
|
||||
] = Field(default="chat")
|
||||
tpm: Optional[int] = None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue