(litellm sdk - perf improvement) - use O(1) set lookups for checking llm providers / models (#7672)

* fix get model info logic to use O(1) lookups

* perf - use O(1) lookup for get llm provider
This commit is contained in:
Ishaan Jaff 2025-01-10 14:16:30 -08:00 committed by GitHub
parent b3bd15e35a
commit c999b4efe1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 11 additions and 5 deletions

View file

@ -861,6 +861,7 @@ model_list = (
+ azure_text_models + azure_text_models
) )
model_list_set = set(model_list)
provider_list: List[Union[LlmProviders, str]] = list(LlmProviders) provider_list: List[Union[LlmProviders, str]] = list(LlmProviders)

View file

@ -141,7 +141,7 @@ def get_llm_provider( # noqa: PLR0915
# check if llm provider part of model name # check if llm provider part of model name
if ( if (
model.split("/", 1)[0] in litellm.provider_list model.split("/", 1)[0] in litellm.provider_list
and model.split("/", 1)[0] not in litellm.model_list and model.split("/", 1)[0] not in litellm.model_list_set
and len(model.split("/")) and len(model.split("/"))
> 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351 > 1 # handle edge case where user passes in `litellm --model mistral` https://github.com/BerriAI/litellm/issues/1351
): ):
@ -210,7 +210,9 @@ def get_llm_provider( # noqa: PLR0915
dynamic_api_key = get_secret_str("DEEPSEEK_API_KEY") dynamic_api_key = get_secret_str("DEEPSEEK_API_KEY")
elif endpoint == "https://api.friendli.ai/serverless/v1": elif endpoint == "https://api.friendli.ai/serverless/v1":
custom_llm_provider = "friendliai" custom_llm_provider = "friendliai"
dynamic_api_key = get_secret_str("FRIENDLIAI_API_KEY") or get_secret("FRIENDLI_TOKEN") dynamic_api_key = get_secret_str(
"FRIENDLIAI_API_KEY"
) or get_secret("FRIENDLI_TOKEN")
elif endpoint == "api.galadriel.com/v1": elif endpoint == "api.galadriel.com/v1":
custom_llm_provider = "galadriel" custom_llm_provider = "galadriel"
dynamic_api_key = get_secret_str("GALADRIEL_API_KEY") dynamic_api_key = get_secret_str("GALADRIEL_API_KEY")

View file

@ -1817,6 +1817,10 @@ class LlmProviders(str, Enum):
HUMANLOOP = "humanloop" HUMANLOOP = "humanloop"
# Create a set of all provider values for quick lookup
LlmProvidersSet = {provider.value for provider in LlmProviders}
class LiteLLMLoggingBaseClass: class LiteLLMLoggingBaseClass:
""" """
Base class for logging pre and post call Base class for logging pre and post call

View file

@ -133,6 +133,7 @@ from litellm.types.utils import (
Function, Function,
ImageResponse, ImageResponse,
LlmProviders, LlmProviders,
LlmProvidersSet,
Message, Message,
ModelInfo, ModelInfo,
ModelInfoBase, ModelInfoBase,
@ -4108,9 +4109,7 @@ def _get_model_info_helper( # noqa: PLR0915
): ):
_model_info = None _model_info = None
if custom_llm_provider and custom_llm_provider in [ if custom_llm_provider and custom_llm_provider in LlmProvidersSet:
provider.value for provider in LlmProviders
]:
# Check if the provider string exists in LlmProviders enum # Check if the provider string exists in LlmProviders enum
provider_config = ProviderConfigManager.get_provider_model_info( provider_config = ProviderConfigManager.get_provider_model_info(
model=model, provider=LlmProviders(custom_llm_provider) model=model, provider=LlmProviders(custom_llm_provider)