LiteLLM minor fixes + improvements (31/08/2024) (#5464)

* fix(vertex_endpoints.py): fix vertex ai pass through endpoints

* test(test_streaming.py): skip model due to end of life

* feat(custom_logger.py): add special callback for model hitting tpm/rpm limits

Closes https://github.com/BerriAI/litellm/issues/4096
This commit is contained in:
Krish Dholakia 2024-09-01 13:31:42 -07:00 committed by GitHub
parent 7778fa0146
commit e0d81434ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 174 additions and 13 deletions

View file

@ -59,6 +59,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
pass pass
#### Fallback Events - router/proxy only #### #### Fallback Events - router/proxy only ####
async def log_model_group_rate_limit_error(
self, exception: Exception, original_model_group: Optional[str], kwargs: dict
):
pass
async def log_success_fallback_event(self, original_model_group: str, kwargs: dict): async def log_success_fallback_event(self, original_model_group: str, kwargs: dict):
pass pass

View file

@ -1552,6 +1552,32 @@ class Logging:
metadata.update(exception.headers) metadata.update(exception.headers)
return start_time, end_time return start_time, end_time
async def special_failure_handlers(self, exception: Exception):
"""
Custom events, emitted for specific failures.
Currently just for router model group rate limit error
"""
from litellm.types.router import RouterErrors
## check if special error ##
if RouterErrors.no_deployments_available.value not in str(exception):
return
## get original model group ##
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
model_group = metadata.get("model_group") or None
for callback in litellm._async_failure_callback:
if isinstance(callback, CustomLogger): # custom logger class
await callback.log_model_group_rate_limit_error(
exception=exception,
original_model_group=model_group,
kwargs=self.model_call_details,
) # type: ignore
def failure_handler( def failure_handler(
self, exception, traceback_exception, start_time=None, end_time=None self, exception, traceback_exception, start_time=None, end_time=None
): ):
@ -1799,6 +1825,7 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
await self.special_failure_handlers(exception=exception)
start_time, end_time = self._failure_handler_helper_fn( start_time, end_time = self._failure_handler_helper_fn(
exception=exception, exception=exception,
traceback_exception=traceback_exception, traceback_exception=traceback_exception,

View file

@ -1,4 +1,4 @@
model_list: model_list:
- model_name: "gpt-3.5-turbo" - model_name: "gpt-3.5-turbo"
litellm_params: litellm_params:
model: "gpt-3.5-turbo" model: "gpt-3.5-turbo"

View file

@ -73,6 +73,45 @@ def exception_handler(e: Exception):
) )
def construct_target_url(
base_url: str,
requested_route: str,
default_vertex_location: Optional[str],
default_vertex_project: Optional[str],
) -> httpx.URL:
"""
Allow user to specify their own project id / location.
If missing, use defaults
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
Constructed Url:
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
"""
new_base_url = httpx.URL(base_url)
if "locations" in requested_route: # contains the target project id + location
updated_url = new_base_url.copy_with(path=requested_route)
return updated_url
"""
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
- Add default project id
- Add default location
"""
vertex_version: Literal["v1", "v1beta1"] = "v1"
if "cachedContent" in requested_route:
vertex_version = "v1beta1"
base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, default_vertex_project, default_vertex_location
)
updated_requested_route = "/" + base_requested_route + requested_route
updated_url = new_base_url.copy_with(path=updated_requested_route)
return updated_url
@router.api_route( @router.api_route(
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"] "/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
) )
@ -86,8 +125,6 @@ async def vertex_proxy_route(
import re import re
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
verbose_proxy_logger.debug("requested endpoint %s", endpoint) verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {} headers: dict = {}
# Use headers from the incoming request if default_vertex_config is not set # Use headers from the incoming request if default_vertex_config is not set
@ -133,8 +170,14 @@ async def vertex_proxy_route(
encoded_endpoint = "/" + encoded_endpoint encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx # Construct the full target URL using httpx
base_url = httpx.URL(base_target_url) updated_url = construct_target_url(
updated_url = base_url.copy_with(path=encoded_endpoint) base_url=base_target_url,
requested_route=encoded_endpoint,
default_vertex_location=vertex_location,
default_vertex_project=vertex_project,
)
# base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url) verbose_proxy_logger.debug("updated url %s", updated_url)

View file

@ -4792,10 +4792,12 @@ class Router:
return deployment return deployment
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
# if router rejects call -> log to langfuse/otel/etc. # if router rejects call -> log to langfuse/otel/etc.
if request_kwargs is not None: if request_kwargs is not None:
logging_obj = request_kwargs.get("litellm_logging_obj", None) logging_obj = request_kwargs.get("litellm_logging_obj", None)
if logging_obj is not None: if logging_obj is not None:
## LOGGING ## LOGGING
threading.Thread( threading.Thread(

View file

@ -1,13 +1,21 @@
### What this tests #### ### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler ## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback import asyncio
import inspect
import os
import sys
import time
import traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List from typing import List, Literal, Optional
from litellm import Router, Cache from unittest.mock import AsyncMock, MagicMock, patch
import litellm import litellm
from litellm import Cache, Router
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding) # Test Scenarios (test across completion, streaming, embedding)
@ -602,14 +610,18 @@ async def test_async_completion_azure_caching():
router = Router(model_list=model_list) # type: ignore router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion( response1 = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True, caching=True,
) )
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion( response2 = await router.acompletion(
model="gpt-3.5-turbo", model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True, caching=True,
) )
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
@ -618,3 +630,73 @@ async def test_async_completion_azure_caching():
) )
assert len(customHandler_caching.errors) == 0 assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_rate_limit_error_callback():
"""
Assert a callback is hit, if a model group starts hitting rate limit errors
Relevant issue: https://github.com/BerriAI/litellm/issues/4096
"""
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
router = Router(
model_list=[
{
"model_name": "my-test-gpt",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "litellm.RateLimitError",
},
}
],
allowed_fails=2,
num_retries=0,
)
litellm_logging_obj = LiteLLMLogging(
model="my-test-gpt",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
except Exception:
pass
with patch.object(
customHandler, "log_model_group_rate_limit_error", new=MagicMock()
) as mock_client:
print(
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
)
for _ in range(3):
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
litellm_logging_obj=litellm_logging_obj,
)
except litellm.RateLimitError:
pass
await asyncio.sleep(3)
mock_client.assert_called_once()
assert "original_model_group" in mock_client.call_args.kwargs
assert mock_client.call_args.kwargs["original_model_group"] == "my-test-gpt"

View file

@ -5,11 +5,12 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
import datetime import datetime
import enum import enum
import uuid import uuid
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
import httpx import httpx
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from ..exceptions import RateLimitError
from .completion import CompletionRequest from .completion import CompletionRequest
from .embedding import EmbeddingRequest from .embedding import EmbeddingRequest
from .utils import ModelResponse from .utils import ModelResponse
@ -567,7 +568,7 @@ class RouterRateLimitErrorBasic(ValueError):
super().__init__(_message) super().__init__(_message)
class RouterRateLimitError(ValueError): class RouterRateLimitError(RateLimitError):
def __init__( def __init__(
self, self,
model: str, model: str,
@ -580,4 +581,4 @@ class RouterRateLimitError(ValueError):
self.enable_pre_call_checks = enable_pre_call_checks self.enable_pre_call_checks = enable_pre_call_checks
self.cooldown_list = cooldown_list self.cooldown_list = cooldown_list
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}" _message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
super().__init__(_message) super().__init__(_message, llm_provider="", model=model)

View file

@ -394,6 +394,7 @@ def function_setup(
print_verbose( print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}" f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
) )
if ( if (
len(litellm.input_callback) > 0 len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0 or len(litellm.success_callback) > 0