mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(router.py): correctly handle retrieving model info on get_model_group_info
fixes issue where model hub was showing None prices
This commit is contained in:
parent
a6084fa37d
commit
2f85b5f6e1
3 changed files with 45 additions and 22 deletions
|
@ -13,7 +13,7 @@ import os
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
def get_model_cost_map(url: str):
|
def get_model_cost_map(url: str) -> dict:
|
||||||
if (
|
if (
|
||||||
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False)
|
os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False)
|
||||||
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
|
or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True"
|
||||||
|
|
|
@ -339,9 +339,9 @@ class Router:
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
|
cache_type: Literal[
|
||||||
"local" # default to an in-memory cache
|
"local", "redis", "redis-semantic", "s3", "disk"
|
||||||
)
|
] = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
cache_config: Dict[str, Any] = {}
|
cache_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@ -562,9 +562,9 @@ class Router:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
self.model_group_retry_policy: Optional[
|
||||||
model_group_retry_policy
|
Dict[str, RetryPolicy]
|
||||||
)
|
] = model_group_retry_policy
|
||||||
|
|
||||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
||||||
if allowed_fails_policy is not None:
|
if allowed_fails_policy is not None:
|
||||||
|
@ -1105,9 +1105,9 @@ class Router:
|
||||||
"""
|
"""
|
||||||
Adds default litellm params to kwargs, if set.
|
Adds default litellm params to kwargs, if set.
|
||||||
"""
|
"""
|
||||||
self.default_litellm_params[metadata_variable_name] = (
|
self.default_litellm_params[
|
||||||
self.default_litellm_params.pop("metadata", {})
|
metadata_variable_name
|
||||||
)
|
] = self.default_litellm_params.pop("metadata", {})
|
||||||
for k, v in self.default_litellm_params.items():
|
for k, v in self.default_litellm_params.items():
|
||||||
if (
|
if (
|
||||||
k not in kwargs and v is not None
|
k not in kwargs and v is not None
|
||||||
|
@ -3243,12 +3243,12 @@ class Router:
|
||||||
|
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
if context_window_fallbacks is not None:
|
if context_window_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[
|
||||||
self._get_fallback_model_group_from_fallbacks(
|
List[str]
|
||||||
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
fallbacks=context_window_fallbacks,
|
fallbacks=context_window_fallbacks,
|
||||||
model_group=model_group,
|
model_group=model_group,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
@ -3279,12 +3279,12 @@ class Router:
|
||||||
e.message += "\n{}".format(error_message)
|
e.message += "\n{}".format(error_message)
|
||||||
elif isinstance(e, litellm.ContentPolicyViolationError):
|
elif isinstance(e, litellm.ContentPolicyViolationError):
|
||||||
if content_policy_fallbacks is not None:
|
if content_policy_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[
|
||||||
self._get_fallback_model_group_from_fallbacks(
|
List[str]
|
||||||
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
fallbacks=content_policy_fallbacks,
|
fallbacks=content_policy_fallbacks,
|
||||||
model_group=model_group,
|
model_group=model_group,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
@ -4856,7 +4856,7 @@ class Router:
|
||||||
litellm_model_name_model_info: Optional[ModelInfo] = None
|
litellm_model_name_model_info: Optional[ModelInfo] = None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
model_info = litellm.get_model_info(model=model_id)
|
model_info = litellm.model_cost.get(model_id)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -4870,9 +4870,11 @@ class Router:
|
||||||
ModelInfo,
|
ModelInfo,
|
||||||
_update_dictionary(
|
_update_dictionary(
|
||||||
cast(dict, litellm_model_name_model_info).copy(),
|
cast(dict, litellm_model_name_model_info).copy(),
|
||||||
cast(dict, model_info),
|
model_info,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
elif litellm_model_name_model_info is not None:
|
||||||
|
model_info = litellm_model_name_model_info
|
||||||
|
|
||||||
return model_info
|
return model_info
|
||||||
|
|
||||||
|
|
|
@ -2767,3 +2767,24 @@ def test_router_dynamic_credentials():
|
||||||
deployment = router.get_deployment(model_id=original_model_id)
|
deployment = router.get_deployment(model_id=original_model_id)
|
||||||
assert deployment is not None
|
assert deployment is not None
|
||||||
assert deployment.litellm_params.api_key == original_api_key
|
assert deployment.litellm_params.api_key == original_api_key
|
||||||
|
|
||||||
|
|
||||||
|
def test_router_get_model_group_info():
|
||||||
|
router = Router(
|
||||||
|
model_list=[
|
||||||
|
{
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {"model": "gpt-3.5-turbo"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"model_name": "gpt-4",
|
||||||
|
"litellm_params": {"model": "gpt-4"},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
model_group_info = router.get_model_group_info(model_group="gpt-4")
|
||||||
|
assert model_group_info is not None
|
||||||
|
assert model_group_info.model_group == "gpt-4"
|
||||||
|
assert model_group_info.input_cost_per_token > 0
|
||||||
|
assert model_group_info.output_cost_per_token > 0
|
Loading…
Add table
Add a link
Reference in a new issue