fix: Minor LiteLLM Fixes + Improvements (29/08/2024) (#5436)

* fix(model_checks.py): support returning wildcard models on `/v1/models`

Fixes https://github.com/BerriAI/litellm/issues/4903

* fix(bedrock_httpx.py): support calling bedrock via api_base

Closes https://github.com/BerriAI/litellm/pull/4587

* fix(litellm_logging.py): only leave last 4 char of gemini key unmasked

Fixes https://github.com/BerriAI/litellm/issues/5433

* feat(router.py): support setting 'weight' param for models on router

Closes https://github.com/BerriAI/litellm/issues/5410

* test(test_bedrock_completion.py): add unit test for custom api base

* fix(model_checks.py): handle no "/" in model
This commit is contained in:
Krish Dholakia 2024-08-29 22:40:25 -07:00 committed by GitHub
parent f70b7575d2
commit dd7b008161
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
12 changed files with 219 additions and 25 deletions

View file

@ -279,9 +279,7 @@ class Logging:
# Find the position of "key=" in the string
key_index = api_base.find("key=") + 4
# Mask the last 5 characters after "key="
masked_api_base = (
api_base[:key_index] + "*" * 5 + api_base[key_index + 5 :]
)
masked_api_base = api_base[:key_index] + "*" * 5 + api_base[-4:]
else:
masked_api_base = api_base
self.model_call_details["litellm_params"]["api_base"] = masked_api_base

View file

@ -48,13 +48,7 @@ from litellm.types.llms.openai import (
from litellm.types.utils import Choices
from litellm.types.utils import GenericStreamingChunk as GChunk
from litellm.types.utils import Message
from litellm.utils import (
CustomStreamWrapper,
ModelResponse,
Usage,
get_secret,
print_verbose,
)
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret
from .base import BaseLLM
from .base_aws_llm import BaseAWSLLM
@ -654,6 +648,7 @@ class BedrockLLM(BaseAWSLLM):
self,
model: str,
messages: list,
api_base: Optional[str],
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
@ -734,7 +729,9 @@ class BedrockLLM(BaseAWSLLM):
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint
@ -1459,7 +1456,7 @@ class BedrockConverseLLM(BaseAWSLLM):
client = client # type: ignore
try:
response = await client.post(api_base, headers=headers, data=data) # type: ignore
response = await client.post(url=api_base, headers=headers, data=data) # type: ignore
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
@ -1485,6 +1482,7 @@ class BedrockConverseLLM(BaseAWSLLM):
self,
model: str,
messages: list,
api_base: Optional[str],
custom_prompt_dict: dict,
model_response: ModelResponse,
print_verbose: Callable,
@ -1565,7 +1563,9 @@ class BedrockConverseLLM(BaseAWSLLM):
### SET RUNTIME ENDPOINT ###
endpoint_url = ""
env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT")
if aws_bedrock_runtime_endpoint is not None and isinstance(
if api_base is not None:
endpoint_url = api_base
elif aws_bedrock_runtime_endpoint is not None and isinstance(
aws_bedrock_runtime_endpoint, str
):
endpoint_url = aws_bedrock_runtime_endpoint

View file

@ -1284,7 +1284,7 @@ class VertexLLM(BaseLLM):
) -> Union[ModelResponse, CustomStreamWrapper]:
request_body = await async_transform_request_body(**data) # type: ignore
if client is None:
if client is None or not isinstance(client, AsyncHTTPHandler):
_params = {}
if timeout is not None:
if isinstance(timeout, float) or isinstance(timeout, int):
@ -1293,6 +1293,16 @@ class VertexLLM(BaseLLM):
client = AsyncHTTPHandler(**_params) # type: ignore
else:
client = client # type: ignore
## LOGGING
logging_obj.pre_call(
input=messages,
api_key="",
additional_args={
"complete_input_dict": request_body,
"api_base": api_base,
"headers": headers,
},
)
try:
response = await client.post(api_base, headers=headers, json=request_body) # type: ignore

View file

@ -2361,6 +2361,7 @@ def completion(
timeout=timeout,
acompletion=acompletion,
client=client,
api_base=api_base,
)
else:
response = bedrock_chat_completion.completion(
@ -2378,6 +2379,7 @@ def completion(
timeout=timeout,
acompletion=acompletion,
client=client,
api_base=api_base,
)
if optional_params.get("stream", False):

View file

@ -1,9 +1,4 @@
model_list:
- model_name: my-fake-openai-endpoint
- model_name: "gemini/*"
litellm_params:
model: gpt-3.5-turbo
api_key: "my-fake-key"
mock_response: "hello-world"
litellm_settings:
ssl_verify: false
model: "gemini/*"

View file

@ -1,9 +1,41 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import List, Optional
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
from litellm.utils import get_valid_models
import litellm
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth
from litellm.utils import get_valid_models
def _check_wildcard_routing(model: str) -> bool:
"""
Returns True if a model is a provider wildcard.
"""
if model == "*":
return True
if "/" in model:
llm_provider, potential_wildcard = model.split("/", 1)
if (
llm_provider in litellm.provider_list and potential_wildcard == "*"
): # e.g. anthropic/*
return True
return False
def get_provider_models(provider: str) -> Optional[List[str]]:
"""
Returns the list of known models by provider
"""
if provider == "*":
return get_valid_models()
if provider in litellm.models_by_provider:
return litellm.models_by_provider[provider]
return None
def get_key_models(
@ -58,6 +90,8 @@ def get_complete_model_list(
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
If list contains wildcard -> return known provider models
"""
unique_models = set()
@ -76,4 +110,18 @@ def get_complete_model_list(
valid_models = get_valid_models()
unique_models.update(valid_models)
return list(unique_models)
models_to_remove = set()
all_wildcard_models = []
for model in unique_models:
if _check_wildcard_routing(model=model):
provider = model.split("/")[0]
# get all known provider models
wildcard_models = get_provider_models(provider=provider)
if wildcard_models is not None:
models_to_remove.add(model)
all_wildcard_models.extend(wildcard_models)
for model in models_to_remove:
unique_models.remove(model)
return list(unique_models) + all_wildcard_models

View file

@ -4700,6 +4700,31 @@ class Router:
)
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check if 'weight' param set for a weighted pick #################
weight = (
healthy_deployments[0].get("litellm_params").get("weight", None)
)
if weight is not None:
# use weight-random pick if rpms provided
weights = [
m["litellm_params"].get("weight", 0)
for m in healthy_deployments
]
verbose_router_logger.debug(f"\nweight {weights}")
total_weight = sum(weights)
weights = [weight / total_weight for weight in weights]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(
range(len(weights)), weights=weights
)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None:
@ -4847,6 +4872,25 @@ class Router:
)
elif self.routing_strategy == "simple-shuffle":
# if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm
############## Check 'weight' param set for weighted pick #################
weight = healthy_deployments[0].get("litellm_params").get("weight", None)
if weight is not None:
# use weight-random pick if rpms provided
weights = [
m["litellm_params"].get("weight", 0) for m in healthy_deployments
]
verbose_router_logger.debug(f"\nweight {weights}")
total_weight = sum(weights)
weights = [weight / total_weight for weight in weights]
verbose_router_logger.debug(f"\n weights {weights}")
# Perform weighted random pick
selected_index = random.choices(range(len(weights)), weights=weights)[0]
verbose_router_logger.debug(f"\n selected index, {selected_index}")
deployment = healthy_deployments[selected_index]
verbose_router_logger.info(
f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}"
)
return deployment or deployment[0]
############## Check if we can do a RPM/TPM based weighted pick #################
rpm = healthy_deployments[0].get("litellm_params").get("rpm", None)
if rpm is not None:

View file

@ -960,11 +960,16 @@ async def test_bedrock_extra_headers():
messages=[{"role": "user", "content": "What's AWS?"}],
client=client,
extra_headers={"test": "hello world", "Authorization": "my-test-key"},
api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1",
)
except Exception as e:
pass
print(f"mock_client_post.call_args: {mock_client_post.call_args}")
print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}")
assert (
mock_client_post.call_args.kwargs["url"]
== "https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1/model/anthropic.claude-3-sonnet-20240229-v1:0/converse"
)
assert "test" in mock_client_post.call_args.kwargs["headers"]
assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world"
assert (

View file

@ -1347,3 +1347,33 @@ def test_logging_async_cache_hit_sync_call():
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0
def test_logging_key_masking_gemini():
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_pre_api_call", new=MagicMock()
) as mock_client:
try:
resp = litellm.completion(
model="gemini/gemini-1.5-pro",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
api_key="LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART",
)
except litellm.AuthenticationError:
pass
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert (
"LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART"
not in mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
)
key = mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
trimmed_key = key.split("key=")[1]
trimmed_key = trimmed_key.replace("*", "")
assert "PART" == trimmed_key

View file

@ -2342,3 +2342,55 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode):
assert e.cooldown_time == cooldown_time
assert exception_raised
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio()
async def test_router_weighted_pick(sync_mode):
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"weight": 2,
"mock_response": "Hello world 1!",
},
"model_info": {"id": "1"},
},
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "gpt-3.5-turbo",
"weight": 1,
"mock_response": "Hello world 2!",
},
"model_info": {"id": "2"},
},
]
)
model_id_1_count = 0
model_id_2_count = 0
for _ in range(50):
# make 50 calls. expect model id 1 to be picked more than model id 2
if sync_mode:
response = router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world!"}],
)
else:
response = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hello world!"}],
)
model_id = int(response._hidden_params["model_id"])
if model_id == 1:
model_id_1_count += 1
elif model_id == 2:
model_id_2_count += 1
else:
raise Exception("invalid model id returned!")
assert model_id_1_count > model_id_2_count

View file

@ -299,6 +299,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
custom_llm_provider: Optional[str]
tpm: Optional[int]
rpm: Optional[int]
order: Optional[int]
weight: Optional[int]
api_key: Optional[str]
api_base: Optional[str]
api_version: Optional[str]

View file

@ -7475,6 +7475,14 @@ def exception_type(
),
litellm_debug_info=extra_information,
)
elif "API key not valid." in error_str:
exception_mapping_worked = True
raise AuthenticationError(
message=f"{custom_llm_provider}Exception - {error_str}",
model=model,
llm_provider=custom_llm_provider,
litellm_debug_info=extra_information,
)
elif "403" in error_str:
exception_mapping_worked = True
raise BadRequestError(