mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
feat(utils.py): add register model helper function
This commit is contained in:
parent
5a3cfff080
commit
f10a4ce16b
3 changed files with 89 additions and 1 deletions
|
@ -320,6 +320,7 @@ from .utils import (
|
|||
check_valid_key,
|
||||
get_llm_provider,
|
||||
completion_with_config,
|
||||
register_model
|
||||
)
|
||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||
from .llms.anthropic import AnthropicConfig
|
||||
|
|
27
litellm/tests/test_register_model.py
Normal file
27
litellm/tests/test_register_model.py
Normal file
|
@ -0,0 +1,27 @@
|
|||
#### What this tests ####
|
||||
# This tests calling batch_completions by running 100 messages together
|
||||
|
||||
import sys, os
|
||||
import traceback
|
||||
import pytest
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
|
||||
def test_update_model_cost():
|
||||
try:
|
||||
litellm.register_model({
|
||||
"gpt-4": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.00002,
|
||||
"output_cost_per_token": 0.00006,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
})
|
||||
assert litellm.model_cost["gpt-4"]["input_cost_per_token"] == 0.00002
|
||||
except Exception as e:
|
||||
pytest.fail(f"An error occurred: {e}")
|
||||
|
||||
test_update_model_cost()
|
|
@ -1029,6 +1029,66 @@ def completion_cost(
|
|||
return 0.0 # this should not block a users execution path
|
||||
|
||||
####### HELPER FUNCTIONS ################
|
||||
def register_model(model_cost: dict):
|
||||
"""
|
||||
Register new / Override existing models (and their pricing) to specific providers.
|
||||
Example usage:
|
||||
model_cost_dict = {
|
||||
"gpt-4": {
|
||||
"max_tokens": 8192,
|
||||
"input_cost_per_token": 0.00003,
|
||||
"output_cost_per_token": 0.00006,
|
||||
"litellm_provider": "openai",
|
||||
"mode": "chat"
|
||||
},
|
||||
}
|
||||
"""
|
||||
for key, value in model_cost.items():
|
||||
## override / add new keys to the existing model cost dictionary
|
||||
litellm.model_cost[key] = model_cost[key]
|
||||
|
||||
# add new model names to provider lists
|
||||
if value.get('litellm_provider') == 'openai':
|
||||
if key not in litellm.open_ai_chat_completion_models:
|
||||
litellm.open_ai_chat_completion_models.append(key)
|
||||
elif value.get('litellm_provider') == 'text-completion-openai':
|
||||
if key not in litellm.open_ai_text_completion_models:
|
||||
litellm.open_ai_text_completion_models.append(key)
|
||||
elif value.get('litellm_provider') == 'cohere':
|
||||
if key not in litellm.cohere_models:
|
||||
litellm.cohere_models.append(key)
|
||||
elif value.get('litellm_provider') == 'anthropic':
|
||||
if key not in litellm.anthropic_models:
|
||||
litellm.anthropic_models.append(key)
|
||||
elif value.get('litellm_provider') == 'openrouter':
|
||||
split_string = key.split('/', 1)
|
||||
if key not in litellm.openrouter_models:
|
||||
litellm.openrouter_models.append(split_string[1])
|
||||
elif value.get('litellm_provider') == 'vertex_ai-text-models':
|
||||
if key not in litellm.vertex_text_models:
|
||||
litellm.vertex_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||
if key not in litellm.vertex_code_text_models:
|
||||
litellm.vertex_code_text_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
|
||||
if key not in litellm.vertex_chat_models:
|
||||
litellm.vertex_chat_models.append(key)
|
||||
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
|
||||
if key not in litellm.vertex_code_chat_models:
|
||||
litellm.vertex_code_chat_models.append(key)
|
||||
elif value.get('litellm_provider') == 'ai21':
|
||||
if key not in litellm.ai21_models:
|
||||
litellm.ai21_models.append(key)
|
||||
elif value.get('litellm_provider') == 'nlp_cloud':
|
||||
if key not in litellm.nlp_cloud_models:
|
||||
litellm.nlp_cloud_models.append(key)
|
||||
elif value.get('litellm_provider') == 'aleph_alpha':
|
||||
if key not in litellm.aleph_alpha_models:
|
||||
litellm.aleph_alpha_models.append(key)
|
||||
elif value.get('litellm_provider') == 'bedrock':
|
||||
if key not in litellm.bedrock_models:
|
||||
litellm.bedrock_models.append(key)
|
||||
|
||||
def get_litellm_params(
|
||||
return_async=False,
|
||||
api_key=None,
|
||||
|
@ -2232,7 +2292,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
|||
pass
|
||||
|
||||
|
||||
# NOTE: DEPRECATING this in favor of using success_handler() in
|
||||
# NOTE: DEPRECATING this in favor of using success_handler() in Logging:
|
||||
def handle_success(args, kwargs, result, start_time, end_time):
|
||||
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
||||
try:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue