fix: Minor LiteLLM Fixes + Improvements (29/08/2024) (#5436)

* fix(model_checks.py): support returning wildcard models on `/v1/models`

Fixes https://github.com/BerriAI/litellm/issues/4903

* fix(bedrock_httpx.py): support calling bedrock via api_base

Closes https://github.com/BerriAI/litellm/pull/4587

* fix(litellm_logging.py): only leave last 4 char of gemini key unmasked

Fixes https://github.com/BerriAI/litellm/issues/5433

* feat(router.py): support setting 'weight' param for models on router

Closes https://github.com/BerriAI/litellm/issues/5410

* test(test_bedrock_completion.py): add unit test for custom api base

* fix(model_checks.py): handle no "/" in model
This commit is contained in:
Krish Dholakia 2024-08-29 22:40:25 -07:00 committed by GitHub
parent f70b7575d2
commit dd7b008161
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 219 additions and 25 deletions

View file

@ -279,9 +279,7 @@ class Logging:
# Find the position of "key=" in the string # Find the position of "key=" in the string
key_index = api_base.find("key=") + 4 key_index = api_base.find("key=") + 4
# Mask the last 5 characters after "key=" # Mask the last 5 characters after "key="
masked_api_base = ( masked_api_base = api_base[:key_index] + "*" * 5 + api_base[-4:]
api_base[:key_index] + "*" * 5 + api_base[key_index + 5 :]
)
else: else:
masked_api_base = api_base masked_api_base = api_base
self.model_call_details["litellm_params"]["api_base"] = masked_api_base self.model_call_details["litellm_params"]["api_base"] = masked_api_base

View file

@ -48,13 +48,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import Choices from litellm.types.utils import Choices
from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import Message from litellm.types.utils import Message
from litellm.utils import ( from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
CustomStreamWrapper,
ModelResponse,
Usage,
get_secret,
print_verbose,
)
from .base import BaseLLM from .base import BaseLLM
from .base_aws_llm import BaseAWSLLM from .base_aws_llm import BaseAWSLLM
@ -654,6 +648,7 @@ class BedrockLLM(BaseAWSLLM):
self, self,
model: str, model: str,
messages: list, messages: list,
api_base: Optional[str],
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
@ -734,7 +729,9 @@ class BedrockLLM(BaseAWSLLM):
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = "" endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance( if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str aws_bedrock_runtime_endpoint, str
): ):
endpoint_url = aws_bedrock_runtime_endpoint endpoint_url = aws_bedrock_runtime_endpoint
@ -1459,7 +1456,7 @@ class BedrockConverseLLM(BaseAWSLLM):
client = client # type: ignore client = client # type: ignore
try: try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore response = await client.post(url=api_base, headers=headers, data=data) # type: ignore
response.raise_for_status() response.raise_for_status()
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
@ -1485,6 +1482,7 @@ class BedrockConverseLLM(BaseAWSLLM):
self, self,
model: str, model: str,
messages: list, messages: list,
api_base: Optional[str],
custom_prompt_dict: dict, custom_prompt_dict: dict,
model_response: ModelResponse, model_response: ModelResponse,
print_verbose: Callable, print_verbose: Callable,
@ -1565,7 +1563,9 @@ class BedrockConverseLLM(BaseAWSLLM):
### SET RUNTIME ENDPOINT ### ### SET RUNTIME ENDPOINT ###
endpoint_url = "" endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance( if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str aws_bedrock_runtime_endpoint, str
): ):
endpoint_url = aws_bedrock_runtime_endpoint endpoint_url = aws_bedrock_runtime_endpoint

View file

@ -1284,7 +1284,7 @@ class VertexLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]: ) -> Union[ModelResponse, CustomStreamWrapper]:
request_body = await async_transform_request_body(**data) # type: ignore request_body = await async_transform_request_body(**data) # type: ignore
if client is None: if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {} _params = {}
if timeout is not None: if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int): if isinstance(timeout, float) or isinstance(timeout, int):
@ -1293,6 +1293,16 @@ class VertexLLM(BaseLLM):
client = AsyncHTTPHandler(**_params) # type: ignore client = AsyncHTTPHandler(**_params) # type: ignore
else: else:
client = client # type: ignore client = client # type: ignore
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": api_base,
"headers": headers,
},
)
try: try:
response = await client.post(api_base, headers=headers, json=request_body) # type: ignore response = await client.post(api_base, headers=headers, json=request_body) # type: ignore

View file

@ -2361,6 +2361,7 @@ def completion(
timeout=timeout, timeout=timeout,
acompletion=acompletion, acompletion=acompletion,
client=client, client=client,
api_base=api_base,
) )
else: else:
response = bedrock_chat_completion.completion( response = bedrock_chat_completion.completion(
@ -2378,6 +2379,7 @@ def completion(
timeout=timeout, timeout=timeout,
acompletion=acompletion, acompletion=acompletion,
client=client, client=client,
api_base=api_base,
) )
if optional_params.get("stream", False): if optional_params.get("stream", False):

View file

@ -1,9 +1,4 @@
model_list: model_list:
- model_name: my-fake-openai-endpoint - model_name: "gemini/*"
litellm_params: litellm_params:
model: gpt-3.5-turbo model: "gemini/*"
api_key: "my-fake-key"
mock_response: "hello-world"
litellm_settings:
ssl_verify: false

View file

@ -1,9 +1,41 @@
# What is this? # What is this?
## Common checks for /v1/models and `/model/info` ## Common checks for /v1/models and `/model/info`
from typing import List, Optional from typing import List, Optional
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
from litellm.utils import get_valid_models import litellm
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.utils import get_valid_models
def _check_wildcard_routing(model: str) -> bool:
"""
Returns True if a model is a provider wildcard.
"""
if model == "*":
return True
if "/" in model:
llm_provider, potential_wildcard = model.split("/", 1)
if (
llm_provider in litellm.provider_list and potential_wildcard == "*"
): # e.g. anthropic/*
return True
return False
def get_provider_models(provider: str) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models()
if provider in litellm.models_by_provider:
return litellm.models_by_provider[provider]
return None
def get_key_models( def get_key_models(
@ -58,6 +90,8 @@ def get_complete_model_list(
""" """
- If key list is empty -> defer to team list - If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list - If team list is empty -> defer to proxy model list
If list contains wildcard -> return known provider models
""" """
unique_models = set() unique_models = set()
@ -76,4 +110,18 @@ def get_complete_model_list(
valid_models = get_valid_models() valid_models = get_valid_models()
unique_models.update(valid_models) unique_models.update(valid_models)
return list(unique_models) models_to_remove = set()
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
provider = model.split("/")[0]
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
if wildcard_models is not None:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
for model in models_to_remove:
unique_models.remove(model)
return list(unique_models) + all_wildcard_models

View file

@ -4700,6 +4700,31 @@ class Router:
) )
elif self.routing_strategy == "simple-shuffle": elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if 'weight' param set for a weighted pick #################
weight = (
healthy_deployments[0].get("litellm_params").get("weight", None)
)
if weight is not None:
# use weight-random pick if rpms provided
weights = [
m["litellm_params"].get("weight", 0)
for m in healthy_deployments
]
verbose_router_logger.debug(f"\nweight {weights}")
total_weight = sum(weights)
weights = [weight / total_weight for weight in weights]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(
range(len(weights)), weights=weights
)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick ################# ############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None: if rpm is not None:
@ -4847,6 +4872,25 @@ class Router:
) )
elif self.routing_strategy == "simple-shuffle": elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check 'weight' param set for weighted pick #################
weight = healthy_deployments[0].get("litellm_params").get("weight", None)
if weight is not None:
# use weight-random pick if rpms provided
weights = [
m["litellm_params"].get("weight", 0) for m in healthy_deployments
]
verbose_router_logger.debug(f"\nweight {weights}")
total_weight = sum(weights)
weights = [weight / total_weight for weight in weights]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(weights)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick ################# ############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None: if rpm is not None:

View file

@ -960,11 +960,16 @@ async def test_bedrock_extra_headers():
messages=[{"role": "user", "content": "What's AWS?"}], messages=[{"role": "user", "content": "What's AWS?"}],
client=client, client=client,
extra_headers={"test": "hello world", "Authorization": "my-test-key"}, extra_headers={"test": "hello world", "Authorization": "my-test-key"},
api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1",
) )
except Exception as e: except Exception as e:
pass pass
print(f"mock_client_post.call_args: {mock_client_post.call_args}") print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}")
assert (
mock_client_post.call_args.kwargs["url"]
== "https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1/model/anthropic.claude-3-sonnet-20240229-v1:0/converse"
)
assert "test" in mock_client_post.call_args.kwargs["headers"] assert "test" in mock_client_post.call_args.kwargs["headers"]
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world" assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
assert ( assert (

View file

@ -1347,3 +1347,33 @@ def test_logging_async_cache_hit_sync_call():
assert standard_logging_object["cache_hit"] is True assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0 assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0 assert standard_logging_object["saved_cache_cost"] > 0
def test_logging_key_masking_gemini():
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_pre_api_call", new=MagicMock()
) as mock_client:
try:
resp = litellm.completion(
model="gemini/gemini-1.5-pro",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
api_key="LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART",
)
except litellm.AuthenticationError:
pass
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert (
"LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART"
not in mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
)
key = mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
trimmed_key = key.split("key=")[1]
trimmed_key = trimmed_key.replace("*", "")
assert "PART" == trimmed_key

View file

@ -2342,3 +2342,55 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
assert e.cooldown_time == cooldown_time assert e.cooldown_time == cooldown_time
assert exception_raised assert exception_raised
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio()
async def test_router_weighted_pick(sync_mode):
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"weight": 2,
"mock_response": "Hello world 1!",
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"weight": 1,
"mock_response": "Hello world 2!",
},
"model_info": {"id": "2"},
},
]
)
model_id_1_count = 0
model_id_2_count = 0
for _ in range(50):
# make 50 calls. expect model id 1 to be picked more than model id 2
if sync_mode:
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world!"}],
)
else:
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world!"}],
)
model_id = int(response._hidden_params["model_id"])
if model_id == 1:
model_id_1_count += 1
elif model_id == 2:
model_id_2_count += 1
else:
raise Exception("invalid model id returned!")
assert model_id_1_count > model_id_2_count

View file

@ -299,6 +299,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
custom_llm_provider: Optional[str] custom_llm_provider: Optional[str]
tpm: Optional[int] tpm: Optional[int]
rpm: Optional[int] rpm: Optional[int]
order: Optional[int]
weight: Optional[int]
api_key: Optional[str] api_key: Optional[str]
api_base: Optional[str] api_base: Optional[str]
api_version: Optional[str] api_version: Optional[str]

View file

@ -7475,6 +7475,14 @@ def exception_type(
), ),
litellm_debug_info=extra_information, litellm_debug_info=extra_information,
) )
elif "API key not valid." in error_str:
exception_mapping_worked = True
raise AuthenticationError(
message=f"{custom_llm_provider}Exception - {error_str}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif "403" in error_str: elif "403" in error_str:
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(