forked from phoenix/litellm-mirror
LiteLLM minor fixes + improvements (31/08/2024) (#5464)
* fix(vertex_endpoints.py): fix vertex ai pass through endpoints * test(test_streaming.py): skip model due to end of life * feat(custom_logger.py): add special callback for model hitting tpm/rpm limits Closes https://github.com/BerriAI/litellm/issues/4096
This commit is contained in:
parent
7778fa0146
commit
e0d81434ed
8 changed files with 174 additions and 13 deletions
|
@ -59,6 +59,11 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
pass
|
||||
|
||||
#### Fallback Events - router/proxy only ####
|
||||
async def log_model_group_rate_limit_error(
|
||||
self, exception: Exception, original_model_group: Optional[str], kwargs: dict
|
||||
):
|
||||
pass
|
||||
|
||||
async def log_success_fallback_event(self, original_model_group: str, kwargs: dict):
|
||||
pass
|
||||
|
||||
|
|
|
@ -1552,6 +1552,32 @@ class Logging:
|
|||
metadata.update(exception.headers)
|
||||
return start_time, end_time
|
||||
|
||||
async def special_failure_handlers(self, exception: Exception):
|
||||
"""
|
||||
Custom events, emitted for specific failures.
|
||||
|
||||
Currently just for router model group rate limit error
|
||||
"""
|
||||
from litellm.types.router import RouterErrors
|
||||
|
||||
## check if special error ##
|
||||
if RouterErrors.no_deployments_available.value not in str(exception):
|
||||
return
|
||||
|
||||
## get original model group ##
|
||||
|
||||
litellm_params: dict = self.model_call_details.get("litellm_params") or {}
|
||||
metadata = litellm_params.get("metadata") or {}
|
||||
|
||||
model_group = metadata.get("model_group") or None
|
||||
for callback in litellm._async_failure_callback:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
await callback.log_model_group_rate_limit_error(
|
||||
exception=exception,
|
||||
original_model_group=model_group,
|
||||
kwargs=self.model_call_details,
|
||||
) # type: ignore
|
||||
|
||||
def failure_handler(
|
||||
self, exception, traceback_exception, start_time=None, end_time=None
|
||||
):
|
||||
|
@ -1799,6 +1825,7 @@ class Logging:
|
|||
"""
|
||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
await self.special_failure_handlers(exception=exception)
|
||||
start_time, end_time = self._failure_handler_helper_fn(
|
||||
exception=exception,
|
||||
traceback_exception=traceback_exception,
|
||||
|
|
|
@ -73,6 +73,45 @@ def exception_handler(e: Exception):
|
|||
)
|
||||
|
||||
|
||||
def construct_target_url(
|
||||
base_url: str,
|
||||
requested_route: str,
|
||||
default_vertex_location: Optional[str],
|
||||
default_vertex_project: Optional[str],
|
||||
) -> httpx.URL:
|
||||
"""
|
||||
Allow user to specify their own project id / location.
|
||||
|
||||
If missing, use defaults
|
||||
|
||||
Handle cachedContent scenario - https://github.com/BerriAI/litellm/issues/5460
|
||||
|
||||
Constructed Url:
|
||||
POST https://LOCATION-aiplatform.googleapis.com/{version}/projects/PROJECT_ID/locations/LOCATION/cachedContents
|
||||
"""
|
||||
new_base_url = httpx.URL(base_url)
|
||||
if "locations" in requested_route: # contains the target project id + location
|
||||
updated_url = new_base_url.copy_with(path=requested_route)
|
||||
return updated_url
|
||||
"""
|
||||
- Add endpoint version (e.g. v1beta for cachedContent, v1 for rest)
|
||||
- Add default project id
|
||||
- Add default location
|
||||
"""
|
||||
vertex_version: Literal["v1", "v1beta1"] = "v1"
|
||||
if "cachedContent" in requested_route:
|
||||
vertex_version = "v1beta1"
|
||||
|
||||
base_requested_route = "{}/projects/{}/locations/{}".format(
|
||||
vertex_version, default_vertex_project, default_vertex_location
|
||||
)
|
||||
|
||||
updated_requested_route = "/" + base_requested_route + requested_route
|
||||
|
||||
updated_url = new_base_url.copy_with(path=updated_requested_route)
|
||||
return updated_url
|
||||
|
||||
|
||||
@router.api_route(
|
||||
"/vertex-ai/{endpoint:path}", methods=["GET", "POST", "PUT", "DELETE"]
|
||||
)
|
||||
|
@ -86,8 +125,6 @@ async def vertex_proxy_route(
|
|||
|
||||
import re
|
||||
|
||||
from litellm.fine_tuning.main import vertex_fine_tuning_apis_instance
|
||||
|
||||
verbose_proxy_logger.debug("requested endpoint %s", endpoint)
|
||||
headers: dict = {}
|
||||
# Use headers from the incoming request if default_vertex_config is not set
|
||||
|
@ -133,8 +170,14 @@ async def vertex_proxy_route(
|
|||
encoded_endpoint = "/" + encoded_endpoint
|
||||
|
||||
# Construct the full target URL using httpx
|
||||
base_url = httpx.URL(base_target_url)
|
||||
updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
updated_url = construct_target_url(
|
||||
base_url=base_target_url,
|
||||
requested_route=encoded_endpoint,
|
||||
default_vertex_location=vertex_location,
|
||||
default_vertex_project=vertex_project,
|
||||
)
|
||||
# base_url = httpx.URL(base_target_url)
|
||||
# updated_url = base_url.copy_with(path=encoded_endpoint)
|
||||
|
||||
verbose_proxy_logger.debug("updated url %s", updated_url)
|
||||
|
||||
|
|
|
@ -4792,10 +4792,12 @@ class Router:
|
|||
|
||||
return deployment
|
||||
except Exception as e:
|
||||
|
||||
traceback_exception = traceback.format_exc()
|
||||
# if router rejects call -> log to langfuse/otel/etc.
|
||||
if request_kwargs is not None:
|
||||
logging_obj = request_kwargs.get("litellm_logging_obj", None)
|
||||
|
||||
if logging_obj is not None:
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
|
|
|
@ -1,13 +1,21 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
from typing import Optional, Literal, List
|
||||
from litellm import Router, Cache
|
||||
from typing import List, Literal, Optional
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm import Cache, Router
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
|
@ -602,14 +610,18 @@ async def test_async_completion_azure_caching():
|
|||
router = Router(model_list=model_list) # type: ignore
|
||||
response1 = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
|
||||
messages=[
|
||||
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||
],
|
||||
caching=True,
|
||||
)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||
response2 = await router.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
|
||||
messages=[
|
||||
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||
],
|
||||
caching=True,
|
||||
)
|
||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||
|
@ -618,3 +630,73 @@ async def test_async_completion_azure_caching():
|
|||
)
|
||||
assert len(customHandler_caching.errors) == 0
|
||||
assert len(customHandler_caching.states) == 4 # pre, post, success, success
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_error_callback():
|
||||
"""
|
||||
Assert a callback is hit, if a model group starts hitting rate limit errors
|
||||
|
||||
Relevant issue: https://github.com/BerriAI/litellm/issues/4096
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
litellm.success_callback = []
|
||||
|
||||
router = Router(
|
||||
model_list=[
|
||||
{
|
||||
"model_name": "my-test-gpt",
|
||||
"litellm_params": {
|
||||
"model": "gpt-3.5-turbo",
|
||||
"mock_response": "litellm.RateLimitError",
|
||||
},
|
||||
}
|
||||
],
|
||||
allowed_fails=2,
|
||||
num_retries=0,
|
||||
)
|
||||
|
||||
litellm_logging_obj = LiteLLMLogging(
|
||||
model="my-test-gpt",
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
stream=False,
|
||||
call_type="acompletion",
|
||||
litellm_call_id="1234",
|
||||
start_time=datetime.now(),
|
||||
function_id="1234",
|
||||
)
|
||||
|
||||
try:
|
||||
_ = await router.acompletion(
|
||||
model="my-test-gpt",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
with patch.object(
|
||||
customHandler, "log_model_group_rate_limit_error", new=MagicMock()
|
||||
) as mock_client:
|
||||
|
||||
print(
|
||||
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
|
||||
)
|
||||
|
||||
for _ in range(3):
|
||||
try:
|
||||
_ = await router.acompletion(
|
||||
model="my-test-gpt",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
litellm_logging_obj=litellm_logging_obj,
|
||||
)
|
||||
except litellm.RateLimitError:
|
||||
pass
|
||||
|
||||
await asyncio.sleep(3)
|
||||
mock_client.assert_called_once()
|
||||
|
||||
assert "original_model_group" in mock_client.call_args.kwargs
|
||||
assert mock_client.call_args.kwargs["original_model_group"] == "my-test-gpt"
|
||||
|
|
|
@ -5,11 +5,12 @@ litellm.Router Types - includes RouterConfig, UpdateRouterConfig, ModelInfo etc
|
|||
import datetime
|
||||
import enum
|
||||
import uuid
|
||||
from typing import Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
from typing import Any, Dict, List, Literal, Optional, Tuple, TypedDict, Union
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from ..exceptions import RateLimitError
|
||||
from .completion import CompletionRequest
|
||||
from .embedding import EmbeddingRequest
|
||||
from .utils import ModelResponse
|
||||
|
@ -567,7 +568,7 @@ class RouterRateLimitErrorBasic(ValueError):
|
|||
super().__init__(_message)
|
||||
|
||||
|
||||
class RouterRateLimitError(ValueError):
|
||||
class RouterRateLimitError(RateLimitError):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
|
@ -580,4 +581,4 @@ class RouterRateLimitError(ValueError):
|
|||
self.enable_pre_call_checks = enable_pre_call_checks
|
||||
self.cooldown_list = cooldown_list
|
||||
_message = f"{RouterErrors.no_deployments_available.value}, Try again in {cooldown_time} seconds. Passed model={model}. pre-call-checks={enable_pre_call_checks}, cooldown_list={cooldown_list}"
|
||||
super().__init__(_message)
|
||||
super().__init__(_message, llm_provider="", model=model)
|
||||
|
|
|
@ -394,6 +394,7 @@ def function_setup(
|
|||
print_verbose(
|
||||
f"Initialized litellm callbacks, Async Success Callbacks: {litellm._async_success_callback}"
|
||||
)
|
||||
|
||||
if (
|
||||
len(litellm.input_callback) > 0
|
||||
or len(litellm.success_callback) > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue