feat(proxy_server.py): expose new /model_group/info endpoint

returns model-group level info on supported params, max tokens, pricing, etc.
This commit is contained in:
Krrish Dholakia 2024-05-26 14:07:35 -07:00
parent bec13d465a
commit 22b6b99b34
6 changed files with 191 additions and 16 deletions

View file

@ -48,6 +48,7 @@ from litellm.types.router import (
RetryPolicy,
AlertingConfig,
DeploymentTypedDict,
ModelGroupInfo,
)
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.azure import get_azure_ad_token_from_oidc
@ -3045,6 +3046,100 @@ class Router:
return model
return None
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
"""
For a given model group name, return the combined model info
Returns:
- ModelGroupInfo if able to construct a model group
- None if error constructing model group info
"""
model_group_info: Optional[ModelGroupInfo] = None
for model in self.model_list:
if "model_name" in model and model["model_name"] == model_group:
# model in model group found #
litellm_params = LiteLLM_Params(**model["litellm_params"])
# get model info
try:
model_info = litellm.get_model_info(model=litellm_params.model)
except Exception as e:
continue
# get llm provider
try:
model, llm_provider, _, _ = litellm.get_llm_provider(
model=litellm_params.model,
custom_llm_provider=litellm_params.custom_llm_provider,
)
except Exception as e:
continue
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
)
else:
# if max_input_tokens > curr
# if max_output_tokens > curr
# if input_cost_per_token > curr
# if output_cost_per_token > curr
# supports_parallel_function_calling == True
# supports_vision == True
# supports_function_calling == True
if llm_provider not in model_group_info.providers:
model_group_info.providers.append(llm_provider)
if model_info.get("max_input_tokens", None) is not None and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
):
model_group_info.max_input_tokens = model_info[
"max_input_tokens"
]
if model_info.get("max_output_tokens", None) is not None and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
):
model_group_info.max_output_tokens = model_info[
"max_output_tokens"
]
if model_info.get("input_cost_per_token", None) is not None and (
model_group_info.input_cost_per_token is None
or model_info["input_cost_per_token"]
> model_group_info.input_cost_per_token
):
model_group_info.input_cost_per_token = model_info[
"input_cost_per_token"
]
if model_info.get("output_cost_per_token", None) is not None and (
model_group_info.output_cost_per_token is None
or model_info["output_cost_per_token"]
> model_group_info.output_cost_per_token
):
model_group_info.output_cost_per_token = model_info[
"output_cost_per_token"
]
if (
model_info.get("supports_parallel_function_calling", None)
is not None
and model_info["supports_parallel_function_calling"] == True # type: ignore
):
model_group_info.supports_parallel_function_calling = True
if (
model_info.get("supports_vision", None) is not None
and model_info["supports_vision"] == True # type: ignore
):
model_group_info.supports_vision = True
if (
model_info.get("supports_function_calling", None) is not None
and model_info["supports_function_calling"] == True # type: ignore
):
model_group_info.supports_function_calling = True
return model_group_info
def get_model_ids(self) -> List[str]:
"""
Returns list of model id's.