mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Support checking provider /models
endpoints on proxy /v1/models
endpoint (#9958)
* feat(utils.py): support global flag for 'check_provider_endpoints' enables setting this for `/models` on proxy * feat(utils.py): add caching to 'get_valid_models' Prevents checking endpoint repeatedly * fix(utils.py): ensure mutations don't impact cached results * test(test_utils.py): add unit test to confirm cache invalidation logic * feat(utils.py): get_valid_models - support passing litellm params dynamically Allows for checking endpoints based on received credentials * test: update test * feat(model_checks.py): pass router credentials to get_valid_models - ensures it checks correct credentials * refactor(utils.py): refactor for simpler functions * fix: fix linting errors * fix(utils.py): fix test * fix(utils.py): set valid providers to custom_llm_provider, if given * test: update test * fix: fix ruff check error
This commit is contained in:
parent
e94eb4ec70
commit
33ead69c0a
6 changed files with 313 additions and 110 deletions
|
@ -128,19 +128,19 @@ prometheus_initialize_budget_metrics: Optional[bool] = False
|
|||
require_auth_for_metrics_endpoint: Optional[bool] = False
|
||||
argilla_batch_size: Optional[int] = None
|
||||
datadog_use_v1: Optional[bool] = False # if you want to use v1 datadog logged payload
|
||||
gcs_pub_sub_use_v1: Optional[bool] = (
|
||||
False # if you want to use v1 gcs pubsub logged payload
|
||||
)
|
||||
gcs_pub_sub_use_v1: Optional[
|
||||
bool
|
||||
] = False # if you want to use v1 gcs pubsub logged payload
|
||||
argilla_transformation_object: Optional[Dict[str, Any]] = None
|
||||
_async_input_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[Union[str, Callable, CustomLogger]] = (
|
||||
[]
|
||||
) # internal variable - async custom callbacks are routed here.
|
||||
_async_input_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_success_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
_async_failure_callback: List[
|
||||
Union[str, Callable, CustomLogger]
|
||||
] = [] # internal variable - async custom callbacks are routed here.
|
||||
pre_call_rules: List[Callable] = []
|
||||
post_call_rules: List[Callable] = []
|
||||
turn_off_message_logging: Optional[bool] = False
|
||||
|
@ -148,18 +148,18 @@ log_raw_request_response: bool = False
|
|||
redact_messages_in_exceptions: Optional[bool] = False
|
||||
redact_user_api_key_info: Optional[bool] = False
|
||||
filter_invalid_headers: Optional[bool] = False
|
||||
add_user_information_to_llm_headers: Optional[bool] = (
|
||||
None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
)
|
||||
add_user_information_to_llm_headers: Optional[
|
||||
bool
|
||||
] = None # adds user_id, team_id, token hash (params from StandardLoggingMetadata) to request headers
|
||||
store_audit_logs = False # Enterprise feature, allow users to see audit logs
|
||||
### end of callbacks #############
|
||||
|
||||
email: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
token: Optional[str] = (
|
||||
None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
email: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
token: Optional[
|
||||
str
|
||||
] = None # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
telemetry = True
|
||||
max_tokens: int = DEFAULT_MAX_TOKENS # OpenAI Defaults
|
||||
drop_params = bool(os.getenv("LITELLM_DROP_PARAMS", False))
|
||||
|
@ -235,24 +235,20 @@ enable_loadbalancing_on_batch_endpoints: Optional[bool] = None
|
|||
enable_caching_on_provider_specific_optional_params: bool = (
|
||||
False # feature-flag for caching on optional params - e.g. 'top_k'
|
||||
)
|
||||
caching: bool = (
|
||||
False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
caching_with_models: bool = (
|
||||
False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
)
|
||||
cache: Optional[Cache] = (
|
||||
None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
)
|
||||
caching: bool = False # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
caching_with_models: bool = False # # Not used anymore, will be removed in next MAJOR release - https://github.com/BerriAI/litellm/discussions/648
|
||||
cache: Optional[
|
||||
Cache
|
||||
] = None # cache object <- use this - https://docs.litellm.ai/docs/caching
|
||||
default_in_memory_ttl: Optional[float] = None
|
||||
default_redis_ttl: Optional[float] = None
|
||||
default_redis_batch_cache_expiry: Optional[float] = None
|
||||
model_alias_map: Dict[str, str] = {}
|
||||
model_group_alias_map: Dict[str, str] = {}
|
||||
max_budget: float = 0.0 # set the max budget across all providers
|
||||
budget_duration: Optional[str] = (
|
||||
None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
)
|
||||
budget_duration: Optional[
|
||||
str
|
||||
] = None # proxy only - resets budget after fixed duration. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
default_soft_budget: float = (
|
||||
DEFAULT_SOFT_BUDGET # by default all litellm proxy keys have a soft budget of 50.0
|
||||
)
|
||||
|
@ -261,15 +257,11 @@ forward_traceparent_to_llm_provider: bool = False
|
|||
|
||||
_current_cost = 0.0 # private variable, used if max budget is set
|
||||
error_logs: Dict = {}
|
||||
add_function_to_prompt: bool = (
|
||||
False # if function calling not supported by api, append function call details to system prompt
|
||||
)
|
||||
add_function_to_prompt: bool = False # if function calling not supported by api, append function call details to system prompt
|
||||
client_session: Optional[httpx.Client] = None
|
||||
aclient_session: Optional[httpx.AsyncClient] = None
|
||||
model_fallbacks: Optional[List] = None # Deprecated for 'litellm.fallbacks'
|
||||
model_cost_map_url: str = (
|
||||
"https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
)
|
||||
model_cost_map_url: str = "https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
|
||||
suppress_debug_info = False
|
||||
dynamodb_table_name: Optional[str] = None
|
||||
s3_callback_params: Optional[Dict] = None
|
||||
|
@ -292,9 +284,7 @@ disable_end_user_cost_tracking_prometheus_only: Optional[bool] = None
|
|||
custom_prometheus_metadata_labels: List[str] = []
|
||||
#### REQUEST PRIORITIZATION ####
|
||||
priority_reservation: Optional[Dict[str, float]] = None
|
||||
force_ipv4: bool = (
|
||||
False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
)
|
||||
force_ipv4: bool = False # when True, litellm will force ipv4 for all LLM requests. Some users have seen httpx ConnectionError when using ipv6.
|
||||
module_level_aclient = AsyncHTTPHandler(
|
||||
timeout=request_timeout, client_alias="module level aclient"
|
||||
)
|
||||
|
@ -308,13 +298,13 @@ fallbacks: Optional[List] = None
|
|||
context_window_fallbacks: Optional[List] = None
|
||||
content_policy_fallbacks: Optional[List] = None
|
||||
allowed_fails: int = 3
|
||||
num_retries_per_request: Optional[int] = (
|
||||
None # for the request overall (incl. fallbacks + model retries)
|
||||
)
|
||||
num_retries_per_request: Optional[
|
||||
int
|
||||
] = None # for the request overall (incl. fallbacks + model retries)
|
||||
####### SECRET MANAGERS #####################
|
||||
secret_manager_client: Optional[Any] = (
|
||||
None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
)
|
||||
secret_manager_client: Optional[
|
||||
Any
|
||||
] = None # list of instantiated key management clients - e.g. azure kv, infisical, etc.
|
||||
_google_kms_resource_name: Optional[str] = None
|
||||
_key_management_system: Optional[KeyManagementSystem] = None
|
||||
_key_management_settings: KeyManagementSettings = KeyManagementSettings()
|
||||
|
@ -325,6 +315,7 @@ from litellm.litellm_core_utils.get_model_cost_map import get_model_cost_map
|
|||
|
||||
model_cost = get_model_cost_map(url=model_cost_map_url)
|
||||
custom_prompt_dict: Dict[str, dict] = {}
|
||||
check_provider_endpoint = False
|
||||
|
||||
|
||||
####### THREAD-SPECIFIC DATA ####################
|
||||
|
@ -1064,10 +1055,10 @@ from .types.llms.custom_llm import CustomLLMItem
|
|||
from .types.utils import GenericStreamingChunk
|
||||
|
||||
custom_provider_map: List[CustomLLMItem] = []
|
||||
_custom_providers: List[str] = (
|
||||
[]
|
||||
) # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[bool] = (
|
||||
None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
)
|
||||
_custom_providers: List[
|
||||
str
|
||||
] = [] # internal helper util, used to track names of custom providers
|
||||
disable_hf_tokenizer_download: Optional[
|
||||
bool
|
||||
] = None # disable huggingface tokenizer download. Defaults to openai clk100
|
||||
global_disable_no_log_param: bool = False
|
||||
|
|
|
@ -33,6 +33,7 @@ model_list:
|
|||
litellm_settings:
|
||||
num_retries: 0
|
||||
callbacks: ["prometheus"]
|
||||
check_provider_endpoint: true
|
||||
|
||||
files_settings:
|
||||
- custom_llm_provider: gemini
|
||||
|
|
|
@ -1,11 +1,12 @@
|
|||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
import copy
|
||||
from typing import Dict, List, Optional, Set
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
||||
from litellm.router import Router
|
||||
from litellm.types.router import LiteLLM_Params
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
|
||||
|
@ -23,15 +24,20 @@ def _check_wildcard_routing(model: str) -> bool:
|
|||
return False
|
||||
|
||||
|
||||
def get_provider_models(provider: str) -> Optional[List[str]]:
|
||||
def get_provider_models(
|
||||
provider: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the list of known models by provider
|
||||
"""
|
||||
if provider == "*":
|
||||
return get_valid_models()
|
||||
return get_valid_models(litellm_params=litellm_params)
|
||||
|
||||
if provider in litellm.models_by_provider:
|
||||
provider_models = copy.deepcopy(litellm.models_by_provider[provider])
|
||||
provider_models = get_valid_models(
|
||||
custom_llm_provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
# provider_models = copy.deepcopy(litellm.models_by_provider[provider])
|
||||
for idx, _model in enumerate(provider_models):
|
||||
if provider not in _model:
|
||||
provider_models[idx] = f"{provider}/{_model}"
|
||||
|
@ -118,6 +124,7 @@ def get_complete_model_list(
|
|||
user_model: Optional[str],
|
||||
infer_model_from_keys: Optional[bool],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
"""Logic for returning complete model list for a given key + team pair"""
|
||||
|
||||
|
@ -143,19 +150,25 @@ def get_complete_model_list(
|
|||
unique_models.update(valid_models)
|
||||
|
||||
all_wildcard_models = _get_wildcard_models(
|
||||
unique_models=unique_models, return_wildcard_routes=return_wildcard_routes
|
||||
unique_models=unique_models,
|
||||
return_wildcard_routes=return_wildcard_routes,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
return list(unique_models) + all_wildcard_models
|
||||
|
||||
|
||||
def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
|
||||
def get_known_models_from_wildcard(
|
||||
wildcard_model: str, litellm_params: Optional[LiteLLM_Params] = None
|
||||
) -> List[str]:
|
||||
try:
|
||||
provider, model = wildcard_model.split("/", 1)
|
||||
except ValueError: # safely fail
|
||||
return []
|
||||
# get all known provider models
|
||||
wildcard_models = get_provider_models(provider=provider)
|
||||
wildcard_models = get_provider_models(
|
||||
provider=provider, litellm_params=litellm_params
|
||||
)
|
||||
if wildcard_models is None:
|
||||
return []
|
||||
if model == "*":
|
||||
|
@ -172,7 +185,9 @@ def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
|
|||
|
||||
|
||||
def _get_wildcard_models(
|
||||
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
|
||||
unique_models: Set[str],
|
||||
return_wildcard_routes: Optional[bool] = False,
|
||||
llm_router: Optional[Router] = None,
|
||||
) -> List[str]:
|
||||
models_to_remove = set()
|
||||
all_wildcard_models = []
|
||||
|
@ -183,12 +198,25 @@ def _get_wildcard_models(
|
|||
): # will add the wildcard route to the list eg: anthropic/*.
|
||||
all_wildcard_models.append(model)
|
||||
|
||||
# get all known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
|
||||
## get litellm params from model
|
||||
if llm_router is not None:
|
||||
model_list = llm_router.get_model_list(model_name=model)
|
||||
if model_list is not None:
|
||||
for router_model in model_list:
|
||||
wildcard_models = get_known_models_from_wildcard(
|
||||
wildcard_model=model,
|
||||
litellm_params=LiteLLM_Params(
|
||||
**router_model["litellm_params"] # type: ignore
|
||||
),
|
||||
)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
else:
|
||||
# get all known provider models
|
||||
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
|
||||
|
||||
if wildcard_models is not None:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
if wildcard_models is not None:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
|
||||
for model in models_to_remove:
|
||||
unique_models.remove(model)
|
||||
|
|
|
@ -803,9 +803,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
|||
dual_cache=user_api_key_cache
|
||||
)
|
||||
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||
redis_usage_cache: Optional[RedisCache] = (
|
||||
None # redis cache used for tracking spend, tpm/rpm limits
|
||||
)
|
||||
redis_usage_cache: Optional[
|
||||
RedisCache
|
||||
] = None # redis cache used for tracking spend, tpm/rpm limits
|
||||
user_custom_auth = None
|
||||
user_custom_key_generate = None
|
||||
user_custom_sso = None
|
||||
|
@ -1131,9 +1131,9 @@ async def update_cache( # noqa: PLR0915
|
|||
_id = "team_id:{}".format(team_id)
|
||||
try:
|
||||
# Fetch the existing cost for the given user
|
||||
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
||||
await user_api_key_cache.async_get_cache(key=_id)
|
||||
)
|
||||
existing_spend_obj: Optional[
|
||||
LiteLLM_TeamTable
|
||||
] = await user_api_key_cache.async_get_cache(key=_id)
|
||||
if existing_spend_obj is None:
|
||||
# do nothing if team not in api key cache
|
||||
return
|
||||
|
@ -2812,9 +2812,9 @@ async def initialize( # noqa: PLR0915
|
|||
user_api_base = api_base
|
||||
dynamic_config[user_model]["api_base"] = api_base
|
||||
if api_version:
|
||||
os.environ["AZURE_API_VERSION"] = (
|
||||
api_version # set this for azure - litellm can read this from the env
|
||||
)
|
||||
os.environ[
|
||||
"AZURE_API_VERSION"
|
||||
] = api_version # set this for azure - litellm can read this from the env
|
||||
if max_tokens: # model-specific param
|
||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||
if temperature: # model-specific param
|
||||
|
@ -3316,6 +3316,7 @@ async def model_list(
|
|||
user_model=user_model,
|
||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||
return_wildcard_routes=return_wildcard_routes,
|
||||
llm_router=llm_router,
|
||||
)
|
||||
|
||||
return dict(
|
||||
|
@ -7758,9 +7759,9 @@ async def get_config_list(
|
|||
hasattr(sub_field_info, "description")
|
||||
and sub_field_info.description is not None
|
||||
):
|
||||
nested_fields[idx].field_description = (
|
||||
sub_field_info.description
|
||||
)
|
||||
nested_fields[
|
||||
idx
|
||||
].field_description = sub_field_info.description
|
||||
idx += 1
|
||||
|
||||
_stored_in_db = None
|
||||
|
|
182
litellm/utils.py
182
litellm/utils.py
|
@ -5807,8 +5807,133 @@ def trim_messages(
|
|||
return messages
|
||||
|
||||
|
||||
from litellm.caching.in_memory_cache import InMemoryCache
|
||||
|
||||
|
||||
class AvailableModelsCache(InMemoryCache):
|
||||
def __init__(self, ttl_seconds: int = 300, max_size: int = 1000):
|
||||
super().__init__(ttl_seconds, max_size)
|
||||
self._env_hash: Optional[str] = None
|
||||
|
||||
def _get_env_hash(self) -> str:
|
||||
"""Create a hash of relevant environment variables"""
|
||||
env_vars = {
|
||||
k: v
|
||||
for k, v in os.environ.items()
|
||||
if k.startswith(("OPENAI", "ANTHROPIC", "AZURE", "AWS"))
|
||||
}
|
||||
return str(hash(frozenset(env_vars.items())))
|
||||
|
||||
def _check_env_changed(self) -> bool:
|
||||
"""Check if environment variables have changed"""
|
||||
current_hash = self._get_env_hash()
|
||||
if self._env_hash is None:
|
||||
self._env_hash = current_hash
|
||||
return True
|
||||
return current_hash != self._env_hash
|
||||
|
||||
def _get_cache_key(
|
||||
self,
|
||||
custom_llm_provider: Optional[str],
|
||||
litellm_params: Optional[LiteLLM_Params],
|
||||
) -> str:
|
||||
valid_str = ""
|
||||
|
||||
if litellm_params is not None:
|
||||
valid_str = litellm_params.model_dump_json()
|
||||
if custom_llm_provider is not None:
|
||||
valid_str = f"{custom_llm_provider}:{valid_str}"
|
||||
return hashlib.sha256(valid_str.encode()).hexdigest()
|
||||
|
||||
def get_cached_model_info(
|
||||
self,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> Optional[List[str]]:
|
||||
"""Get cached model info"""
|
||||
# Check if environment has changed
|
||||
if litellm_params is None and self._check_env_changed():
|
||||
self.cache_dict.clear()
|
||||
return None
|
||||
|
||||
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
|
||||
|
||||
result = cast(Optional[List[str]], self.get_cache(cache_key))
|
||||
|
||||
if result is not None:
|
||||
return copy.deepcopy(result)
|
||||
return result
|
||||
|
||||
def set_cached_model_info(
|
||||
self,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: Optional[LiteLLM_Params],
|
||||
available_models: List[str],
|
||||
):
|
||||
"""Set cached model info"""
|
||||
cache_key = self._get_cache_key(custom_llm_provider, litellm_params)
|
||||
self.set_cache(cache_key, copy.deepcopy(available_models))
|
||||
|
||||
|
||||
# Global cache instance
|
||||
_model_cache = AvailableModelsCache()
|
||||
|
||||
|
||||
def _infer_valid_provider_from_env_vars(
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
) -> List[str]:
|
||||
valid_providers: List[str] = []
|
||||
environ_keys = os.environ.keys()
|
||||
for provider in litellm.provider_list:
|
||||
if custom_llm_provider and provider != custom_llm_provider:
|
||||
continue
|
||||
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
||||
# litellm standardizes expected provider keys to
|
||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||
if (
|
||||
expected_provider_key_1 in environ_keys
|
||||
or expected_provider_key_2 in environ_keys
|
||||
):
|
||||
# key is set
|
||||
valid_providers.append(provider)
|
||||
|
||||
return valid_providers
|
||||
|
||||
|
||||
def _get_valid_models_from_provider_api(
|
||||
provider_config: BaseLLMModelInfo,
|
||||
custom_llm_provider: str,
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> List[str]:
|
||||
try:
|
||||
cached_result = _model_cache.get_cached_model_info(
|
||||
custom_llm_provider, litellm_params
|
||||
)
|
||||
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
models = provider_config.get_models(
|
||||
api_key=litellm_params.api_key if litellm_params is not None else None,
|
||||
api_base=litellm_params.api_base if litellm_params is not None else None,
|
||||
)
|
||||
|
||||
_model_cache.set_cached_model_info(custom_llm_provider, litellm_params, models)
|
||||
return models
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
return []
|
||||
|
||||
|
||||
def get_valid_models(
|
||||
check_provider_endpoint: bool = False, custom_llm_provider: Optional[str] = None
|
||||
check_provider_endpoint: Optional[bool] = None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
litellm_params: Optional[LiteLLM_Params] = None,
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns a list of valid LLMs based on the set environment variables
|
||||
|
@ -5819,31 +5944,21 @@ def get_valid_models(
|
|||
Returns:
|
||||
A list of valid LLMs
|
||||
"""
|
||||
|
||||
try:
|
||||
check_provider_endpoint = (
|
||||
check_provider_endpoint or litellm.check_provider_endpoint
|
||||
)
|
||||
# get keys set in .env
|
||||
environ_keys = os.environ.keys()
|
||||
valid_providers = []
|
||||
|
||||
valid_providers: List[str] = []
|
||||
valid_models: List[str] = []
|
||||
# for all valid providers, make a list of supported llms
|
||||
valid_models = []
|
||||
|
||||
for provider in litellm.provider_list:
|
||||
if custom_llm_provider and provider != custom_llm_provider:
|
||||
continue
|
||||
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
||||
# litellm standardizes expected provider keys to
|
||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||
if (
|
||||
expected_provider_key_1 in environ_keys
|
||||
or expected_provider_key_2 in environ_keys
|
||||
):
|
||||
# key is set
|
||||
valid_providers.append(provider)
|
||||
if custom_llm_provider:
|
||||
valid_providers = [custom_llm_provider]
|
||||
else:
|
||||
valid_providers = _infer_valid_provider_from_env_vars(custom_llm_provider)
|
||||
|
||||
for provider in valid_providers:
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
|
@ -5856,15 +5971,24 @@ def get_valid_models(
|
|||
|
||||
if provider == "azure":
|
||||
valid_models.append("Azure-LLM")
|
||||
elif provider_config is not None and check_provider_endpoint:
|
||||
try:
|
||||
models = provider_config.get_models()
|
||||
valid_models.extend(models)
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
elif (
|
||||
provider_config is not None
|
||||
and check_provider_endpoint
|
||||
and provider is not None
|
||||
):
|
||||
valid_models.extend(
|
||||
_get_valid_models_from_provider_api(
|
||||
provider_config,
|
||||
provider,
|
||||
litellm_params,
|
||||
)
|
||||
)
|
||||
else:
|
||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||
models_for_provider = copy.deepcopy(
|
||||
litellm.models_by_provider.get(provider, [])
|
||||
)
|
||||
valid_models.extend(models_for_provider)
|
||||
|
||||
return valid_models
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
|
|
|
@ -41,8 +41,10 @@ from unittest.mock import AsyncMock, MagicMock, patch
|
|||
|
||||
|
||||
# Assuming your trim_messages, shorten_message_to_fit_limit, and get_token_count functions are all in a module named 'message_utils'
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_mock_cache():
|
||||
from litellm.utils import _model_cache
|
||||
_model_cache.flush_cache()
|
||||
# Test 1: Check trimming of normal message
|
||||
def test_basic_trimming():
|
||||
messages = [
|
||||
|
@ -1539,6 +1541,7 @@ def test_get_valid_models_fireworks_ai(monkeypatch):
|
|||
litellm.module_level_client, "get", return_value=mock_response
|
||||
) as mock_post:
|
||||
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||
print("valid_models", valid_models)
|
||||
mock_post.assert_called_once()
|
||||
assert (
|
||||
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
|
||||
|
@ -2122,3 +2125,58 @@ def test_get_provider_audio_transcription_config():
|
|||
config = ProviderConfigManager.get_provider_audio_transcription_config(
|
||||
model="whisper-1", provider=provider
|
||||
)
|
||||
|
||||
|
||||
def test_get_valid_models_from_provider():
|
||||
"""
|
||||
Test that get_valid_models returns the correct models for a given provider
|
||||
"""
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
valid_models = get_valid_models(custom_llm_provider="openai")
|
||||
assert len(valid_models) > 0
|
||||
assert "gpt-4o-mini" in valid_models
|
||||
|
||||
print("Valid models: ", valid_models)
|
||||
valid_models.remove("gpt-4o-mini")
|
||||
assert "gpt-4o-mini" not in valid_models
|
||||
|
||||
valid_models = get_valid_models(custom_llm_provider="openai")
|
||||
assert len(valid_models) > 0
|
||||
assert "gpt-4o-mini" in valid_models
|
||||
|
||||
|
||||
|
||||
def test_get_valid_models_from_provider_cache_invalidation(monkeypatch):
|
||||
"""
|
||||
Test that get_valid_models returns the correct models for a given provider
|
||||
"""
|
||||
from litellm.utils import _model_cache
|
||||
|
||||
monkeypatch.setenv("OPENAI_API_KEY", "123")
|
||||
|
||||
_model_cache.set_cached_model_info("openai", litellm_params=None, available_models=["gpt-4o-mini"])
|
||||
monkeypatch.delenv("OPENAI_API_KEY")
|
||||
|
||||
assert _model_cache.get_cached_model_info("openai") is None
|
||||
|
||||
|
||||
|
||||
def test_get_valid_models_from_dynamic_api_key():
|
||||
"""
|
||||
Test that get_valid_models returns the correct models for a given provider
|
||||
"""
|
||||
from litellm.utils import get_valid_models
|
||||
from litellm.types.router import CredentialLiteLLMParams
|
||||
|
||||
creds = CredentialLiteLLMParams(api_key="123")
|
||||
|
||||
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
|
||||
assert len(valid_models) == 0
|
||||
|
||||
creds = CredentialLiteLLMParams(api_key=os.getenv("ANTHROPIC_API_KEY"))
|
||||
valid_models = get_valid_models(custom_llm_provider="anthropic", litellm_params=creds, check_provider_endpoint=True)
|
||||
assert len(valid_models) > 0
|
||||
assert "anthropic/claude-3-7-sonnet-20250219" in valid_models
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue