mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(fix-'get_model_group_info'-to-return-a-default-value-if-unmapped-model-group): allows model hub to return all model groupss
This commit is contained in:
parent
54e4a2f7ac
commit
67da24f144
2 changed files with 44 additions and 15 deletions
|
@ -38,6 +38,7 @@ from litellm.utils import (
|
||||||
import copy
|
import copy
|
||||||
from litellm._logging import verbose_router_logger
|
from litellm._logging import verbose_router_logger
|
||||||
import logging
|
import logging
|
||||||
|
from litellm.types.utils import ModelInfo as ModelMapInfo
|
||||||
from litellm.types.router import (
|
from litellm.types.router import (
|
||||||
Deployment,
|
Deployment,
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
|
@ -3065,7 +3066,7 @@ class Router:
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(model=litellm_params.model)
|
model_info = litellm.get_model_info(model=litellm_params.model)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
continue
|
model_info = None
|
||||||
# get llm provider
|
# get llm provider
|
||||||
try:
|
try:
|
||||||
model, llm_provider, _, _ = litellm.get_llm_provider(
|
model, llm_provider, _, _ = litellm.get_llm_provider(
|
||||||
|
@ -3075,6 +3076,21 @@ class Router:
|
||||||
except litellm.exceptions.BadRequestError as e:
|
except litellm.exceptions.BadRequestError as e:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
if model_info is None:
|
||||||
|
supported_openai_params = litellm.get_supported_openai_params(
|
||||||
|
model=model, custom_llm_provider=llm_provider
|
||||||
|
)
|
||||||
|
model_info = ModelMapInfo(
|
||||||
|
max_tokens=None,
|
||||||
|
max_input_tokens=None,
|
||||||
|
max_output_tokens=None,
|
||||||
|
input_cost_per_token=0,
|
||||||
|
output_cost_per_token=0,
|
||||||
|
litellm_provider=llm_provider,
|
||||||
|
mode="chat",
|
||||||
|
supported_openai_params=supported_openai_params,
|
||||||
|
)
|
||||||
|
|
||||||
if model_group_info is None:
|
if model_group_info is None:
|
||||||
model_group_info = ModelGroupInfo(
|
model_group_info = ModelGroupInfo(
|
||||||
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
model_group=model_group, providers=[llm_provider], **model_info # type: ignore
|
||||||
|
@ -3089,18 +3105,26 @@ class Router:
|
||||||
# supports_function_calling == True
|
# supports_function_calling == True
|
||||||
if llm_provider not in model_group_info.providers:
|
if llm_provider not in model_group_info.providers:
|
||||||
model_group_info.providers.append(llm_provider)
|
model_group_info.providers.append(llm_provider)
|
||||||
if model_info.get("max_input_tokens", None) is not None and (
|
if (
|
||||||
model_group_info.max_input_tokens is None
|
model_info.get("max_input_tokens", None) is not None
|
||||||
or model_info["max_input_tokens"]
|
and model_info["max_input_tokens"] is not None
|
||||||
> model_group_info.max_input_tokens
|
and (
|
||||||
|
model_group_info.max_input_tokens is None
|
||||||
|
or model_info["max_input_tokens"]
|
||||||
|
> model_group_info.max_input_tokens
|
||||||
|
)
|
||||||
):
|
):
|
||||||
model_group_info.max_input_tokens = model_info[
|
model_group_info.max_input_tokens = model_info[
|
||||||
"max_input_tokens"
|
"max_input_tokens"
|
||||||
]
|
]
|
||||||
if model_info.get("max_output_tokens", None) is not None and (
|
if (
|
||||||
model_group_info.max_output_tokens is None
|
model_info.get("max_output_tokens", None) is not None
|
||||||
or model_info["max_output_tokens"]
|
and model_info["max_output_tokens"] is not None
|
||||||
> model_group_info.max_output_tokens
|
and (
|
||||||
|
model_group_info.max_output_tokens is None
|
||||||
|
or model_info["max_output_tokens"]
|
||||||
|
> model_group_info.max_output_tokens
|
||||||
|
)
|
||||||
):
|
):
|
||||||
model_group_info.max_output_tokens = model_info[
|
model_group_info.max_output_tokens = model_info[
|
||||||
"max_output_tokens"
|
"max_output_tokens"
|
||||||
|
@ -3137,7 +3161,10 @@ class Router:
|
||||||
and model_info["supports_function_calling"] is True # type: ignore
|
and model_info["supports_function_calling"] is True # type: ignore
|
||||||
):
|
):
|
||||||
model_group_info.supports_function_calling = True
|
model_group_info.supports_function_calling = True
|
||||||
if model_info.get("supported_openai_params", None) is not None:
|
if (
|
||||||
|
model_info.get("supported_openai_params", None) is not None
|
||||||
|
and model_info["supported_openai_params"] is not None
|
||||||
|
):
|
||||||
model_group_info.supported_openai_params = model_info[
|
model_group_info.supported_openai_params = model_info[
|
||||||
"supported_openai_params"
|
"supported_openai_params"
|
||||||
]
|
]
|
||||||
|
|
|
@ -19,11 +19,13 @@ class ModelInfo(TypedDict):
|
||||||
Model info for a given model, this is information found in litellm.model_prices_and_context_window.json
|
Model info for a given model, this is information found in litellm.model_prices_and_context_window.json
|
||||||
"""
|
"""
|
||||||
|
|
||||||
max_tokens: int
|
max_tokens: Optional[int]
|
||||||
max_input_tokens: int
|
max_input_tokens: Optional[int]
|
||||||
max_output_tokens: int
|
max_output_tokens: Optional[int]
|
||||||
input_cost_per_token: float
|
input_cost_per_token: float
|
||||||
output_cost_per_token: float
|
output_cost_per_token: float
|
||||||
litellm_provider: str
|
litellm_provider: str
|
||||||
mode: str
|
mode: Literal[
|
||||||
supported_openai_params: List[str]
|
"completion", "embedding", "image_generation", "chat", "audio_transcription"
|
||||||
|
]
|
||||||
|
supported_openai_params: Optional[List[str]]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue