diff --git a/litellm/litellm_core_utils/get_model_cost_map.py b/litellm/litellm_core_utils/get_model_cost_map.py index b8bdaee19c..b6a3a243c4 100644 --- a/litellm/litellm_core_utils/get_model_cost_map.py +++ b/litellm/litellm_core_utils/get_model_cost_map.py @@ -13,7 +13,7 @@ import os import httpx -def get_model_cost_map(url: str): +def get_model_cost_map(url: str) -> dict: if ( os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) or os.getenv("LITELLM_LOCAL_MODEL_COST_MAP", False) == "True" diff --git a/litellm/router.py b/litellm/router.py index 4a466f4119..e71452bc00 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -339,9 +339,9 @@ class Router: ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### - cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( - "local" # default to an in-memory cache - ) + cache_type: Literal[ + "local", "redis", "redis-semantic", "s3", "disk" + ] = "local" # default to an in-memory cache redis_cache = None cache_config: Dict[str, Any] = {} @@ -562,9 +562,9 @@ class Router: ) ) - self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( - model_group_retry_policy - ) + self.model_group_retry_policy: Optional[ + Dict[str, RetryPolicy] + ] = model_group_retry_policy self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -1105,9 +1105,9 @@ class Router: """ Adds default litellm params to kwargs, if set. """ - self.default_litellm_params[metadata_variable_name] = ( - self.default_litellm_params.pop("metadata", {}) - ) + self.default_litellm_params[ + metadata_variable_name + ] = self.default_litellm_params.pop("metadata", {}) for k, v in self.default_litellm_params.items(): if ( k not in kwargs and v is not None @@ -3243,11 +3243,11 @@ class Router: if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3279,11 +3279,11 @@ class Router: e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -4856,7 +4856,7 @@ class Router: litellm_model_name_model_info: Optional[ModelInfo] = None try: - model_info = litellm.get_model_info(model=model_id) + model_info = litellm.model_cost.get(model_id) except Exception: pass @@ -4870,9 +4870,11 @@ class Router: ModelInfo, _update_dictionary( cast(dict, litellm_model_name_model_info).copy(), - cast(dict, model_info), + model_info, ), ) + elif litellm_model_name_model_info is not None: + model_info = litellm_model_name_model_info return model_info diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index 68a79f94a6..13eaeb09ab 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2767,3 +2767,24 @@ def test_router_dynamic_credentials(): deployment = router.get_deployment(model_id=original_model_id) assert deployment is not None assert deployment.litellm_params.api_key == original_api_key + + +def test_router_get_model_group_info(): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": {"model": "gpt-3.5-turbo"}, + }, + { + "model_name": "gpt-4", + "litellm_params": {"model": "gpt-4"}, + }, + ], + ) + + model_group_info = router.get_model_group_info(model_group="gpt-4") + assert model_group_info is not None + assert model_group_info.model_group == "gpt-4" + assert model_group_info.input_cost_per_token > 0 + assert model_group_info.output_cost_per_token > 0 \ No newline at end of file