mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
feat(proxy_server.py): expose new /model_group/info
endpoint
returns model-group level info on supported params, max tokens, pricing, etc.
This commit is contained in:
parent
bec13d465a
commit
22b6b99b34
6 changed files with 191 additions and 16 deletions
|
@ -48,6 +48,7 @@ from litellm.types.router import (
|
|||
RetryPolicy,
|
||||
AlertingConfig,
|
||||
DeploymentTypedDict,
|
||||
ModelGroupInfo,
|
||||
)
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.azure import get_azure_ad_token_from_oidc
|
||||
|
@ -3045,6 +3046,100 @@ class Router:
|
|||
return model
|
||||
return None
|
||||
|
||||
def get_model_group_info(self, model_group: str) -> Optional[ModelGroupInfo]:
|
||||
"""
|
||||
For a given model group name, return the combined model info
|
||||
|
||||
Returns:
|
||||
- ModelGroupInfo if able to construct a model group
|
||||
- None if error constructing model group info
|
||||
"""
|
||||
|
||||
model_group_info: Optional[ModelGroupInfo] = None
|
||||
|
||||
for model in self.model_list:
|
||||
if "model_name" in model and model["model_name"] == model_group:
|
||||
# model in model group found #
|
||||
litellm_params = LiteLLM_Params(**model["litellm_params"])
|
||||
# get model info
|
||||
try:
|
||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||
except Exception as e:
|
||||
continue
|
||||
# get llm provider
|
||||
try:
|
||||
model, llm_provider, _, _ = litellm.get_llm_provider(
|
||||
model=litellm_params.model,
|
||||
custom_llm_provider=litellm_params.custom_llm_provider,
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
if model_group_info is None:
|
||||
model_group_info = ModelGroupInfo(
|
||||
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
||||
)
|
||||
else:
|
||||
# if max_input_tokens > curr
|
||||
# if max_output_tokens > curr
|
||||
# if input_cost_per_token > curr
|
||||
# if output_cost_per_token > curr
|
||||
# supports_parallel_function_calling == True
|
||||
# supports_vision == True
|
||||
# supports_function_calling == True
|
||||
if llm_provider not in model_group_info.providers:
|
||||
model_group_info.providers.append(llm_provider)
|
||||
if model_info.get("max_input_tokens", None) is not None and (
|
||||
model_group_info.max_input_tokens is None
|
||||
or model_info["max_input_tokens"]
|
||||
> model_group_info.max_input_tokens
|
||||
):
|
||||
model_group_info.max_input_tokens = model_info[
|
||||
"max_input_tokens"
|
||||
]
|
||||
if model_info.get("max_output_tokens", None) is not None and (
|
||||
model_group_info.max_output_tokens is None
|
||||
or model_info["max_output_tokens"]
|
||||
> model_group_info.max_output_tokens
|
||||
):
|
||||
model_group_info.max_output_tokens = model_info[
|
||||
"max_output_tokens"
|
||||
]
|
||||
if model_info.get("input_cost_per_token", None) is not None and (
|
||||
model_group_info.input_cost_per_token is None
|
||||
or model_info["input_cost_per_token"]
|
||||
> model_group_info.input_cost_per_token
|
||||
):
|
||||
model_group_info.input_cost_per_token = model_info[
|
||||
"input_cost_per_token"
|
||||
]
|
||||
if model_info.get("output_cost_per_token", None) is not None and (
|
||||
model_group_info.output_cost_per_token is None
|
||||
or model_info["output_cost_per_token"]
|
||||
> model_group_info.output_cost_per_token
|
||||
):
|
||||
model_group_info.output_cost_per_token = model_info[
|
||||
"output_cost_per_token"
|
||||
]
|
||||
if (
|
||||
model_info.get("supports_parallel_function_calling", None)
|
||||
is not None
|
||||
and model_info["supports_parallel_function_calling"] == True # type: ignore
|
||||
):
|
||||
model_group_info.supports_parallel_function_calling = True
|
||||
if (
|
||||
model_info.get("supports_vision", None) is not None
|
||||
and model_info["supports_vision"] == True # type: ignore
|
||||
):
|
||||
model_group_info.supports_vision = True
|
||||
if (
|
||||
model_info.get("supports_function_calling", None) is not None
|
||||
and model_info["supports_function_calling"] == True # type: ignore
|
||||
):
|
||||
model_group_info.supports_function_calling = True
|
||||
|
||||
return model_group_info
|
||||
|
||||
def get_model_ids(self) -> List[str]:
|
||||
"""
|
||||
Returns list of model id's.
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue