mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Merge 58349e624e
into b82af5b826
This commit is contained in:
commit
4c7173824b
2 changed files with 30 additions and 0 deletions
|
@ -723,6 +723,8 @@ models_by_provider: dict = {
|
|||
"snowflake": snowflake_models,
|
||||
}
|
||||
|
||||
provider_by_model = {} # This will initialized by get_provider_from_model()
|
||||
|
||||
# mapping for those models which have larger equivalents
|
||||
longer_context_model_fallback_dict: dict = {
|
||||
# openai chat completion models
|
||||
|
|
|
@ -2345,6 +2345,9 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
|
|||
elif value.get("litellm_provider") == "bedrock":
|
||||
if key not in litellm.bedrock_models:
|
||||
litellm.bedrock_models.append(key)
|
||||
elif value.get("litellm_provider") == "gemini":
|
||||
if key not in litellm.gemini_models:
|
||||
litellm.gemini_models.append(key)
|
||||
return model_cost
|
||||
|
||||
|
||||
|
@ -6025,6 +6028,31 @@ def get_valid_models(
|
|||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
return [] # NON-Blocking
|
||||
|
||||
def get_provider_from_model(model_or_model_name):
|
||||
"""
|
||||
Returns a provider(llm_provider) from litellm.models_by_provider(provider_by_model)
|
||||
|
||||
Args:
|
||||
model_or_model_name (str): The name. like "openai/gpt-3.5-turbo" or "gpt-3.5-turbo".
|
||||
|
||||
Returns:
|
||||
llm_provider(str): provider or None
|
||||
"""
|
||||
if len(litellm.provider_by_model) == 0:
|
||||
# models_by_provider: provider -> list[provider/model_name]
|
||||
for provider, models in litellm.models_by_provider.items():
|
||||
for model in models:
|
||||
# provider_by_model : provider/model_name -> provider
|
||||
# model_name -> provider
|
||||
model_name = model.replace(provider + "/", "")
|
||||
litellm.provider_by_model[model] = provider
|
||||
if model_name:
|
||||
litellm.provider_by_model[model_name] = provider
|
||||
|
||||
if model_or_model_name in litellm.provider_by_model:
|
||||
return litellm.provider_by_model[model_or_model_name]
|
||||
|
||||
return None
|
||||
|
||||
def print_args_passed_to_litellm(original_function, args, kwargs):
|
||||
if not _is_debugging_on():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue