diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 14a4c68b3a..682d1a4b57 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -279,9 +279,7 @@ class Logging: # Find the position of "key=" in the string key_index = api_base.find("key=") + 4 # Mask the last 5 characters after "key=" - masked_api_base = ( - api_base[:key_index] + "*" * 5 + api_base[key_index + 5 :] - ) + masked_api_base = api_base[:key_index] + "*" * 5 + api_base[-4:] else: masked_api_base = api_base self.model_call_details["litellm_params"]["api_base"] = masked_api_base diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index e60f106943..5980463e5c 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -48,13 +48,7 @@ from litellm.types.llms.openai import ( from litellm.types.utils import Choices from litellm.types.utils import GenericStreamingChunk as GChunk from litellm.types.utils import Message -from litellm.utils import ( - CustomStreamWrapper, - ModelResponse, - Usage, - get_secret, - print_verbose, -) +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage, get_secret from .base import BaseLLM from .base_aws_llm import BaseAWSLLM @@ -654,6 +648,7 @@ class BedrockLLM(BaseAWSLLM): self, model: str, messages: list, + api_base: Optional[str], custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, @@ -734,7 +729,9 @@ class BedrockLLM(BaseAWSLLM): ### SET RUNTIME ENDPOINT ### endpoint_url = "" env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if aws_bedrock_runtime_endpoint is not None and isinstance( + if api_base is not None: + endpoint_url = api_base + elif aws_bedrock_runtime_endpoint is not None and isinstance( aws_bedrock_runtime_endpoint, str ): endpoint_url = aws_bedrock_runtime_endpoint @@ -1459,7 +1456,7 @@ class BedrockConverseLLM(BaseAWSLLM): client = client # type: ignore try: - response = await client.post(api_base, headers=headers, data=data) # type: ignore + response = await client.post(url=api_base, headers=headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1485,6 +1482,7 @@ class BedrockConverseLLM(BaseAWSLLM): self, model: str, messages: list, + api_base: Optional[str], custom_prompt_dict: dict, model_response: ModelResponse, print_verbose: Callable, @@ -1565,7 +1563,9 @@ class BedrockConverseLLM(BaseAWSLLM): ### SET RUNTIME ENDPOINT ### endpoint_url = "" env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") - if aws_bedrock_runtime_endpoint is not None and isinstance( + if api_base is not None: + endpoint_url = api_base + elif aws_bedrock_runtime_endpoint is not None and isinstance( aws_bedrock_runtime_endpoint, str ): endpoint_url = aws_bedrock_runtime_endpoint diff --git a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py index ed7cd8dbab..05a2ffc3fa 100644 --- a/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py +++ b/litellm/llms/vertex_ai_and_google_ai_studio/gemini/vertex_and_google_ai_studio_gemini.py @@ -1284,7 +1284,7 @@ class VertexLLM(BaseLLM): ) -> Union[ModelResponse, CustomStreamWrapper]: request_body = await async_transform_request_body(**data) # type: ignore - if client is None: + if client is None or not isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: if isinstance(timeout, float) or isinstance(timeout, int): @@ -1293,6 +1293,16 @@ class VertexLLM(BaseLLM): client = AsyncHTTPHandler(**_params) # type: ignore else: client = client # type: ignore + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": request_body, + "api_base": api_base, + "headers": headers, + }, + ) try: response = await client.post(api_base, headers=headers, json=request_body) # type: ignore diff --git a/litellm/main.py b/litellm/main.py index ca9d145f1b..4aa0815ca5 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2361,6 +2361,7 @@ def completion( timeout=timeout, acompletion=acompletion, client=client, + api_base=api_base, ) else: response = bedrock_chat_completion.completion( @@ -2378,6 +2379,7 @@ def completion( timeout=timeout, acompletion=acompletion, client=client, + api_base=api_base, ) if optional_params.get("stream", False): diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 8192098de8..b8f964ab34 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,9 +1,4 @@ model_list: - - model_name: my-fake-openai-endpoint + - model_name: "gemini/*" litellm_params: - model: gpt-3.5-turbo - api_key: "my-fake-key" - mock_response: "hello-world" - -litellm_settings: - ssl_verify: false \ No newline at end of file + model: "gemini/*" \ No newline at end of file diff --git a/litellm/proxy/auth/model_checks.py b/litellm/proxy/auth/model_checks.py index ccbf907864..dd22a7f47d 100644 --- a/litellm/proxy/auth/model_checks.py +++ b/litellm/proxy/auth/model_checks.py @@ -1,9 +1,41 @@ # What is this? ## Common checks for /v1/models and `/model/info` from typing import List, Optional -from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames -from litellm.utils import get_valid_models + +import litellm from litellm._logging import verbose_proxy_logger +from litellm.proxy._types import SpecialModelNames, UserAPIKeyAuth +from litellm.utils import get_valid_models + + +def _check_wildcard_routing(model: str) -> bool: + """ + Returns True if a model is a provider wildcard. + """ + if model == "*": + return True + + if "/" in model: + llm_provider, potential_wildcard = model.split("/", 1) + if ( + llm_provider in litellm.provider_list and potential_wildcard == "*" + ): # e.g. anthropic/* + return True + + return False + + +def get_provider_models(provider: str) -> Optional[List[str]]: + """ + Returns the list of known models by provider + """ + if provider == "*": + return get_valid_models() + + if provider in litellm.models_by_provider: + return litellm.models_by_provider[provider] + + return None def get_key_models( @@ -58,6 +90,8 @@ def get_complete_model_list( """ - If key list is empty -> defer to team list - If team list is empty -> defer to proxy model list + + If list contains wildcard -> return known provider models """ unique_models = set() @@ -76,4 +110,18 @@ def get_complete_model_list( valid_models = get_valid_models() unique_models.update(valid_models) - return list(unique_models) + models_to_remove = set() + all_wildcard_models = [] + for model in unique_models: + if _check_wildcard_routing(model=model): + provider = model.split("/")[0] + # get all known provider models + wildcard_models = get_provider_models(provider=provider) + if wildcard_models is not None: + models_to_remove.add(model) + all_wildcard_models.extend(wildcard_models) + + for model in models_to_remove: + unique_models.remove(model) + + return list(unique_models) + all_wildcard_models diff --git a/litellm/router.py b/litellm/router.py index 37ca435d7c..3d371a3e60 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4700,6 +4700,31 @@ class Router: ) elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm + + ############## Check if 'weight' param set for a weighted pick ################# + weight = ( + healthy_deployments[0].get("litellm_params").get("weight", None) + ) + if weight is not None: + # use weight-random pick if rpms provided + weights = [ + m["litellm_params"].get("weight", 0) + for m in healthy_deployments + ] + verbose_router_logger.debug(f"\nweight {weights}") + total_weight = sum(weights) + weights = [weight / total_weight for weight in weights] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices( + range(len(weights)), weights=weights + )[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] ############## Check if we can do a RPM/TPM based weighted pick ################# rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) if rpm is not None: @@ -4847,6 +4872,25 @@ class Router: ) elif self.routing_strategy == "simple-shuffle": # if users pass rpm or tpm, we do a random weighted pick - based on rpm/tpm + ############## Check 'weight' param set for weighted pick ################# + weight = healthy_deployments[0].get("litellm_params").get("weight", None) + if weight is not None: + # use weight-random pick if rpms provided + weights = [ + m["litellm_params"].get("weight", 0) for m in healthy_deployments + ] + verbose_router_logger.debug(f"\nweight {weights}") + total_weight = sum(weights) + weights = [weight / total_weight for weight in weights] + verbose_router_logger.debug(f"\n weights {weights}") + # Perform weighted random pick + selected_index = random.choices(range(len(weights)), weights=weights)[0] + verbose_router_logger.debug(f"\n selected index, {selected_index}") + deployment = healthy_deployments[selected_index] + verbose_router_logger.info( + f"get_available_deployment for model: {model}, Selected deployment: {self.print_deployment(deployment) or deployment[0]} for model: {model}" + ) + return deployment or deployment[0] ############## Check if we can do a RPM/TPM based weighted pick ################# rpm = healthy_deployments[0].get("litellm_params").get("rpm", None) if rpm is not None: diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 129e0fc625..9a830d3535 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -960,11 +960,16 @@ async def test_bedrock_extra_headers(): messages=[{"role": "user", "content": "What's AWS?"}], client=client, extra_headers={"test": "hello world", "Authorization": "my-test-key"}, + api_base="https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1", ) except Exception as e: pass - print(f"mock_client_post.call_args: {mock_client_post.call_args}") + print(f"mock_client_post.call_args.kwargs: {mock_client_post.call_args.kwargs}") + assert ( + mock_client_post.call_args.kwargs["url"] + == "https://gateway.ai.cloudflare.com/v1/fa4cdcab1f32b95ca3b53fd36043d691/test/aws-bedrock/bedrock-runtime/us-east-1/model/anthropic.claude-3-sonnet-20240229-v1:0/converse" + ) assert "test" in mock_client_post.call_args.kwargs["headers"] assert mock_client_post.call_args.kwargs["headers"]["test"] == "hello world" assert ( diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index b60d5a98c0..739365a702 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -1347,3 +1347,33 @@ def test_logging_async_cache_hit_sync_call(): assert standard_logging_object["cache_hit"] is True assert standard_logging_object["response_cost"] == 0 assert standard_logging_object["saved_cache_cost"] > 0 + + +def test_logging_key_masking_gemini(): + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + litellm.success_callback = [] + + with patch.object( + customHandler, "log_pre_api_call", new=MagicMock() + ) as mock_client: + try: + resp = litellm.completion( + model="gemini/gemini-1.5-pro", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + api_key="LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART", + ) + except litellm.AuthenticationError: + pass + + mock_client.assert_called() + + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") + assert ( + "LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART" + not in mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"] + ) + key = mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"] + trimmed_key = key.split("key=")[1] + trimmed_key = trimmed_key.replace("*", "") + assert "PART" == trimmed_key diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index fd4f79657c..a0a96fa67f 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -2342,3 +2342,55 @@ async def test_aaarouter_dynamic_cooldown_message_retry_time(sync_mode): assert e.cooldown_time == cooldown_time assert exception_raised + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.asyncio() +async def test_router_weighted_pick(sync_mode): + router = Router( + model_list=[ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "weight": 2, + "mock_response": "Hello world 1!", + }, + "model_info": {"id": "1"}, + }, + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "gpt-3.5-turbo", + "weight": 1, + "mock_response": "Hello world 2!", + }, + "model_info": {"id": "2"}, + }, + ] + ) + + model_id_1_count = 0 + model_id_2_count = 0 + for _ in range(50): + # make 50 calls. expect model id 1 to be picked more than model id 2 + if sync_mode: + response = router.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world!"}], + ) + else: + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hello world!"}], + ) + + model_id = int(response._hidden_params["model_id"]) + + if model_id == 1: + model_id_1_count += 1 + elif model_id == 2: + model_id_2_count += 1 + else: + raise Exception("invalid model id returned!") + assert model_id_1_count > model_id_2_count diff --git a/litellm/types/router.py b/litellm/types/router.py index b05e9be4a6..f959b9682e 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -299,6 +299,8 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): custom_llm_provider: Optional[str] tpm: Optional[int] rpm: Optional[int] + order: Optional[int] + weight: Optional[int] api_key: Optional[str] api_base: Optional[str] api_version: Optional[str] diff --git a/litellm/utils.py b/litellm/utils.py index aeb48cb7f5..b22faae044 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7475,6 +7475,14 @@ def exception_type( ), litellm_debug_info=extra_information, ) + elif "API key not valid." in error_str: + exception_mapping_worked = True + raise AuthenticationError( + message=f"{custom_llm_provider}Exception - {error_str}", + model=model, + llm_provider=custom_llm_provider, + litellm_debug_info=extra_information, + ) elif "403" in error_str: exception_mapping_worked = True raise BadRequestError(