diff --git a/litellm/router.py b/litellm/router.py index 384c7f3389..3e52acfdfa 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -38,6 +38,7 @@ from litellm.utils import ( import copy from litellm._logging import verbose_router_logger import logging +from litellm.types.utils import ModelInfo as ModelMapInfo from litellm.types.router import ( Deployment, ModelInfo, @@ -3065,7 +3066,7 @@ class Router: try: model_info = litellm.get_model_info(model=litellm_params.model) except Exception as e: - continue + model_info = None # get llm provider try: model, llm_provider, _, _ = litellm.get_llm_provider( @@ -3075,6 +3076,21 @@ class Router: except litellm.exceptions.BadRequestError as e: continue + if model_info is None: + supported_openai_params = litellm.get_supported_openai_params( + model=model, custom_llm_provider=llm_provider + ) + model_info = ModelMapInfo( + max_tokens=None, + max_input_tokens=None, + max_output_tokens=None, + input_cost_per_token=0, + output_cost_per_token=0, + litellm_provider=llm_provider, + mode="chat", + supported_openai_params=supported_openai_params, + ) + if model_group_info is None: model_group_info = ModelGroupInfo( model_group=model_group, providers=[llm_provider], **model_info # type: ignore @@ -3089,18 +3105,26 @@ class Router: # supports_function_calling == True if llm_provider not in model_group_info.providers: model_group_info.providers.append(llm_provider) - if model_info.get("max_input_tokens", None) is not None and ( - model_group_info.max_input_tokens is None - or model_info["max_input_tokens"] - > model_group_info.max_input_tokens + if ( + model_info.get("max_input_tokens", None) is not None + and model_info["max_input_tokens"] is not None + and ( + model_group_info.max_input_tokens is None + or model_info["max_input_tokens"] + > model_group_info.max_input_tokens + ) ): model_group_info.max_input_tokens = model_info[ "max_input_tokens" ] - if model_info.get("max_output_tokens", None) is not None and ( - model_group_info.max_output_tokens is None - or model_info["max_output_tokens"] - > model_group_info.max_output_tokens + if ( + model_info.get("max_output_tokens", None) is not None + and model_info["max_output_tokens"] is not None + and ( + model_group_info.max_output_tokens is None + or model_info["max_output_tokens"] + > model_group_info.max_output_tokens + ) ): model_group_info.max_output_tokens = model_info[ "max_output_tokens" @@ -3137,7 +3161,10 @@ class Router: and model_info["supports_function_calling"] is True # type: ignore ): model_group_info.supports_function_calling = True - if model_info.get("supported_openai_params", None) is not None: + if ( + model_info.get("supported_openai_params", None) is not None + and model_info["supported_openai_params"] is not None + ): model_group_info.supported_openai_params = model_info[ "supported_openai_params" ] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index cc08361329..7efc628ca7 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -19,11 +19,13 @@ class ModelInfo(TypedDict): Model info for a given model, this is information found in litellm.model_prices_and_context_window.json """ - max_tokens: int - max_input_tokens: int - max_output_tokens: int + max_tokens: Optional[int] + max_input_tokens: Optional[int] + max_output_tokens: Optional[int] input_cost_per_token: float output_cost_per_token: float litellm_provider: str - mode: str - supported_openai_params: List[str] + mode: Literal[ + "completion", "embedding", "image_generation", "chat", "audio_transcription" + ] + supported_openai_params: Optional[List[str]]