From 58349e624ec3a7115f291b51510b679883dd2ff3 Mon Sep 17 00:00:00 2001 From: jinno Date: Tue, 7 May 2024 09:05:50 +0900 Subject: [PATCH] fix(utils.py): add get_provider_from_model() --- litellm/utils.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/litellm/utils.py b/litellm/utils.py index 1b37b77a54..163221d0ab 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2071,6 +2071,9 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915 elif value.get("litellm_provider") == "bedrock": if key not in litellm.bedrock_models: litellm.bedrock_models.append(key) + elif value.get("litellm_provider") == "gemini": + if key not in litellm.gemini_models: + litellm.gemini_models.append(key) return model_cost @@ -8390,6 +8393,31 @@ def get_valid_models() -> List[str]: except Exception: return [] # NON-Blocking +def get_provider_from_model(model_or_model_name): + """ + Returns a provider(llm_provider) from litellm.models_by_provider(provider_by_model) + + Args: + model_or_model_name (str): The name. like "openai/gpt-3.5-turbo" or "gpt-3.5-turbo". + + Returns: + llm_provider(str): provider or None + """ + if len(litellm.provider_by_model) == 0: + # models_by_provider: provider -> list[provider/model_name] + for provider, models in litellm.models_by_provider.items(): + for model in models: + # provider_by_model : provider/model_name -> provider + # model_name -> provider + model_name = model.replace(provider + "/", "") + litellm.provider_by_model[model] = provider + if model_name: + litellm.provider_by_model[model_name] = provider + + if model_or_model_name in litellm.provider_by_model: + return litellm.provider_by_model[model_or_model_name] + + return None def print_args_passed_to_litellm(original_function, args, kwargs): try: