fix(__init__.py): add gemini models to all model list

Fixes https://github.com/BerriAI/litellm/issues/4240
This commit is contained in:
Krrish Dholakia 2024-06-17 10:54:28 -07:00
parent fa6ddcde3c
commit 6482e57f56
2 changed files with 32 additions and 17 deletions

View file

@ -338,6 +338,7 @@ bedrock_models: List = []
deepinfra_models: List = [] deepinfra_models: List = []
perplexity_models: List = [] perplexity_models: List = []
watsonx_models: List = [] watsonx_models: List = []
gemini_models: List = []
for key, value in model_cost.items(): for key, value in model_cost.items():
if value.get("litellm_provider") == "openai": if value.get("litellm_provider") == "openai":
open_ai_chat_completion_models.append(key) open_ai_chat_completion_models.append(key)
@ -384,7 +385,8 @@ for key, value in model_cost.items():
perplexity_models.append(key) perplexity_models.append(key)
elif value.get("litellm_provider") == "watsonx": elif value.get("litellm_provider") == "watsonx":
watsonx_models.append(key) watsonx_models.append(key)
elif value.get("litellm_provider") == "gemini":
gemini_models.append(key)
# known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary # known openai compatible endpoints - we'll eventually move this list to the model_prices_and_context_window.json dictionary
openai_compatible_endpoints: List = [ openai_compatible_endpoints: List = [
"api.perplexity.ai", "api.perplexity.ai",
@ -591,6 +593,7 @@ model_list = (
+ maritalk_models + maritalk_models
+ vertex_language_models + vertex_language_models
+ watsonx_models + watsonx_models
+ gemini_models
) )
provider_list: List = [ provider_list: List = [
@ -663,6 +666,7 @@ models_by_provider: dict = {
"perplexity": perplexity_models, "perplexity": perplexity_models,
"maritalk": maritalk_models, "maritalk": maritalk_models,
"watsonx": watsonx_models, "watsonx": watsonx_models,
"gemini": gemini_models,
} }
# mapping for those models which have larger equivalents # mapping for those models which have larger equivalents

View file

@ -1,9 +1,9 @@
import copy
import sys import sys
from datetime import datetime
from unittest import mock from unittest import mock
from dotenv import load_dotenv from dotenv import load_dotenv
import copy
from datetime import datetime
load_dotenv() load_dotenv()
import os import os
@ -12,25 +12,26 @@ sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm.utils import (
trim_messages,
get_token_count,
get_valid_models,
check_valid_key,
validate_environment,
function_to_dict,
token_counter,
create_pretrained_tokenizer,
create_tokenizer,
get_max_tokens,
get_supported_openai_params,
)
from litellm.proxy.utils import ( from litellm.proxy.utils import (
_duration_in_seconds, _duration_in_seconds,
_extract_from_regex, _extract_from_regex,
get_last_day_of_month, get_last_day_of_month,
) )
from litellm.utils import (
check_valid_key,
create_pretrained_tokenizer,
create_tokenizer,
function_to_dict,
get_max_tokens,
get_supported_openai_params,
get_token_count,
get_valid_models,
token_counter,
trim_messages,
validate_environment,
)
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils' # Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
@ -216,6 +217,16 @@ def test_get_valid_models():
# reset replicate env key # reset replicate env key
os.environ = old_environ os.environ = old_environ
# GEMINI
expected_models = litellm.gemini_models
old_environ = os.environ
os.environ = {"GEMINI_API_KEY": "temp"} # mock set only openai key in environ
valid_models = get_valid_models()
print(valid_models)
assert valid_models == expected_models
# test_get_valid_models() # test_get_valid_models()
@ -409,10 +420,10 @@ def test_redact_msgs_from_logs():
On the proxy some users were seeing the redaction impact client side responses On the proxy some users were seeing the redaction impact client side responses
""" """
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.redact_messages import ( from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
) )
from litellm.litellm_core_utils.litellm_logging import Logging
litellm.turn_off_message_logging = True litellm.turn_off_message_logging = True