litellm/tests/local_testing/test_custom_callback_input.py
Krish Dholakia e9aa492af3
LiteLLM Minor Fixes & Improvement (11/14/2024) (#6730)
* fix(ollama.py): fix get model info request

Fixes https://github.com/BerriAI/litellm/issues/6703

* feat(anthropic/chat/transformation.py): support passing user id to anthropic via openai 'user' param

* docs(anthropic.md): document all supported openai params for anthropic

* test: fix tests

* fix: fix tests

* feat(jina_ai/): add rerank support

Closes https://github.com/BerriAI/litellm/issues/6691

* test: handle service unavailable error

* fix(handler.py): refactor together ai rerank call

* test: update test to handle overloaded error

* test: fix test

* Litellm router trace (#6742)

* feat(router.py): add trace_id to parent functions - allows tracking retry/fallbacks

* feat(router.py): log trace id across retry/fallback logic

allows grouping llm logs for the same request

* test: fix tests

* fix: fix test

* fix(transformation.py): only set non-none stop_sequences

* Litellm router disable fallbacks (#6743)

* bump: version 1.52.6 → 1.52.7

* feat(router.py): enable dynamically disabling fallbacks

Allows for enabling/disabling fallbacks per key

* feat(litellm_pre_call_utils.py): support setting 'disable_fallbacks' on litellm key

* test: fix test

* fix(exception_mapping_utils.py): map 'model is overloaded' to internal server error

* test: handle gemini error

* test: fix test

* fix: new run
2024-11-15 01:02:54 +05:30

1678 lines
59 KiB
Python

### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import asyncio
import inspect
import os
import sys
import time
import traceback
import uuid
from datetime import datetime
import pytest
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../.."))
from typing import List, Literal, Optional, Union
from unittest.mock import AsyncMock, MagicMock, patch
import litellm
from litellm import Cache, completion, embedding
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import LiteLLMCommonStrings
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## 5. Caching
# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock
# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## refer to test_custom_callback_input_router.py for the router + proxy tests
class CompletionCustomHandler(
CustomLogger
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: List[
Literal[
"sync_pre_api_call",
"async_pre_api_call",
"post_api_call",
"sync_stream",
"async_stream",
"sync_success",
"async_success",
"sync_failure",
"async_failure",
]
] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
### METADATA
metadata_value = kwargs["litellm_params"].get("metadata")
assert metadata_value is None or isinstance(metadata_value, dict)
if metadata_value is not None:
if litellm.turn_off_message_logging is True:
assert (
metadata_value["raw_request"]
is LiteLLMCommonStrings.redacted_by_litellm.value
)
else:
assert "raw_request" not in metadata_value or isinstance(
metadata_value["raw_request"], str
)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"\n\nkwargs={kwargs}\n\n")
print(
json.dumps(kwargs, default=str)
) # this is a test to confirm no circular references are in the logging object
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(
response_obj,
(
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.ImageResponse,
),
)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and (
isinstance(kwargs["input"][0], dict)
or isinstance(kwargs["input"][0], str)
)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format(
kwargs["original_response"]
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert isinstance(kwargs["response_cost"], (float, type(None)))
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"kwargs: {kwargs}")
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict])
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert (
isinstance(kwargs["input"], list)
and isinstance(kwargs["input"][0], dict)
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("async_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list) and isinstance(
kwargs["messages"][0], dict
)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
print(
"in async_log_success_event", kwargs, response_obj, start_time, end_time
)
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(
response_obj,
(
litellm.ModelResponse,
litellm.EmbeddingResponse,
litellm.TextCompletionResponse,
),
)
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["litellm_params"]["api_base"], str)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["completion_start_time"], datetime)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
assert isinstance(kwargs["response_cost"], (float, type(None)))
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs["model"], str)
assert isinstance(kwargs["messages"], list)
assert isinstance(kwargs["optional_params"], dict)
assert isinstance(kwargs["litellm_params"], dict)
assert isinstance(kwargs["start_time"], (datetime, type(None)))
assert isinstance(kwargs["stream"], bool)
assert isinstance(kwargs["user"], (str, type(None)))
assert isinstance(kwargs["input"], (list, str, dict))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
)
or inspect.isasyncgen(kwargs["original_response"])
or inspect.iscoroutine(kwargs["original_response"])
or kwargs["original_response"] == None
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
except Exception:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# COMPLETION
## Test OpenAI + sync
def test_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync openai"}],
)
## test streaming
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
for chunk in response:
continue
## test failure callback
try:
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
api_key="my-bad-key",
stream=True,
)
for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
## test streaming
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
stream=True,
)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
api_key="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync
def test_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
)
# test streaming
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
stream=True,
)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync azure"}],
api_key="my-bad-key",
stream=True,
)
for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream()
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
)
## test streaming
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
stream=True,
)
async for chunk in response:
continue
# test failure callback
try:
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async azure"}],
api_key="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except Exception:
pass
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream())
@pytest.mark.asyncio
async def test_async_chat_openai_stream_options():
try:
litellm.set_verbose = True
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
with patch.object(
customHandler, "async_log_success_event", new=AsyncMock()
) as mock_client:
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm async openai"}],
stream=True,
stream_options={"include_usage": True},
)
async for chunk in response:
continue
print("mock client args list=", mock_client.await_args_list)
mock_client.assert_awaited_once()
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test Bedrock + sync
def test_chat_bedrock_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(
model="bedrock/anthropic.claude-v2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
)
# test streaming
response = litellm.completion(
model="bedrock/anthropic.claude-v2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
stream=True,
)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(
model="bedrock/anthropic.claude-v2",
messages=[{"role": "user", "content": "Hi 👋 - i'm sync bedrock"}],
aws_region_name="my-bad-region",
stream=True,
)
for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_bedrock_stream()
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_chat_bedrock_stream():
try:
litellm.set_verbose = True
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
)
# test streaming
response = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(
model="bedrock/anthropic.claude-v2",
messages=[{"role": "user", "content": "Hi 👋 - i'm async bedrock"}],
aws_region_name="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except Exception:
pass
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async
@pytest.mark.skip(reason="AWS Suspended Account")
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
)
# test streaming
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.acompletion(
model="sagemaker/berri-benchmarking-Llama-2-70b-chat-hf-4",
messages=[{"role": "user", "content": "Hi 👋 - i'm async sagemaker"}],
aws_region_name="my-bad-key",
stream=True,
)
async for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test Vertex AI + Async
import json
import tempfile
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/vertex_key.json"
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("VERTEX_AI_PRIVATE_KEY_ID", "")
private_key = os.environ.get("VERTEX_AI_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary file
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
@pytest.mark.skip(reason="Vertex AI Hanging")
@pytest.mark.asyncio
async def test_async_chat_vertex_ai_stream():
try:
load_vertex_ai_credentials()
customHandler = CompletionCustomHandler()
litellm.set_verbose = True
litellm.callbacks = [customHandler]
# test streaming
response = await litellm.acompletion(
model="gemini-pro",
messages=[
{
"role": "user",
"content": f"Hi 👋 - i'm async vertex_ai {uuid.uuid4()}",
}
],
stream=True,
)
print(f"response: {response}")
async for chunk in response:
print(f"chunk: {chunk}")
continue
await asyncio.sleep(10)
print(f"customHandler.states: {customHandler.states}")
assert (
customHandler.states.count("async_success") == 1
) # pre, post, success, pre, post, failure
assert len(customHandler.states) >= 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# Text Completion
@pytest.mark.asyncio
async def test_async_text_completion_bedrock():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.atext_completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
prompt=["Hi 👋 - i'm async text completion bedrock"],
)
# test streaming
response = await litellm.atext_completion(
model="bedrock/anthropic.claude-3-haiku-20240307-v1:0",
prompt=["Hi 👋 - i'm async text completion bedrock"],
stream=True,
)
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.atext_completion(
model="bedrock/",
prompt=["Hi 👋 - i'm async text completion bedrock"],
stream=True,
api_key="my-bad-key",
)
async for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test OpenAI text completion + Async
@pytest.mark.asyncio
async def test_async_text_completion_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.atext_completion(
model="gpt-3.5-turbo",
prompt="Hi 👋 - i'm async text completion openai",
)
# test streaming
response = await litellm.atext_completion(
model="gpt-3.5-turbo",
prompt="Hi 👋 - i'm async text completion openai",
stream=True,
)
async for chunk in response:
print(f"chunk: {chunk}")
continue
## test failure callback
try:
response = await litellm.atext_completion(
model="gpt-3.5-turbo",
prompt="Hi 👋 - i'm async text completion openai",
stream=True,
api_key="my-bad-key",
)
async for chunk in response:
continue
except Exception:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_embedding_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(
model="text-embedding-ada-002",
input=["good morning from litellm"],
api_key="my-bad-key",
)
except Exception:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_openai())
## Test Azure + Async
def test_amazing_sync_embedding():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = litellm.embedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = litellm.embedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key",
)
except Exception:
pass
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
time.sleep(2)
assert len(customHandler_failure.errors) == 1
assert len(customHandler_failure.states) == 3 # pre, post, failure
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(
model="azure/azure-embedding-model", input=["good morning from litellm"]
)
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(
model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key",
)
except Exception:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_azure())
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
litellm.set_verbose = True
response = await litellm.aembedding(
model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="us-east-1",
)
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(
model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="my-bad-region",
)
except Exception:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_bedrock())
# CACHING
## Test Azure - completion, embedding
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_async_completion_azure_caching():
litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
@pytest.mark.asyncio
async def test_async_completion_azure_caching_streaming():
import copy
litellm.set_verbose = True
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = uuid.uuid4()
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response1:
print(f"chunk in response1: {chunk}")
await asyncio.sleep(1)
initial_customhandler_caching_states = len(customHandler_caching.states)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
stream=True,
)
async for chunk in response2:
print(f"chunk in response2: {chunk}")
await asyncio.sleep(1) # success callbacks are done in parallel
print(
f"customHandler_caching.states post-cache hit: {customHandler_caching.states}"
)
assert len(customHandler_caching.errors) == 0
assert (
len(customHandler_caching.states) > initial_customhandler_caching_states
) # pre, post, streaming .., success, success
@pytest.mark.asyncio
async def test_async_embedding_azure_caching():
print("Testing custom callback input - Azure Caching")
customHandler_caching = CompletionCustomHandler()
litellm.cache = Cache(
type="redis",
host=os.environ["REDIS_HOST"],
port=os.environ["REDIS_PORT"],
password=os.environ["REDIS_PASSWORD"],
)
litellm.callbacks = [customHandler_caching]
unique_time = time.time()
response1 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True,
)
await asyncio.sleep(1) # set cache is async for aembedding()
response2 = await litellm.aembedding(
model="azure/azure-embedding-model",
input=[f"good morning from litellm1 {unique_time}"],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
print(customHandler_caching.states)
print(customHandler_caching.errors)
assert len(customHandler_caching.errors) == 0
assert len(customHandler_caching.states) == 4 # pre, post, success, success
# Image Generation
## Test OpenAI + Sync
@pytest.mark.flaky(retries=3, delay=1)
def test_image_generation_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/",
api_base=os.getenv("AZURE_API_BASE"),
api_key=os.getenv("AZURE_API_KEY"),
api_version="2023-06-01-preview",
)
print(f"response: {response}")
assert len(response.data) > 0
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="dall-e-2",
api_key="my-bad-api-key",
)
except Exception:
pass
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai()
## Test OpenAI + Async
## Test Azure + Sync
## Test Azure + Async
##### PII REDACTION ######
def test_turn_off_message_logging():
"""
If 'turn_off_message_logging' is true, assert no user request information is logged.
"""
litellm.turn_off_message_logging = True
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="Going well!",
)
time.sleep(2)
assert len(customHandler.errors) == 0
##### VALID JSON ######
@pytest.mark.parametrize(
"model",
[
"ft:gpt-3.5-turbo:my-org:custom_suffix:id"
], # "gpt-3.5-turbo", "azure/chatgpt-v-2",
)
@pytest.mark.parametrize(
"turn_off_message_logging",
[
True,
],
) # False
def test_standard_logging_payload(model, turn_off_message_logging):
"""
Ensure valid standard_logging_payload is passed for logging calls to s3
Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
"""
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.turn_off_message_logging = turn_off_message_logging
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
_ = litellm.completion(
model=model,
messages=[{"role": "user", "content": "Hey, how's it going?"}],
mock_response="Going well!",
)
time.sleep(2)
mock_client.assert_called_once()
print(
f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
)
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
print(
"Standard Logging Object - {}".format(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
)
keys_list = list(StandardLoggingPayload.__annotations__.keys())
for k in keys_list:
assert (
k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
## json serializable
json_str_payload = json.dumps(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
json.loads(json_str_payload)
## response cost
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"response_cost"
]
> 0
)
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"model_map_information"
]["model_map_value"]
is not None
)
## turn off message logging
slobject: StandardLoggingPayload = mock_client.call_args.kwargs["kwargs"][
"standard_logging_object"
]
if turn_off_message_logging:
print("checks redacted-by-litellm")
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
assert "redacted-by-litellm" == slobject["response"]
@pytest.mark.parametrize(
"stream",
[True, False],
)
@pytest.mark.parametrize(
"turn_off_message_logging",
[
True,
],
) # False
def test_standard_logging_payload_audio(turn_off_message_logging, stream):
"""
Ensure valid standard_logging_payload is passed for logging calls to s3
Motivation: provide a standard set of things that are logged to s3/gcs/future integrations across all llm calls
"""
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.turn_off_message_logging = turn_off_message_logging
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
response = litellm.completion(
model="gpt-4o-audio-preview",
modalities=["text", "audio"],
audio={"voice": "alloy", "format": "pcm16"},
messages=[{"role": "user", "content": "response in 1 word - yes or no"}],
stream=stream,
)
if stream:
for chunk in response:
continue
time.sleep(2)
mock_client.assert_called_once()
print(
f"mock_client_post.call_args: {mock_client.call_args.kwargs['kwargs'].keys()}"
)
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
print(
"Standard Logging Object - {}".format(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
)
keys_list = list(StandardLoggingPayload.__annotations__.keys())
for k in keys_list:
assert (
k in mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
## json serializable
json_str_payload = json.dumps(
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
json.loads(json_str_payload)
## response cost
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"response_cost"
]
> 0
)
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"][
"model_map_information"
]["model_map_value"]
is not None
)
## turn off message logging
slobject: StandardLoggingPayload = mock_client.call_args.kwargs["kwargs"][
"standard_logging_object"
]
if turn_off_message_logging:
print("checks redacted-by-litellm")
assert "redacted-by-litellm" == slobject["messages"][0]["content"]
assert "redacted-by-litellm" == slobject["response"]
@pytest.mark.skip(reason="Works locally. Flaky on ci/cd")
def test_aaastandard_logging_payload_cache_hit():
from litellm.types.utils import StandardLoggingPayload
# sync completion
litellm.cache = Cache()
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
_ = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
)
time.sleep(2)
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0
@pytest.mark.parametrize(
"turn_off_message_logging",
[False, True],
) # False
def test_logging_async_cache_hit_sync_call(turn_off_message_logging):
from litellm.types.utils import StandardLoggingPayload
litellm.turn_off_message_logging = turn_off_message_logging
litellm.cache = Cache()
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
stream=True,
)
for chunk in response:
print(chunk)
time.sleep(3)
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
caching=True,
stream=True,
)
for chunk in resp:
print(chunk)
time.sleep(2)
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert standard_logging_object["cache_hit"] is True
assert standard_logging_object["response_cost"] == 0
assert standard_logging_object["saved_cache_cost"] > 0
if turn_off_message_logging:
print("checks redacted-by-litellm")
assert (
"redacted-by-litellm"
== standard_logging_object["messages"][0]["content"]
)
assert "redacted-by-litellm" == standard_logging_object["response"]
def test_logging_standard_payload_failure_call():
from litellm.types.utils import StandardLoggingPayload
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
api_key="my-bad-api-key",
)
except litellm.AuthenticationError:
pass
mock_client.assert_called_once()
assert "standard_logging_object" in mock_client.call_args.kwargs["kwargs"]
assert (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
is not None
)
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
assert "additional_headers" in standard_logging_object["hidden_params"]
@pytest.mark.parametrize("stream", [True, False])
def test_logging_standard_payload_llm_headers(stream):
from litellm.types.utils import StandardLoggingPayload
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
with patch.object(
customHandler, "log_success_event", new=MagicMock()
) as mock_client:
resp = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=stream,
)
if stream:
for chunk in resp:
continue
time.sleep(2)
mock_client.assert_called_once()
standard_logging_object: StandardLoggingPayload = mock_client.call_args.kwargs[
"kwargs"
]["standard_logging_object"]
print(standard_logging_object["hidden_params"]["additional_headers"])
def test_logging_key_masking_gemini():
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
litellm.success_callback = []
with patch.object(
customHandler, "log_pre_api_call", new=MagicMock()
) as mock_client:
try:
resp = litellm.completion(
model="gemini/gemini-1.5-pro",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
api_key="LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART",
)
except litellm.AuthenticationError:
pass
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert (
"LEAVE_ONLY_LAST_4_CHAR_UNMASKED_THIS_PART"
not in mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
)
key = mock_client.call_args.kwargs["kwargs"]["litellm_params"]["api_base"]
trimmed_key = key.split("key=")[1]
trimmed_key = trimmed_key.replace("*", "")
assert "PART" == trimmed_key
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_standard_logging_payload_stream_usage(sync_mode):
"""
Even if stream_options is not provided, correct usage should be logged
"""
from litellm.types.utils import StandardLoggingPayload
from litellm.main import stream_chunk_builder
stream = True
try:
# sync completion
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
if sync_mode:
patch_event = "log_success_event"
return_val = MagicMock()
else:
patch_event = "async_log_success_event"
return_val = AsyncMock()
with patch.object(customHandler, patch_event, new=return_val) as mock_client:
if sync_mode:
resp = litellm.completion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=stream,
)
chunks = []
for chunk in resp:
chunks.append(chunk)
time.sleep(2)
else:
resp = await litellm.acompletion(
model="anthropic/claude-3-5-sonnet-20240620",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
stream=stream,
)
chunks = []
async for chunk in resp:
chunks.append(chunk)
await asyncio.sleep(2)
mock_client.assert_called_once()
standard_logging_object: StandardLoggingPayload = (
mock_client.call_args.kwargs["kwargs"]["standard_logging_object"]
)
built_response = stream_chunk_builder(chunks=chunks)
assert (
built_response.usage.total_tokens
!= standard_logging_object["total_tokens"]
)
print(f"standard_logging_object usage: {built_response.usage}")
except litellm.InternalServerError:
pass
def test_standard_logging_retries():
"""
know if a request was retried.
"""
from litellm.types.utils import StandardLoggingPayload
from litellm.router import Router
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
router = Router(
model_list=[
{
"model_name": "gpt-3.5-turbo",
"litellm_params": {
"model": "openai/gpt-3.5-turbo",
"api_key": "test-api-key",
},
}
]
)
with patch.object(
customHandler, "log_failure_event", new=MagicMock()
) as mock_client:
try:
router.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey, how's it going?"}],
num_retries=1,
mock_response="litellm.RateLimitError",
)
except litellm.RateLimitError:
pass
assert mock_client.call_count == 2
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
is not None
)
assert (
mock_client.call_args_list[0].kwargs["kwargs"]["standard_logging_object"][
"trace_id"
]
== mock_client.call_args_list[1].kwargs["kwargs"][
"standard_logging_object"
]["trace_id"]
)