mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat(utils.py): add register model helper function
This commit is contained in:
parent
5a3cfff080
commit
f10a4ce16b
3 changed files with 89 additions and 1 deletions
|
@ -320,6 +320,7 @@ from .utils import (
|
||||||
check_valid_key,
|
check_valid_key,
|
||||||
get_llm_provider,
|
get_llm_provider,
|
||||||
completion_with_config,
|
completion_with_config,
|
||||||
|
register_model
|
||||||
)
|
)
|
||||||
from .llms.huggingface_restapi import HuggingfaceConfig
|
from .llms.huggingface_restapi import HuggingfaceConfig
|
||||||
from .llms.anthropic import AnthropicConfig
|
from .llms.anthropic import AnthropicConfig
|
||||||
|
|
27
litellm/tests/test_register_model.py
Normal file
27
litellm/tests/test_register_model.py
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
#### What this tests ####
|
||||||
|
# This tests calling batch_completions by running 100 messages together
|
||||||
|
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
def test_update_model_cost():
|
||||||
|
try:
|
||||||
|
litellm.register_model({
|
||||||
|
"gpt-4": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00002,
|
||||||
|
"output_cost_per_token": 0.00006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
})
|
||||||
|
assert litellm.model_cost["gpt-4"]["input_cost_per_token"] == 0.00002
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An error occurred: {e}")
|
||||||
|
|
||||||
|
test_update_model_cost()
|
|
@ -1029,6 +1029,66 @@ def completion_cost(
|
||||||
return 0.0 # this should not block a users execution path
|
return 0.0 # this should not block a users execution path
|
||||||
|
|
||||||
####### HELPER FUNCTIONS ################
|
####### HELPER FUNCTIONS ################
|
||||||
|
def register_model(model_cost: dict):
|
||||||
|
"""
|
||||||
|
Register new / Override existing models (and their pricing) to specific providers.
|
||||||
|
Example usage:
|
||||||
|
model_cost_dict = {
|
||||||
|
"gpt-4": {
|
||||||
|
"max_tokens": 8192,
|
||||||
|
"input_cost_per_token": 0.00003,
|
||||||
|
"output_cost_per_token": 0.00006,
|
||||||
|
"litellm_provider": "openai",
|
||||||
|
"mode": "chat"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
for key, value in model_cost.items():
|
||||||
|
## override / add new keys to the existing model cost dictionary
|
||||||
|
litellm.model_cost[key] = model_cost[key]
|
||||||
|
|
||||||
|
# add new model names to provider lists
|
||||||
|
if value.get('litellm_provider') == 'openai':
|
||||||
|
if key not in litellm.open_ai_chat_completion_models:
|
||||||
|
litellm.open_ai_chat_completion_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'text-completion-openai':
|
||||||
|
if key not in litellm.open_ai_text_completion_models:
|
||||||
|
litellm.open_ai_text_completion_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'cohere':
|
||||||
|
if key not in litellm.cohere_models:
|
||||||
|
litellm.cohere_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'anthropic':
|
||||||
|
if key not in litellm.anthropic_models:
|
||||||
|
litellm.anthropic_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'openrouter':
|
||||||
|
split_string = key.split('/', 1)
|
||||||
|
if key not in litellm.openrouter_models:
|
||||||
|
litellm.openrouter_models.append(split_string[1])
|
||||||
|
elif value.get('litellm_provider') == 'vertex_ai-text-models':
|
||||||
|
if key not in litellm.vertex_text_models:
|
||||||
|
litellm.vertex_text_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
|
||||||
|
if key not in litellm.vertex_code_text_models:
|
||||||
|
litellm.vertex_code_text_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
|
||||||
|
if key not in litellm.vertex_chat_models:
|
||||||
|
litellm.vertex_chat_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
|
||||||
|
if key not in litellm.vertex_code_chat_models:
|
||||||
|
litellm.vertex_code_chat_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'ai21':
|
||||||
|
if key not in litellm.ai21_models:
|
||||||
|
litellm.ai21_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'nlp_cloud':
|
||||||
|
if key not in litellm.nlp_cloud_models:
|
||||||
|
litellm.nlp_cloud_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'aleph_alpha':
|
||||||
|
if key not in litellm.aleph_alpha_models:
|
||||||
|
litellm.aleph_alpha_models.append(key)
|
||||||
|
elif value.get('litellm_provider') == 'bedrock':
|
||||||
|
if key not in litellm.bedrock_models:
|
||||||
|
litellm.bedrock_models.append(key)
|
||||||
|
|
||||||
def get_litellm_params(
|
def get_litellm_params(
|
||||||
return_async=False,
|
return_async=False,
|
||||||
api_key=None,
|
api_key=None,
|
||||||
|
@ -2232,7 +2292,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# NOTE: DEPRECATING this in favor of using success_handler() in
|
# NOTE: DEPRECATING this in favor of using success_handler() in Logging:
|
||||||
def handle_success(args, kwargs, result, start_time, end_time):
|
def handle_success(args, kwargs, result, start_time, end_time):
|
||||||
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue