mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
fix(proxy_server.py): pass llm router to get complete model list (#10176)
allows model auth to work
This commit is contained in:
parent
e0a613f88a
commit
ce828408da
2 changed files with 31 additions and 28 deletions
|
@ -804,9 +804,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
|
||||||
dual_cache=user_api_key_cache
|
dual_cache=user_api_key_cache
|
||||||
)
|
)
|
||||||
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
|
||||||
redis_usage_cache: Optional[RedisCache] = (
|
redis_usage_cache: Optional[
|
||||||
None # redis cache used for tracking spend, tpm/rpm limits
|
RedisCache
|
||||||
)
|
] = None # redis cache used for tracking spend, tpm/rpm limits
|
||||||
user_custom_auth = None
|
user_custom_auth = None
|
||||||
user_custom_key_generate = None
|
user_custom_key_generate = None
|
||||||
user_custom_sso = None
|
user_custom_sso = None
|
||||||
|
@ -1132,9 +1132,9 @@ async def update_cache( # noqa: PLR0915
|
||||||
_id = "team_id:{}".format(team_id)
|
_id = "team_id:{}".format(team_id)
|
||||||
try:
|
try:
|
||||||
# Fetch the existing cost for the given user
|
# Fetch the existing cost for the given user
|
||||||
existing_spend_obj: Optional[LiteLLM_TeamTable] = (
|
existing_spend_obj: Optional[
|
||||||
await user_api_key_cache.async_get_cache(key=_id)
|
LiteLLM_TeamTable
|
||||||
)
|
] = await user_api_key_cache.async_get_cache(key=_id)
|
||||||
if existing_spend_obj is None:
|
if existing_spend_obj is None:
|
||||||
# do nothing if team not in api key cache
|
# do nothing if team not in api key cache
|
||||||
return
|
return
|
||||||
|
@ -2806,9 +2806,9 @@ async def initialize( # noqa: PLR0915
|
||||||
user_api_base = api_base
|
user_api_base = api_base
|
||||||
dynamic_config[user_model]["api_base"] = api_base
|
dynamic_config[user_model]["api_base"] = api_base
|
||||||
if api_version:
|
if api_version:
|
||||||
os.environ["AZURE_API_VERSION"] = (
|
os.environ[
|
||||||
api_version # set this for azure - litellm can read this from the env
|
"AZURE_API_VERSION"
|
||||||
)
|
] = api_version # set this for azure - litellm can read this from the env
|
||||||
if max_tokens: # model-specific param
|
if max_tokens: # model-specific param
|
||||||
dynamic_config[user_model]["max_tokens"] = max_tokens
|
dynamic_config[user_model]["max_tokens"] = max_tokens
|
||||||
if temperature: # model-specific param
|
if temperature: # model-specific param
|
||||||
|
@ -6160,6 +6160,7 @@ async def model_info_v1( # noqa: PLR0915
|
||||||
proxy_model_list=proxy_model_list,
|
proxy_model_list=proxy_model_list,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(all_models_str) > 0:
|
if len(all_models_str) > 0:
|
||||||
|
@ -6184,6 +6185,7 @@ def _get_model_group_info(
|
||||||
llm_router: Router, all_models_str: List[str], model_group: Optional[str]
|
llm_router: Router, all_models_str: List[str], model_group: Optional[str]
|
||||||
) -> List[ModelGroupInfo]:
|
) -> List[ModelGroupInfo]:
|
||||||
model_groups: List[ModelGroupInfo] = []
|
model_groups: List[ModelGroupInfo] = []
|
||||||
|
|
||||||
for model in all_models_str:
|
for model in all_models_str:
|
||||||
if model_group is not None and model_group != model:
|
if model_group is not None and model_group != model:
|
||||||
continue
|
continue
|
||||||
|
@ -6393,6 +6395,7 @@ async def model_group_info(
|
||||||
proxy_model_list=proxy_model_list,
|
proxy_model_list=proxy_model_list,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||||
|
llm_router=llm_router,
|
||||||
)
|
)
|
||||||
model_groups: List[ModelGroupInfo] = _get_model_group_info(
|
model_groups: List[ModelGroupInfo] = _get_model_group_info(
|
||||||
llm_router=llm_router, all_models_str=all_models_str, model_group=model_group
|
llm_router=llm_router, all_models_str=all_models_str, model_group=model_group
|
||||||
|
@ -7755,9 +7758,9 @@ async def get_config_list(
|
||||||
hasattr(sub_field_info, "description")
|
hasattr(sub_field_info, "description")
|
||||||
and sub_field_info.description is not None
|
and sub_field_info.description is not None
|
||||||
):
|
):
|
||||||
nested_fields[idx].field_description = (
|
nested_fields[
|
||||||
sub_field_info.description
|
idx
|
||||||
)
|
].field_description = sub_field_info.description
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
_stored_in_db = None
|
_stored_in_db = None
|
||||||
|
|
|
@ -339,9 +339,9 @@ class Router:
|
||||||
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
) # names of models under litellm_params. ex. azure/chatgpt-v-2
|
||||||
self.deployment_latency_map = {}
|
self.deployment_latency_map = {}
|
||||||
### CACHING ###
|
### CACHING ###
|
||||||
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
|
cache_type: Literal[
|
||||||
"local" # default to an in-memory cache
|
"local", "redis", "redis-semantic", "s3", "disk"
|
||||||
)
|
] = "local" # default to an in-memory cache
|
||||||
redis_cache = None
|
redis_cache = None
|
||||||
cache_config: Dict[str, Any] = {}
|
cache_config: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
@ -562,9 +562,9 @@ class Router:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
|
self.model_group_retry_policy: Optional[
|
||||||
model_group_retry_policy
|
Dict[str, RetryPolicy]
|
||||||
)
|
] = model_group_retry_policy
|
||||||
|
|
||||||
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
|
||||||
if allowed_fails_policy is not None:
|
if allowed_fails_policy is not None:
|
||||||
|
@ -3247,12 +3247,12 @@ class Router:
|
||||||
|
|
||||||
if isinstance(e, litellm.ContextWindowExceededError):
|
if isinstance(e, litellm.ContextWindowExceededError):
|
||||||
if context_window_fallbacks is not None:
|
if context_window_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[
|
||||||
self._get_fallback_model_group_from_fallbacks(
|
List[str]
|
||||||
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
fallbacks=context_window_fallbacks,
|
fallbacks=context_window_fallbacks,
|
||||||
model_group=model_group,
|
model_group=model_group,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
@ -3283,12 +3283,12 @@ class Router:
|
||||||
e.message += "\n{}".format(error_message)
|
e.message += "\n{}".format(error_message)
|
||||||
elif isinstance(e, litellm.ContentPolicyViolationError):
|
elif isinstance(e, litellm.ContentPolicyViolationError):
|
||||||
if content_policy_fallbacks is not None:
|
if content_policy_fallbacks is not None:
|
||||||
fallback_model_group: Optional[List[str]] = (
|
fallback_model_group: Optional[
|
||||||
self._get_fallback_model_group_from_fallbacks(
|
List[str]
|
||||||
|
] = self._get_fallback_model_group_from_fallbacks(
|
||||||
fallbacks=content_policy_fallbacks,
|
fallbacks=content_policy_fallbacks,
|
||||||
model_group=model_group,
|
model_group=model_group,
|
||||||
)
|
)
|
||||||
)
|
|
||||||
if fallback_model_group is None:
|
if fallback_model_group is None:
|
||||||
raise original_exception
|
raise original_exception
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue