From f10a4ce16b24c345970f83e2b7ac203e9ee9d2bc Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 19 Oct 2023 18:26:21 -0700 Subject: [PATCH] feat(utils.py): add register model helper function --- litellm/__init__.py | 1 + litellm/tests/test_register_model.py | 27 ++++++++++++ litellm/utils.py | 62 +++++++++++++++++++++++++++- 3 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 litellm/tests/test_register_model.py diff --git a/litellm/__init__.py b/litellm/__init__.py index e6cdf7ea21..90426c1053 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -320,6 +320,7 @@ from .utils import ( check_valid_key, get_llm_provider, completion_with_config, + register_model ) from .llms.huggingface_restapi import HuggingfaceConfig from .llms.anthropic import AnthropicConfig diff --git a/litellm/tests/test_register_model.py b/litellm/tests/test_register_model.py new file mode 100644 index 0000000000..684173fee8 --- /dev/null +++ b/litellm/tests/test_register_model.py @@ -0,0 +1,27 @@ +#### What this tests #### +# This tests calling batch_completions by running 100 messages together + +import sys, os +import traceback +import pytest +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm + +def test_update_model_cost(): + try: + litellm.register_model({ + "gpt-4": { + "max_tokens": 8192, + "input_cost_per_token": 0.00002, + "output_cost_per_token": 0.00006, + "litellm_provider": "openai", + "mode": "chat" + }, + }) + assert litellm.model_cost["gpt-4"]["input_cost_per_token"] == 0.00002 + except Exception as e: + pytest.fail(f"An error occurred: {e}") + +test_update_model_cost() \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 06a9f4f936..c9a4c6318f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1029,6 +1029,66 @@ def completion_cost( return 0.0 # this should not block a users execution path ####### HELPER FUNCTIONS ################ +def register_model(model_cost: dict): + """ + Register new / Override existing models (and their pricing) to specific providers. + Example usage: + model_cost_dict = { + "gpt-4": { + "max_tokens": 8192, + "input_cost_per_token": 0.00003, + "output_cost_per_token": 0.00006, + "litellm_provider": "openai", + "mode": "chat" + }, + } + """ + for key, value in model_cost.items(): + ## override / add new keys to the existing model cost dictionary + litellm.model_cost[key] = model_cost[key] + + # add new model names to provider lists + if value.get('litellm_provider') == 'openai': + if key not in litellm.open_ai_chat_completion_models: + litellm.open_ai_chat_completion_models.append(key) + elif value.get('litellm_provider') == 'text-completion-openai': + if key not in litellm.open_ai_text_completion_models: + litellm.open_ai_text_completion_models.append(key) + elif value.get('litellm_provider') == 'cohere': + if key not in litellm.cohere_models: + litellm.cohere_models.append(key) + elif value.get('litellm_provider') == 'anthropic': + if key not in litellm.anthropic_models: + litellm.anthropic_models.append(key) + elif value.get('litellm_provider') == 'openrouter': + split_string = key.split('/', 1) + if key not in litellm.openrouter_models: + litellm.openrouter_models.append(split_string[1]) + elif value.get('litellm_provider') == 'vertex_ai-text-models': + if key not in litellm.vertex_text_models: + litellm.vertex_text_models.append(key) + elif value.get('litellm_provider') == 'vertex_ai-code-text-models': + if key not in litellm.vertex_code_text_models: + litellm.vertex_code_text_models.append(key) + elif value.get('litellm_provider') == 'vertex_ai-chat-models': + if key not in litellm.vertex_chat_models: + litellm.vertex_chat_models.append(key) + elif value.get('litellm_provider') == 'vertex_ai-code-chat-models': + if key not in litellm.vertex_code_chat_models: + litellm.vertex_code_chat_models.append(key) + elif value.get('litellm_provider') == 'ai21': + if key not in litellm.ai21_models: + litellm.ai21_models.append(key) + elif value.get('litellm_provider') == 'nlp_cloud': + if key not in litellm.nlp_cloud_models: + litellm.nlp_cloud_models.append(key) + elif value.get('litellm_provider') == 'aleph_alpha': + if key not in litellm.aleph_alpha_models: + litellm.aleph_alpha_models.append(key) + elif value.get('litellm_provider') == 'bedrock': + if key not in litellm.bedrock_models: + litellm.bedrock_models.append(key) + def get_litellm_params( return_async=False, api_key=None, @@ -2232,7 +2292,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k pass -# NOTE: DEPRECATING this in favor of using success_handler() in +# NOTE: DEPRECATING this in favor of using success_handler() in Logging: def handle_success(args, kwargs, result, start_time, end_time): global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger try: