forked from phoenix/litellm-mirror
* fix(pattern_match_deployments.py): default to user input if unable to map based on wildcards * test: fix test * test: reset test name * test: update conftest to reload proxy server module between tests * ci(config.yml): move langfuse out of local_testing reduce ci/cd time * ci(config.yml): cleanup langfuse ci/cd tests * fix: update test to not use global proxy_server app module * ci: move caching to a separate test pipeline speed up ci pipeline * test: update conftest to check if proxy_server attr exists before reloading * build(conftest.py): don't block on inability to reload proxy_server * ci(config.yml): update caching unit test filter to work on 'cache' keyword as well * fix(encrypt_decrypt_utils.py): use function to get salt key * test: mark flaky test * test: handle anthropic overloaded errors * refactor: create separate ci/cd pipeline for proxy unit tests make ci/cd faster * ci(config.yml): add litellm_proxy_unit_testing to build_and_test jobs * ci(config.yml): generate prisma binaries for proxy unit tests * test: readd vertex_key.json * ci(config.yml): remove `-s` from proxy_unit_test cmd speed up test * ci: remove any 'debug' logging flag speed up ci pipeline * test: fix test * test(test_braintrust.py): rerun * test: add delay for braintrust test
359 lines
15 KiB
Python
359 lines
15 KiB
Python
### What this tests ####
|
|
## This test asserts the type of data passed into each method of the custom callback handler
|
|
import asyncio
|
|
import inspect
|
|
import os
|
|
import sys
|
|
import time
|
|
import traceback
|
|
import uuid
|
|
from datetime import datetime
|
|
|
|
import pytest
|
|
from pydantic import BaseModel
|
|
|
|
sys.path.insert(0, os.path.abspath("../.."))
|
|
from typing import List, Literal, Optional, Union
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import litellm
|
|
from litellm import Cache, completion, embedding
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.types.utils import LiteLLMCommonStrings
|
|
|
|
# Test Scenarios (test across completion, streaming, embedding)
|
|
## 1: Pre-API-Call
|
|
## 2: Post-API-Call
|
|
## 3: On LiteLLM Call success
|
|
## 4: On LiteLLM Call failure
|
|
## 5. Caching
|
|
|
|
# Test models
|
|
## 1. OpenAI
|
|
## 2. Azure OpenAI
|
|
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
|
|
|
# Test interfaces
|
|
## 1. litellm.completion() + litellm.embeddings()
|
|
## refer to test_custom_callback_input_router.py for the router + proxy tests
|
|
|
|
|
|
class CompletionCustomHandler(
|
|
CustomLogger
|
|
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
|
"""
|
|
The set of expected inputs to a custom handler for a
|
|
"""
|
|
|
|
# Class variables or attributes
|
|
def __init__(self):
|
|
self.errors = []
|
|
self.states: List[
|
|
Literal[
|
|
"sync_pre_api_call",
|
|
"async_pre_api_call",
|
|
"post_api_call",
|
|
"sync_stream",
|
|
"async_stream",
|
|
"sync_success",
|
|
"async_success",
|
|
"sync_failure",
|
|
"async_failure",
|
|
]
|
|
] = []
|
|
|
|
def log_pre_api_call(self, model, messages, kwargs):
|
|
try:
|
|
self.states.append("sync_pre_api_call")
|
|
## MODEL
|
|
assert isinstance(model, str)
|
|
## MESSAGES
|
|
assert isinstance(messages, list)
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
### METADATA
|
|
metadata_value = kwargs["litellm_params"].get("metadata")
|
|
assert metadata_value is None or isinstance(metadata_value, dict)
|
|
if metadata_value is not None:
|
|
if litellm.turn_off_message_logging is True:
|
|
assert (
|
|
metadata_value["raw_request"]
|
|
is LiteLLMCommonStrings.redacted_by_litellm.value
|
|
)
|
|
else:
|
|
assert "raw_request" not in metadata_value or isinstance(
|
|
metadata_value["raw_request"], str
|
|
)
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
self.states.append("post_api_call")
|
|
## START TIME
|
|
assert isinstance(start_time, datetime)
|
|
## END TIME
|
|
assert end_time == None
|
|
## RESPONSE OBJECT
|
|
assert response_obj == None
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
assert isinstance(kwargs["input"], (list, dict, str))
|
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
|
assert (
|
|
isinstance(
|
|
kwargs["original_response"],
|
|
(str, litellm.CustomStreamWrapper, BaseModel),
|
|
)
|
|
or inspect.iscoroutine(kwargs["original_response"])
|
|
or inspect.isasyncgen(kwargs["original_response"])
|
|
)
|
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
|
assert isinstance(kwargs["log_event_type"], str)
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
self.states.append("async_stream")
|
|
## START TIME
|
|
assert isinstance(start_time, datetime)
|
|
## END TIME
|
|
assert isinstance(end_time, datetime)
|
|
## RESPONSE OBJECT
|
|
assert isinstance(response_obj, litellm.ModelResponse)
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list) and isinstance(
|
|
kwargs["messages"][0], dict
|
|
)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
assert (
|
|
isinstance(kwargs["input"], list)
|
|
and isinstance(kwargs["input"][0], dict)
|
|
) or isinstance(kwargs["input"], (dict, str))
|
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
|
assert (
|
|
isinstance(
|
|
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
|
)
|
|
or inspect.isasyncgen(kwargs["original_response"])
|
|
or inspect.iscoroutine(kwargs["original_response"])
|
|
)
|
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
|
assert isinstance(kwargs["log_event_type"], str)
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
print(f"\n\nkwargs={kwargs}\n\n")
|
|
print(
|
|
json.dumps(kwargs, default=str)
|
|
) # this is a test to confirm no circular references are in the logging object
|
|
|
|
self.states.append("sync_success")
|
|
## START TIME
|
|
assert isinstance(start_time, datetime)
|
|
## END TIME
|
|
assert isinstance(end_time, datetime)
|
|
## RESPONSE OBJECT
|
|
assert isinstance(
|
|
response_obj,
|
|
(
|
|
litellm.ModelResponse,
|
|
litellm.EmbeddingResponse,
|
|
litellm.ImageResponse,
|
|
),
|
|
)
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list) and isinstance(
|
|
kwargs["messages"][0], dict
|
|
)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
assert (
|
|
isinstance(kwargs["input"], list)
|
|
and (
|
|
isinstance(kwargs["input"][0], dict)
|
|
or isinstance(kwargs["input"][0], str)
|
|
)
|
|
) or isinstance(kwargs["input"], (dict, str))
|
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
|
assert isinstance(
|
|
kwargs["original_response"],
|
|
(str, litellm.CustomStreamWrapper, BaseModel),
|
|
), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format(
|
|
kwargs["original_response"]
|
|
)
|
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
|
assert isinstance(kwargs["log_event_type"], str)
|
|
assert isinstance(kwargs["response_cost"], (float, type(None)))
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
print(f"kwargs: {kwargs}")
|
|
self.states.append("sync_failure")
|
|
## START TIME
|
|
assert isinstance(start_time, datetime)
|
|
## END TIME
|
|
assert isinstance(end_time, datetime)
|
|
## RESPONSE OBJECT
|
|
assert response_obj == None
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list) and isinstance(
|
|
kwargs["messages"][0], dict
|
|
)
|
|
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict])
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
assert (
|
|
isinstance(kwargs["input"], list)
|
|
and isinstance(kwargs["input"][0], dict)
|
|
) or isinstance(kwargs["input"], (dict, str))
|
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
|
assert (
|
|
isinstance(
|
|
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
|
)
|
|
or kwargs["original_response"] == None
|
|
)
|
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
|
assert isinstance(kwargs["log_event_type"], str)
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
|
try:
|
|
self.states.append("async_pre_api_call")
|
|
## MODEL
|
|
assert isinstance(model, str)
|
|
## MESSAGES
|
|
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list) and isinstance(
|
|
kwargs["messages"][0], dict
|
|
)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
except Exception as e:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
print(
|
|
"in async_log_success_event", kwargs, response_obj, start_time, end_time
|
|
)
|
|
self.states.append("async_success")
|
|
## START TIME
|
|
assert isinstance(start_time, datetime)
|
|
## END TIME
|
|
assert isinstance(end_time, datetime)
|
|
## RESPONSE OBJECT
|
|
assert isinstance(
|
|
response_obj,
|
|
(
|
|
litellm.ModelResponse,
|
|
litellm.EmbeddingResponse,
|
|
litellm.TextCompletionResponse,
|
|
),
|
|
)
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["completion_start_time"], datetime)
|
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
assert isinstance(kwargs["input"], (list, dict, str))
|
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
|
assert (
|
|
isinstance(
|
|
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
|
)
|
|
or inspect.isasyncgen(kwargs["original_response"])
|
|
or inspect.iscoroutine(kwargs["original_response"])
|
|
)
|
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
|
assert isinstance(kwargs["log_event_type"], str)
|
|
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
|
assert isinstance(kwargs["response_cost"], (float, type(None)))
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|
|
|
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
|
try:
|
|
self.states.append("async_failure")
|
|
## START TIME
|
|
assert isinstance(start_time, datetime)
|
|
## END TIME
|
|
assert isinstance(end_time, datetime)
|
|
## RESPONSE OBJECT
|
|
assert response_obj == None
|
|
## KWARGS
|
|
assert isinstance(kwargs["model"], str)
|
|
assert isinstance(kwargs["messages"], list)
|
|
assert isinstance(kwargs["optional_params"], dict)
|
|
assert isinstance(kwargs["litellm_params"], dict)
|
|
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
|
assert isinstance(kwargs["stream"], bool)
|
|
assert isinstance(kwargs["user"], (str, type(None)))
|
|
assert isinstance(kwargs["input"], (list, str, dict))
|
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
|
assert (
|
|
isinstance(
|
|
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
|
)
|
|
or inspect.isasyncgen(kwargs["original_response"])
|
|
or inspect.iscoroutine(kwargs["original_response"])
|
|
or kwargs["original_response"] == None
|
|
)
|
|
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
|
assert isinstance(kwargs["log_event_type"], str)
|
|
except Exception:
|
|
print(f"Assertion Error: {traceback.format_exc()}")
|
|
self.errors.append(traceback.format_exc())
|