From 0d200cd8dca29f32feacf6a237fb952e878784e9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 29 Nov 2023 20:14:31 -0800 Subject: [PATCH] feat(main.py): allow updating model cost via completion() --- litellm/main.py | 16 +++++++++++++++- litellm/tests/test_register_model.py | 13 ++++++++++++- litellm/utils.py | 1 - 3 files changed, 27 insertions(+), 3 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index ffe562641..00b328880 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -331,6 +331,9 @@ def completion( num_retries = kwargs.get("num_retries", None) ## deprecated max_retries = kwargs.get("max_retries", None) context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) + ### CUSTOM MODEL COST ### + input_cost_per_token = kwargs.get("input_cost_per_token", None) + output_cost_per_token = kwargs.get("output_cost_per_token", None) ### CUSTOM PROMPT TEMPLATE ### initial_prompt_value = kwargs.get("initial_prompt_value", None) roles = kwargs.get("roles", None) @@ -341,7 +344,7 @@ def completion( client = kwargs.get("client", None) ######## end of unpacking kwargs ########### openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] - litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm"] + litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider if mock_response: @@ -375,6 +378,17 @@ def completion( model=deployment_id custom_llm_provider="azure" model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base, api_key=api_key) + + ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### + if input_cost_per_token is not None and output_cost_per_token is not None: + litellm.register_model({ + model: { + "input_cost_per_token": input_cost_per_token, + "output_cost_per_token": output_cost_per_token, + "litellm_provider": custom_llm_provider + } + }) + ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### custom_prompt_dict = {} # type: ignore if initial_prompt_value or roles or final_prompt_value or bos_token or eos_token: custom_prompt_dict = {model: {}} diff --git a/litellm/tests/test_register_model.py b/litellm/tests/test_register_model.py index 56d0c39d1..185e96c20 100644 --- a/litellm/tests/test_register_model.py +++ b/litellm/tests/test_register_model.py @@ -33,4 +33,15 @@ def test_update_model_cost_map_url(): except Exception as e: pytest.fail(f"An error occurred: {e}") -test_update_model_cost_map_url() \ No newline at end of file +# test_update_model_cost_map_url() + +def test_update_model_cost_via_completion(): + try: + response = litellm.completion(model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey, how's it going?"}], input_cost_per_token=0.3, output_cost_per_token=0.4) + print(f"litellm.model_cost for gpt-3.5-turbo: {litellm.model_cost['gpt-3.5-turbo']}") + assert litellm.model_cost["gpt-3.5-turbo"]["input_cost_per_token"] == 0.3 + assert litellm.model_cost["gpt-3.5-turbo"]["output_cost_per_token"] == 0.4 + except Exception as e: + pytest.fail(f"An error occurred: {e}") + +test_update_model_cost_via_completion() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 924e14f5c..7947fbbe2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1794,7 +1794,6 @@ def register_model(model_cost: Union[str, dict]): for key, value in loaded_model_cost.items(): ## override / add new keys to the existing model cost dictionary litellm.model_cost[key] = loaded_model_cost[key] - # add new model names to provider lists if value.get('litellm_provider') == 'openai': if key not in litellm.open_ai_chat_completion_models: