feat(utils.py): add register model helper function

This commit is contained in:
Krrish Dholakia 2023-10-19 18:26:21 -07:00
parent 90ce8e4afd
commit 8dda69e216
3 changed files with 89 additions and 1 deletions

View file

@ -1029,6 +1029,66 @@ def completion_cost(
return 0.0 # this should not block a users execution path
####### HELPER FUNCTIONS ################
def register_model(model_cost: dict):
"""
Register new / Override existing models (and their pricing) to specific providers.
Example usage:
model_cost_dict = {
"gpt-4": {
"max_tokens": 8192,
"input_cost_per_token": 0.00003,
"output_cost_per_token": 0.00006,
"litellm_provider": "openai",
"mode": "chat"
},
}
"""
for key, value in model_cost.items():
## override / add new keys to the existing model cost dictionary
litellm.model_cost[key] = model_cost[key]
# add new model names to provider lists
if value.get('litellm_provider') == 'openai':
if key not in litellm.open_ai_chat_completion_models:
litellm.open_ai_chat_completion_models.append(key)
elif value.get('litellm_provider') == 'text-completion-openai':
if key not in litellm.open_ai_text_completion_models:
litellm.open_ai_text_completion_models.append(key)
elif value.get('litellm_provider') == 'cohere':
if key not in litellm.cohere_models:
litellm.cohere_models.append(key)
elif value.get('litellm_provider') == 'anthropic':
if key not in litellm.anthropic_models:
litellm.anthropic_models.append(key)
elif value.get('litellm_provider') == 'openrouter':
split_string = key.split('/', 1)
if key not in litellm.openrouter_models:
litellm.openrouter_models.append(split_string[1])
elif value.get('litellm_provider') == 'vertex_ai-text-models':
if key not in litellm.vertex_text_models:
litellm.vertex_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-text-models':
if key not in litellm.vertex_code_text_models:
litellm.vertex_code_text_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-chat-models':
if key not in litellm.vertex_chat_models:
litellm.vertex_chat_models.append(key)
elif value.get('litellm_provider') == 'vertex_ai-code-chat-models':
if key not in litellm.vertex_code_chat_models:
litellm.vertex_code_chat_models.append(key)
elif value.get('litellm_provider') == 'ai21':
if key not in litellm.ai21_models:
litellm.ai21_models.append(key)
elif value.get('litellm_provider') == 'nlp_cloud':
if key not in litellm.nlp_cloud_models:
litellm.nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha':
if key not in litellm.aleph_alpha_models:
litellm.aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock':
if key not in litellm.bedrock_models:
litellm.bedrock_models.append(key)
def get_litellm_params(
return_async=False,
api_key=None,
@ -2232,7 +2292,7 @@ def handle_failure(exception, traceback_exception, start_time, end_time, args, k
pass
# NOTE: DEPRECATING this in favor of using success_handler() in
# NOTE: DEPRECATING this in favor of using success_handler() in Logging:
def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try: