forked from phoenix/litellm-mirror
fix(pattern_match_deployments.py): default to user input if unable to… (#6632)
* fix(pattern_match_deployments.py): default to user input if unable to map based on wildcards * test: fix test * test: reset test name * test: update conftest to reload proxy server module between tests * ci(config.yml): move langfuse out of local_testing reduce ci/cd time * ci(config.yml): cleanup langfuse ci/cd tests * fix: update test to not use global proxy_server app module * ci: move caching to a separate test pipeline speed up ci pipeline * test: update conftest to check if proxy_server attr exists before reloading * build(conftest.py): don't block on inability to reload proxy_server * ci(config.yml): update caching unit test filter to work on 'cache' keyword as well * fix(encrypt_decrypt_utils.py): use function to get salt key * test: mark flaky test * test: handle anthropic overloaded errors * refactor: create separate ci/cd pipeline for proxy unit tests make ci/cd faster * ci(config.yml): add litellm_proxy_unit_testing to build_and_test jobs * ci(config.yml): generate prisma binaries for proxy unit tests * test: readd vertex_key.json * ci(config.yml): remove `-s` from proxy_unit_test cmd speed up test * ci: remove any 'debug' logging flag speed up ci pipeline * test: fix test * test(test_braintrust.py): rerun * test: add delay for braintrust test
This commit is contained in:
parent
44840d615d
commit
27e18358ab
77 changed files with 2861 additions and 76 deletions
359
tests/proxy_unit_tests/test_custom_callback_input.py
Normal file
359
tests/proxy_unit_tests/test_custom_callback_input.py
Normal file
|
@ -0,0 +1,359 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
sys.path.insert(0, os.path.abspath("../.."))
|
||||
from typing import List, Literal, Optional, Union
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import litellm
|
||||
from litellm import Cache, completion, embedding
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import LiteLLMCommonStrings
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
## 1: Pre-API-Call
|
||||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
## 5. Caching
|
||||
|
||||
# Test models
|
||||
## 1. OpenAI
|
||||
## 2. Azure OpenAI
|
||||
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
||||
|
||||
# Test interfaces
|
||||
## 1. litellm.completion() + litellm.embeddings()
|
||||
## refer to test_custom_callback_input_router.py for the router + proxy tests
|
||||
|
||||
|
||||
class CompletionCustomHandler(
|
||||
CustomLogger
|
||||
): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: List[
|
||||
Literal[
|
||||
"sync_pre_api_call",
|
||||
"async_pre_api_call",
|
||||
"post_api_call",
|
||||
"sync_stream",
|
||||
"async_stream",
|
||||
"sync_success",
|
||||
"async_success",
|
||||
"sync_failure",
|
||||
"async_failure",
|
||||
]
|
||||
] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
### METADATA
|
||||
metadata_value = kwargs["litellm_params"].get("metadata")
|
||||
assert metadata_value is None or isinstance(metadata_value, dict)
|
||||
if metadata_value is not None:
|
||||
if litellm.turn_off_message_logging is True:
|
||||
assert (
|
||||
metadata_value["raw_request"]
|
||||
is LiteLLMCommonStrings.redacted_by_litellm.value
|
||||
)
|
||||
else:
|
||||
assert "raw_request" not in metadata_value or isinstance(
|
||||
metadata_value["raw_request"], str
|
||||
)
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"],
|
||||
(str, litellm.CustomStreamWrapper, BaseModel),
|
||||
)
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(kwargs["input"], list)
|
||||
and isinstance(kwargs["input"][0], dict)
|
||||
) or isinstance(kwargs["input"], (dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(f"\n\nkwargs={kwargs}\n\n")
|
||||
print(
|
||||
json.dumps(kwargs, default=str)
|
||||
) # this is a test to confirm no circular references are in the logging object
|
||||
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(
|
||||
response_obj,
|
||||
(
|
||||
litellm.ModelResponse,
|
||||
litellm.EmbeddingResponse,
|
||||
litellm.ImageResponse,
|
||||
),
|
||||
)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(kwargs["input"], list)
|
||||
and (
|
||||
isinstance(kwargs["input"][0], dict)
|
||||
or isinstance(kwargs["input"][0], str)
|
||||
)
|
||||
) or isinstance(kwargs["input"], (dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert isinstance(
|
||||
kwargs["original_response"],
|
||||
(str, litellm.CustomStreamWrapper, BaseModel),
|
||||
), "Original Response={}. Allowed types=[str, litellm.CustomStreamWrapper, BaseModel]".format(
|
||||
kwargs["original_response"]
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
assert isinstance(kwargs["response_cost"], (float, type(None)))
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(f"kwargs: {kwargs}")
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["metadata"], Optional[dict])
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(kwargs["input"], list)
|
||||
and isinstance(kwargs["input"][0], dict)
|
||||
) or isinstance(kwargs["input"], (dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or kwargs["original_response"] == None
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("async_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list) and isinstance(
|
||||
kwargs["messages"][0], dict
|
||||
)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print(
|
||||
"in async_log_success_event", kwargs, response_obj, start_time, end_time
|
||||
)
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(
|
||||
response_obj,
|
||||
(
|
||||
litellm.ModelResponse,
|
||||
litellm.EmbeddingResponse,
|
||||
litellm.TextCompletionResponse,
|
||||
),
|
||||
)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"]["api_base"], str)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["completion_start_time"], datetime)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, dict, str))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
assert kwargs["cache_hit"] is None or isinstance(kwargs["cache_hit"], bool)
|
||||
assert isinstance(kwargs["response_cost"], (float, type(None)))
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs["model"], str)
|
||||
assert isinstance(kwargs["messages"], list)
|
||||
assert isinstance(kwargs["optional_params"], dict)
|
||||
assert isinstance(kwargs["litellm_params"], dict)
|
||||
assert isinstance(kwargs["start_time"], (datetime, type(None)))
|
||||
assert isinstance(kwargs["stream"], bool)
|
||||
assert isinstance(kwargs["user"], (str, type(None)))
|
||||
assert isinstance(kwargs["input"], (list, str, dict))
|
||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||
assert (
|
||||
isinstance(
|
||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
||||
)
|
||||
or inspect.isasyncgen(kwargs["original_response"])
|
||||
or inspect.iscoroutine(kwargs["original_response"])
|
||||
or kwargs["original_response"] == None
|
||||
)
|
||||
assert isinstance(kwargs["additional_args"], (dict, type(None)))
|
||||
assert isinstance(kwargs["log_event_type"], str)
|
||||
except Exception:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
Loading…
Add table
Add a link
Reference in a new issue