From fca2ffb4801fd41bb44bcc2871697f4169c7b01e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 17 Jun 2024 19:15:02 -0700 Subject: [PATCH] fix(utils.py): return cost above 128k from get_model_info --- litellm/utils.py | 35 +++++++++++++++++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index 8b640b16d..79269a4b8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4376,7 +4376,27 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod pass else: raise Exception - return _model_info + return ModelInfo( + max_tokens=_model_info.get("max_tokens", None), + max_input_tokens=_model_info.get("max_input_tokens", None), + max_output_tokens=_model_info.get("max_output_tokens", None), + input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), + output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), + litellm_provider=_model_info.get( + "litellm_provider", custom_llm_provider + ), + mode=_model_info.get("mode"), + supported_openai_params=supported_openai_params, + supports_system_messages=_model_info.get( + "supports_system_messages", None + ), + ) elif model in litellm.model_cost: _model_info = litellm.model_cost[model] _model_info["supported_openai_params"] = supported_openai_params @@ -4395,7 +4415,13 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod max_input_tokens=_model_info.get("max_input_tokens", None), max_output_tokens=_model_info.get("max_output_tokens", None), input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), litellm_provider=_model_info.get( "litellm_provider", custom_llm_provider ), @@ -4405,7 +4431,6 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod "supports_system_messages", None ), ) - return _model_info elif split_model in litellm.model_cost: _model_info = litellm.model_cost[split_model] _model_info["supported_openai_params"] = supported_openai_params @@ -4424,7 +4449,13 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod max_input_tokens=_model_info.get("max_input_tokens", None), max_output_tokens=_model_info.get("max_output_tokens", None), input_cost_per_token=_model_info.get("input_cost_per_token", 0), + input_cost_per_token_above_128k_tokens=_model_info.get( + "input_cost_per_token_above_128k_tokens", None + ), output_cost_per_token=_model_info.get("output_cost_per_token", 0), + output_cost_per_token_above_128k_tokens=_model_info.get( + "output_cost_per_token_above_128k_tokens", None + ), litellm_provider=_model_info.get( "litellm_provider", custom_llm_provider ),