mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix: Minor LiteLLM Fixes + Improvements (29/08/2024) (#5436)
* fix(model_checks.py): support returning wildcard models on `/v1/models` Fixes https://github.com/BerriAI/litellm/issues/4903 * fix(bedrock_httpx.py): support calling bedrock via api_base Closes https://github.com/BerriAI/litellm/pull/4587 * fix(litellm_logging.py): only leave last 4 char of gemini key unmasked Fixes https://github.com/BerriAI/litellm/issues/5433 * feat(router.py): support setting 'weight' param for models on router Closes https://github.com/BerriAI/litellm/issues/5410 * test(test_bedrock_completion.py): add unit test for custom api base * fix(model_checks.py): handle no "/" in model
This commit is contained in:
parent
f70b7575d2
commit
dd7b008161
12 changed files with 219 additions and 25 deletions
|
@ -1,9 +1,41 @@
|
|||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
from typing import List, Optional
|
||||
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
|
||||
from litellm.utils import get_valid_models
|
||||
|
||||
|
||||
def _check_wildcard_routing(model: str) -> bool:
|
||||
"""
|
||||
Returns True if a model is a provider wildcard.
|
||||
"""
|
||||
if model == "*":
|
||||
return True
|
||||
|
||||
if "/" in model:
|
||||
llm_provider, potential_wildcard = model.split("/", 1)
|
||||
if (
|
||||
llm_provider in litellm.provider_list and potential_wildcard == "*"
|
||||
): # e.g. anthropic/*
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_provider_models(provider: str) -> Optional[List[str]]:
|
||||
"""
|
||||
Returns the list of known models by provider
|
||||
"""
|
||||
if provider == "*":
|
||||
return get_valid_models()
|
||||
|
||||
if provider in litellm.models_by_provider:
|
||||
return litellm.models_by_provider[provider]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_key_models(
|
||||
|
@ -58,6 +90,8 @@ def get_complete_model_list(
|
|||
"""
|
||||
- If key list is empty -> defer to team list
|
||||
- If team list is empty -> defer to proxy model list
|
||||
|
||||
If list contains wildcard -> return known provider models
|
||||
"""
|
||||
|
||||
unique_models = set()
|
||||
|
@ -76,4 +110,18 @@ def get_complete_model_list(
|
|||
valid_models = get_valid_models()
|
||||
unique_models.update(valid_models)
|
||||
|
||||
return list(unique_models)
|
||||
models_to_remove = set()
|
||||
all_wildcard_models = []
|
||||
for model in unique_models:
|
||||
if _check_wildcard_routing(model=model):
|
||||
provider = model.split("/")[0]
|
||||
# get all known provider models
|
||||
wildcard_models = get_provider_models(provider=provider)
|
||||
if wildcard_models is not None:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
|
||||
for model in models_to_remove:
|
||||
unique_models.remove(model)
|
||||
|
||||
return list(unique_models) + all_wildcard_models
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue