LiteLLM minor fixes + improvements (31/08/2024) (#5464)

* fix(vertex_endpoints.py): fix vertex ai pass through endpoints

* test(test_streaming.py): skip model due to end of life

* feat(custom_logger.py): add special callback for model hitting tpm/rpm limits

Closes https://github.com/BerriAI/litellm/issues/4096
This commit is contained in:
Krish Dholakia 2024-09-01 13:31:42 -07:00 committed by GitHub
parent 7778fa0146
commit e0d81434ed
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 174 additions and 13 deletions

View file

@ -1,13 +1,21 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
import asyncio
import inspect
import os
import sys
import time
import traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List
from litellm import Router, Cache
from typing import List, Literal, Optional
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import Cache, Router
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
@ -602,14 +610,18 @@ async def test_async_completion_azure_caching():
router = Router(model_list=model_list) # type: ignore
response1 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await router.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
@ -618,3 +630,73 @@ async def test_async_completion_azure_caching():
)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_rate_limit_error_callback():
"""
Assert a callback is hit, if a model group starts hitting rate limit errors
Relevant issue: https://github.com/BerriAI/litellm/issues/4096
"""
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
router = Router(
model_list=[
{
"model_name": "my-test-gpt",
"litellm_params": {
"model": "gpt-3.5-turbo",
"mock_response": "litellm.RateLimitError",
},
}
],
allowed_fails=2,
num_retries=0,
)
litellm_logging_obj = LiteLLMLogging(
model="my-test-gpt",
messages=[{"role": "user", "content": "hi"}],
stream=False,
call_type="acompletion",
litellm_call_id="1234",
start_time=datetime.now(),
function_id="1234",
)
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
)
except Exception:
pass
with patch.object(
customHandler, "log_model_group_rate_limit_error", new=MagicMock()
) as mock_client:
print(
f"customHandler.log_model_group_rate_limit_error: {customHandler.log_model_group_rate_limit_error}"
)
for _ in range(3):
try:
_ = await router.acompletion(
model="my-test-gpt",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
litellm_logging_obj=litellm_logging_obj,
)
except litellm.RateLimitError:
pass
await asyncio.sleep(3)
mock_client.assert_called_once()
assert "original_model_group" in mock_client.call_args.kwargs
assert mock_client.call_args.kwargs["original_model_group"] == "my-test-gpt"