forked from phoenix/litellm-mirror
(fix proxy) model_group/info support rerank models (#5955)
* fix /model_group/info on rerank * add test test_proxy_model_group_info_rerank
This commit is contained in:
parent
088d906276
commit
8bf7573fd8
2 changed files with 65 additions and 1 deletions
|
@ -1515,3 +1515,62 @@ async def test_proxy_model_group_alias_checks(prisma_client, hidden):
|
||||||
assert is_model_alias_in_list is False
|
assert is_model_alias_in_list is False
|
||||||
else:
|
else:
|
||||||
assert is_model_alias_in_list, f"models: {models}"
|
assert is_model_alias_in_list, f"models: {models}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_proxy_model_group_info_rerank(prisma_client):
|
||||||
|
"""
|
||||||
|
Check if rerank model is returned on the following endpoints
|
||||||
|
|
||||||
|
`/v1/models`
|
||||||
|
`/v1/model/info`
|
||||||
|
`/v1/model_group/info`
|
||||||
|
"""
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import HTTPException, Request, Response
|
||||||
|
from starlette.datastructures import URL
|
||||||
|
|
||||||
|
from litellm.proxy.proxy_server import model_group_info, model_info_v1, model_list
|
||||||
|
|
||||||
|
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
|
||||||
|
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
|
||||||
|
await litellm.proxy.proxy_server.prisma_client.connect()
|
||||||
|
|
||||||
|
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
|
||||||
|
|
||||||
|
_model_list = [
|
||||||
|
{
|
||||||
|
"model_name": "rerank-english-v3.0",
|
||||||
|
"litellm_params": {"model": "cohere/rerank-english-v3.0"},
|
||||||
|
"model_info": {
|
||||||
|
"mode": "rerank",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
router = litellm.Router(model_list=_model_list)
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_router", router)
|
||||||
|
setattr(litellm.proxy.proxy_server, "llm_model_list", _model_list)
|
||||||
|
|
||||||
|
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
|
||||||
|
request._url = URL(url="/v1/models")
|
||||||
|
|
||||||
|
resp = await model_list(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(resp["data"]) == 1
|
||||||
|
print(resp)
|
||||||
|
|
||||||
|
resp = await model_info_v1(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
)
|
||||||
|
models = resp["data"]
|
||||||
|
assert models[0]["model_info"]["mode"] == "rerank"
|
||||||
|
resp = await model_group_info(
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(models=[]),
|
||||||
|
)
|
||||||
|
|
||||||
|
print(resp)
|
||||||
|
models = resp["data"]
|
||||||
|
assert models[0].mode == "rerank"
|
||||||
|
|
|
@ -482,7 +482,12 @@ class ModelGroupInfo(BaseModel):
|
||||||
output_cost_per_token: Optional[float] = None
|
output_cost_per_token: Optional[float] = None
|
||||||
mode: Optional[
|
mode: Optional[
|
||||||
Literal[
|
Literal[
|
||||||
"chat", "embedding", "completion", "image_generation", "audio_transcription"
|
"chat",
|
||||||
|
"embedding",
|
||||||
|
"completion",
|
||||||
|
"image_generation",
|
||||||
|
"audio_transcription",
|
||||||
|
"rerank",
|
||||||
]
|
]
|
||||||
] = Field(default="chat")
|
] = Field(default="chat")
|
||||||
tpm: Optional[int] = None
|
tpm: Optional[int] = None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue