fix(fix-'get_model_group_info'-to-return-a-default-value-if-unmapped-model-group): allows model hub to return all model groupss

This commit is contained in:
Krrish Dholakia 2024-05-27 13:53:01 -07:00
parent 54e4a2f7ac
commit 67da24f144
2 changed files with 44 additions and 15 deletions

View file

@ -38,6 +38,7 @@ from litellm.utils import (
import copy
from litellm._logging import verbose_router_logger
import logging
from litellm.types.utils import ModelInfo as ModelMapInfo
from litellm.types.router import (
Deployment,
ModelInfo,
@ -3065,7 +3066,7 @@ class Router:
try:
model_info = litellm.get_model_info(model=litellm_params.model)
except Exception as e:
continue
model_info = None
# get llm provider
try:
model, llm_provider, _, _ = litellm.get_llm_provider(
@ -3075,6 +3076,21 @@ class Router:
except litellm.exceptions.BadRequestError as e:
continue
if model_info is None:
supported_openai_params = litellm.get_supported_openai_params(
model=model, custom_llm_provider=llm_provider
)
model_info = ModelMapInfo(
max_tokens=None,
max_input_tokens=None,
max_output_tokens=None,
input_cost_per_token=0,
output_cost_per_token=0,
litellm_provider=llm_provider,
mode="chat",
supported_openai_params=supported_openai_params,
)
if model_group_info is None:
model_group_info = ModelGroupInfo(
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
@ -3089,18 +3105,26 @@ class Router:
# supports_function_calling == True
if llm_provider not in model_group_info.providers:
model_group_info.providers.append(llm_provider)
if model_info.get("max_input_tokens", None) is not None and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
if (
model_info.get("max_input_tokens", None) is not None
and model_info["max_input_tokens"] is not None
and (
model_group_info.max_input_tokens is None
or model_info["max_input_tokens"]
> model_group_info.max_input_tokens
)
):
model_group_info.max_input_tokens = model_info[
"max_input_tokens"
]
if model_info.get("max_output_tokens", None) is not None and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
if (
model_info.get("max_output_tokens", None) is not None
and model_info["max_output_tokens"] is not None
and (
model_group_info.max_output_tokens is None
or model_info["max_output_tokens"]
> model_group_info.max_output_tokens
)
):
model_group_info.max_output_tokens = model_info[
"max_output_tokens"
@ -3137,7 +3161,10 @@ class Router:
and model_info["supports_function_calling"] is True # type: ignore
):
model_group_info.supports_function_calling = True
if model_info.get("supported_openai_params", None) is not None:
if (
model_info.get("supported_openai_params", None) is not None
and model_info["supported_openai_params"] is not None
):
model_group_info.supported_openai_params = model_info[
"supported_openai_params"
]