From e0d81434ed28a672dbbc541a5b533b4ddb6b2c9f Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sun, 1 Sep 2024 13:31:42 -0700 Subject: [PATCH] LiteLLM minor fixes + improvements (31/08/2024) (#5464) * fix(vertex_endpoints.py): fix vertex ai pass through endpoints * test(test_streaming.py): skip model due to end of life * feat(custom_logger.py): add special callback for model hitting tpm/rpm limits Closes https://github.com/BerriAI/litellm/issues/4096 --- litellm/integrations/custom_logger.py | 5 + litellm/litellm_core_utils/litellm_logging.py | 27 ++++++ litellm/proxy/_new_secret_config.yaml | 2 +- .../vertex_ai_endpoints/vertex_endpoints.py | 51 +++++++++- litellm/router.py | 2 + litellm/tests/test_custom_callback_router.py | 92 ++++++++++++++++++- litellm/types/router.py | 7 +- litellm/utils.py | 1 + 8 files changed, 174 insertions(+), 13 deletions(-) diff --git a/litellm/integrations/custom_logger.py b/litellm/integrations/custom_logger.py index 01fd35990..ce0caf32b 100644 --- a/litellm/integrations/custom_logger.py +++ b/litellm/integrations/custom_logger.py @@ -59,6 +59,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac pass #### Fallback Events - router/proxy only #### + async def log_model_group_rate_limit_error( + self, exception: Exception, original_model_group: Optional[str], kwargs: dict + ): + pass + async def log_success_fallback_event(self, original_model_group: str, kwargs: dict): pass diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 025c0e9a3..537ca15a4 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -1552,6 +1552,32 @@ class Logging: metadata.update(exception.headers) return start_time, end_time + async def special_failure_handlers(self, exception: Exception): + """ + Custom events, emitted for specific failures. + + Currently just for router model group rate limit error + """ + from litellm.types.router import RouterErrors + + ## check if special error ## + if RouterErrors.no_deployments_available.value not in str(exception): + return + + ## get original model group ## + + litellm_params: dict = self.model_call_details.get("litellm_params") or {} + metadata = litellm_params.get("metadata") or {} + + model_group = metadata.get("model_group") or None + for callback in litellm._async_failure_callback: + if isinstance(callback, CustomLogger): # custom logger class + await callback.log_model_group_rate_limit_error( + exception=exception, + original_model_group=model_group, + kwargs=self.model_call_details, + ) # type: ignore + def failure_handler( self, exception, traceback_exception, start_time=None, end_time=None ): @@ -1799,6 +1825,7 @@ class Logging: """ Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. """ + await self.special_failure_handlers(exception=exception) start_time, end_time = self._failure_handler_helper_fn( exception=exception, traceback_exception=traceback_exception, diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index b84ef7453..220da4932 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,4 +1,4 @@ model_list: - model_name: "gpt-3.5-turbo" litellm_params: - model: "gpt-3.5-turbo" \ No newline at end of file + model: "gpt-3.5-turbo" diff --git a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py index 9b87823e3..4a3a97e67 100644 --- a/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py +++ b/litellm/proxy/vertex_ai_endpoints/vertex_endpoints.py @@ -73,6 +73,45 @@ def exception_handler(e: Exception): ) +def construct_target_url( + base_url: str, + requested_route: str, + default_vertex_location: Optional[str], + default_vertex_project: Optional[str], +) -> httpx.URL: + """ + Allow user to specify their own project id / location. + + If missing, use defaults + + Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460 + + Constructed Url: + POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents + """ + new_base_url = httpx.URL(base_url) + if "locations" in requested_route: # contains the target project id + location + updated_url = new_base_url.copy_with(path=requested_route) + return updated_url + """ + - Add endpoint version (e.g. v1beta for cachedContent, v1 for rest) + - Add default project id + - Add default location + """ + vertex_version: Literal["v1", "v1beta1"] = "v1" + if "cachedContent" in requested_route: + vertex_version = "v1beta1" + + base_requested_route = "{}/projects/{}/locations/{}".format( + vertex_version, default_vertex_project, default_vertex_location + ) + + updated_requested_route = "/" + base_requested_route + requested_route + + updated_url = new_base_url.copy_with(path=updated_requested_route) + return updated_url + + @router.api_route( "/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"] ) @@ -86,8 +125,6 @@ async def vertex_proxy_route( import re - from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance - verbose_proxy_logger.debug("requested endpoint %s", endpoint) headers: dict = {} # Use headers from the incoming request if default_vertex_config is not set @@ -133,8 +170,14 @@ async def vertex_proxy_route( encoded_endpoint = "/" + encoded_endpoint # Construct the full target URL using httpx - base_url = httpx.URL(base_target_url) - updated_url = base_url.copy_with(path=encoded_endpoint) + updated_url = construct_target_url( + base_url=base_target_url, + requested_route=encoded_endpoint, + default_vertex_location=vertex_location, + default_vertex_project=vertex_project, + ) + # base_url = httpx.URL(base_target_url) + # updated_url = base_url.copy_with(path=encoded_endpoint) verbose_proxy_logger.debug("updated url %s", updated_url) diff --git a/litellm/router.py b/litellm/router.py index 3d371a3e6..1a433858d 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4792,10 +4792,12 @@ class Router: return deployment except Exception as e: + traceback_exception = traceback.format_exc() # if router rejects call -> log to langfuse/otel/etc. if request_kwargs is not None: logging_obj = request_kwargs.get("litellm_logging_obj", None) + if logging_obj is not None: ## LOGGING threading.Thread( diff --git a/litellm/tests/test_custom_callback_router.py b/litellm/tests/test_custom_callback_router.py index 071d4529d..544af78ee 100644 --- a/litellm/tests/test_custom_callback_router.py +++ b/litellm/tests/test_custom_callback_router.py @@ -1,13 +1,21 @@ ### What this tests #### ## This test asserts the type of data passed into each method of the custom callback handler -import sys, os, time, inspect, asyncio, traceback +import asyncio +import inspect +import os +import sys +import time +import traceback from datetime import datetime + import pytest sys.path.insert(0, os.path.abspath("../..")) -from typing import Optional, Literal, List -from litellm import Router, Cache +from typing import List, Literal, Optional +from unittest.mock import AsyncMock, MagicMock, patch + import litellm +from litellm import Cache, Router from litellm.integrations.custom_logger import CustomLogger # Test Scenarios (test across completion, streaming, embedding) @@ -602,14 +610,18 @@ async def test_async_completion_azure_caching(): router = Router(model_list=model_list) # type: ignore response1 = await router.acompletion( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + messages=[ + {"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"} + ], caching=True, ) await asyncio.sleep(1) print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") response2 = await router.acompletion( model="gpt-3.5-turbo", - messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + messages=[ + {"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"} + ], caching=True, ) await asyncio.sleep(1) # success callbacks are done in parallel @@ -618,3 +630,73 @@ async def test_async_completion_azure_caching(): ) assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.states) == 4 # pre, post, success, success + + +@pytest.mark.asyncio +async def test_rate_limit_error_callback(): + """ + Assert a callback is hit, if a model group starts hitting rate limit errors + + Relevant issue: https://github.com/BerriAI/litellm/issues/4096 + """ + from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging + + customHandler = CompletionCustomHandler() + litellm.callbacks = [customHandler] + litellm.success_callback = [] + + router = Router( + model_list=[ + { + "model_name": "my-test-gpt", + "litellm_params": { + "model": "gpt-3.5-turbo", + "mock_response": "litellm.RateLimitError", + }, + } + ], + allowed_fails=2, + num_retries=0, + ) + + litellm_logging_obj = LiteLLMLogging( + model="my-test-gpt", + messages=[{"role": "user", "content": "hi"}], + stream=False, + call_type="acompletion", + litellm_call_id="1234", + start_time=datetime.now(), + function_id="1234", + ) + + try: + _ = await router.acompletion( + model="my-test-gpt", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + ) + except Exception: + pass + + with patch.object( + customHandler, "log_model_group_rate_limit_error", new=MagicMock() + ) as mock_client: + + print( + f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}" + ) + + for _ in range(3): + try: + _ = await router.acompletion( + model="my-test-gpt", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + litellm_logging_obj=litellm_logging_obj, + ) + except litellm.RateLimitError: + pass + + await asyncio.sleep(3) + mock_client.assert_called_once() + + assert "original_model_group" in mock_client.call_args.kwargs + assert mock_client.call_args.kwargs["original_model_group"] == "my-test-gpt" diff --git a/litellm/types/router.py b/litellm/types/router.py index f959b9682..ba1a08901 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -5,11 +5,12 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc import datetime import enum import uuid -from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union import httpx from pydantic import BaseModel, ConfigDict, Field +from ..exceptions import RateLimitError from .completion import CompletionRequest from .embedding import EmbeddingRequest from .utils import ModelResponse @@ -567,7 +568,7 @@ class RouterRateLimitErrorBasic(ValueError): super().__init__(_message) -class RouterRateLimitError(ValueError): +class RouterRateLimitError(RateLimitError): def __init__( self, model: str, @@ -580,4 +581,4 @@ class RouterRateLimitError(ValueError): self.enable_pre_call_checks = enable_pre_call_checks self.cooldown_list = cooldown_list _message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}" - super().__init__(_message) + super().__init__(_message, llm_provider="", model=model) diff --git a/litellm/utils.py b/litellm/utils.py index 5b8229d68..efd48e8ab 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -394,6 +394,7 @@ def function_setup( print_verbose( f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" ) + if ( len(litellm.input_callback) > 0 or len(litellm.success_callback) > 0