LiteLLM minor fixes + improvements (31/08/2024) (#5464)

* fix(vertex_endpoints.py): fix vertex ai pass through endpoints

* test(test_streaming.py): skip model due to end of life

* feat(custom_logger.py): add special callback for model hitting tpm/rpm limits

Closes https://github.com/BerriAI/litellm/issues/4096
This commit is contained in:
Krish Dholakia 2024-09-01 13:31:42 -07:00 committed by GitHub
parent 7778fa0146
commit e0d81434ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 174 additions and 13 deletions

View file

@ -59,6 +59,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
pass
#### Fallback Events - router/proxy only ####
async def log_model_group_rate_limit_error(
self, exception: Exception, original_model_group: Optional[str], kwargs: dict
):
pass
async def log_success_fallback_event(self, original_model_group: str, kwargs: dict):
pass

View file

@ -1552,6 +1552,32 @@ class Logging:
metadata.update(exception.headers)
return start_time, end_time
async def special_failure_handlers(self, exception: Exception):
"""
Custom events, emitted for specific failures.
Currently just for router model group rate limit error
"""
from litellm.types.router import RouterErrors
## check if special error ##
if RouterErrors.no_deployments_available.value not in str(exception):
return
## get original model group ##
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
metadata = litellm_params.get("metadata") or {}
model_group = metadata.get("model_group") or None
for callback in litellm._async_failure_callback:
if isinstance(callback, CustomLogger): # custom logger class
await callback.log_model_group_rate_limit_error(
exception=exception,
original_model_group=model_group,
kwargs=self.model_call_details,
) # type: ignore
def failure_handler(
self, exception, traceback_exception, start_time=None, end_time=None
):
@ -1799,6 +1825,7 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
await self.special_failure_handlers(exception=exception)
start_time, end_time = self._failure_handler_helper_fn(
exception=exception,
traceback_exception=traceback_exception,

View file

@ -73,6 +73,45 @@ def exception_handler(e: Exception):
)
def construct_target_url(
base_url: str,
requested_route: str,
default_vertex_location: Optional[str],
default_vertex_project: Optional[str],
) -> httpx.URL:
"""
Allow user to specify their own project id / location.
If missing, use defaults
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
Constructed Url:
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
"""
new_base_url = httpx.URL(base_url)
if "locations" in requested_route: # contains the target project id + location
updated_url = new_base_url.copy_with(path=requested_route)
return updated_url
"""
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
- Add default project id
- Add default location
"""
vertex_version: Literal["v1", "v1beta1"] = "v1"
if "cachedContent" in requested_route:
vertex_version = "v1beta1"
base_requested_route = "{}/projects/{}/locations/{}".format(
vertex_version, default_vertex_project, default_vertex_location
)
updated_requested_route = "/" + base_requested_route + requested_route
updated_url = new_base_url.copy_with(path=updated_requested_route)
return updated_url
@router.api_route(
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
)
@ -86,8 +125,6 @@ async def vertex_proxy_route(
import re
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
headers: dict = {}
# Use headers from the incoming request if default_vertex_config is not set
@ -133,8 +170,14 @@ async def vertex_proxy_route(
encoded_endpoint = "/" + encoded_endpoint
# Construct the full target URL using httpx
base_url = httpx.URL(base_target_url)
updated_url = base_url.copy_with(path=encoded_endpoint)
updated_url = construct_target_url(
base_url=base_target_url,
requested_route=encoded_endpoint,
default_vertex_location=vertex_location,
default_vertex_project=vertex_project,
)
# base_url = httpx.URL(base_target_url)
# updated_url = base_url.copy_with(path=encoded_endpoint)
verbose_proxy_logger.debug("updated url %s", updated_url)

View file

@ -4792,10 +4792,12 @@ class Router:
return deployment
except Exception as e:
traceback_exception = traceback.format_exc()
# if router rejects call -> log to langfuse/otel/etc.
if request_kwargs is not None:
logging_obj = request_kwargs.get("litellm_logging_obj", None)
if logging_obj is not None:
## LOGGING
threading.Thread(

View file

@ -1,13 +1,21 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
import asyncio
import inspect
import os
import sys
import time
import traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List
from litellm import Router, Cache
from typing import List, Literal, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import Cache, Router
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
@ -602,14 +610,18 @@ async def test_async_completion_azure_caching():
router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
@ -618,3 +630,73 @@ async def test_async_completion_azure_caching():
)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_rate_limit_error_callback():
"""
Assert a callback is hit, if a model group starts hitting rate limit errors
Relevant issue: https://github.com/BerriAI/litellm/issues/4096
"""
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
router = Router(
model_list=[
{
"model_name": "my-test-gpt",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "litellm.RateLimitError",
},
}
],
allowed_fails=2,
num_retries=0,
)
litellm_logging_obj = LiteLLMLogging(
model="my-test-gpt",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
except Exception:
pass
with patch.object(
customHandler, "log_model_group_rate_limit_error", new=MagicMock()
) as mock_client:
print(
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
)
for _ in range(3):
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
litellm_logging_obj=litellm_logging_obj,
)
except litellm.RateLimitError:
pass
await asyncio.sleep(3)
mock_client.assert_called_once()
assert "original_model_group" in mock_client.call_args.kwargs
assert mock_client.call_args.kwargs["original_model_group"] == "my-test-gpt"

View file

@ -5,11 +5,12 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
import datetime
import enum
import uuid
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
import httpx
from pydantic import BaseModel, ConfigDict, Field
from ..exceptions import RateLimitError
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
from .utils import ModelResponse
@ -567,7 +568,7 @@ class RouterRateLimitErrorBasic(ValueError):
super().__init__(_message)
class RouterRateLimitError(ValueError):
class RouterRateLimitError(RateLimitError):
def __init__(
self,
model: str,
@ -580,4 +581,4 @@ class RouterRateLimitError(ValueError):
self.enable_pre_call_checks = enable_pre_call_checks
self.cooldown_list = cooldown_list
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
super().__init__(_message)
super().__init__(_message, llm_provider="", model=model)

View file

@ -394,6 +394,7 @@ def function_setup(
print_verbose(
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
)
if (
len(litellm.input_callback) > 0
or len(litellm.success_callback) > 0