fix(utils.py): add get_provider_from_model()

This commit is contained in:
jinno 2024-05-07 09:05:50 +09:00
parent 24164eccaa
commit 58349e624e

View file

@ -2071,6 +2071,9 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
elif value.get("litellm_provider") == "bedrock":
if key not in litellm.bedrock_models:
litellm.bedrock_models.append(key)
elif value.get("litellm_provider") == "gemini":
if key not in litellm.gemini_models:
litellm.gemini_models.append(key)
return model_cost
@ -8390,6 +8393,31 @@ def get_valid_models() -> List[str]:
except Exception:
return [] # NON-Blocking
def get_provider_from_model(model_or_model_name):
"""
Returns a provider(llm_provider) from litellm.models_by_provider(provider_by_model)
Args:
model_or_model_name (str): The name. like "openai/gpt-3.5-turbo" or "gpt-3.5-turbo".
Returns:
llm_provider(str): provider or None
"""
if len(litellm.provider_by_model) == 0:
# models_by_provider: provider -> list[provider/model_name]
for provider, models in litellm.models_by_provider.items():
for model in models:
# provider_by_model : provider/model_name -> provider
# model_name -> provider
model_name = model.replace(provider + "/", "")
litellm.provider_by_model[model] = provider
if model_name:
litellm.provider_by_model[model_name] = provider
if model_or_model_name in litellm.provider_by_model:
return litellm.provider_by_model[model_or_model_name]
return None
def print_args_passed_to_litellm(original_function, args, kwargs):
try: