This commit is contained in:
nobuo kawasaki 2025-04-24 01:06:59 -07:00 committed by GitHub
commit 4c7173824b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 30 additions and 0 deletions

View file

@ -723,6 +723,8 @@ models_by_provider: dict = {
"snowflake": snowflake_models,
}
provider_by_model = {} # This will initialized by get_provider_from_model()
# mapping for those models which have larger equivalents
longer_context_model_fallback_dict: dict = {
# openai chat completion models

View file

@ -2345,6 +2345,9 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
elif value.get("litellm_provider") == "bedrock":
if key not in litellm.bedrock_models:
litellm.bedrock_models.append(key)
elif value.get("litellm_provider") == "gemini":
if key not in litellm.gemini_models:
litellm.gemini_models.append(key)
return model_cost
@ -6025,6 +6028,31 @@ def get_valid_models(
verbose_logger.debug(f"Error getting valid models: {e}")
return [] # NON-Blocking
def get_provider_from_model(model_or_model_name):
"""
Returns a provider(llm_provider) from litellm.models_by_provider(provider_by_model)
Args:
model_or_model_name (str): The name. like "openai/gpt-3.5-turbo" or "gpt-3.5-turbo".
Returns:
llm_provider(str): provider or None
"""
if len(litellm.provider_by_model) == 0:
# models_by_provider: provider -> list[provider/model_name]
for provider, models in litellm.models_by_provider.items():
for model in models:
# provider_by_model : provider/model_name -> provider
# model_name -> provider
model_name = model.replace(provider + "/", "")
litellm.provider_by_model[model] = provider
if model_name:
litellm.provider_by_model[model_name] = provider
if model_or_model_name in litellm.provider_by_model:
return litellm.provider_by_model[model_or_model_name]
return None
def print_args_passed_to_litellm(original_function, args, kwargs):
if not _is_debugging_on():