test(test_custom_callback_router.py): add async azure testing for router

This commit is contained in:
Krrish Dholakia 2023-12-11 16:40:23 -08:00
parent 5c1322e574
commit 3b6099633c
4 changed files with 404 additions and 7 deletions

View file

@ -1835,7 +1835,7 @@ def embedding(
try: try:
response = None response = None
logging = litellm_logging_obj logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding}) logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding, "preset_cache_key": None, "stream_response": {}})
if azure == True or custom_llm_provider == "azure": if azure == True or custom_llm_provider == "azure":
# azure configs # azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure" api_type = get_secret("AZURE_API_TYPE") or "azure"

View file

@ -313,8 +313,8 @@ class Router:
**kwargs) -> Union[List[float], None]: **kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) kwargs.setdefault("model_info", {})
kwargs["model_info"] = deployment.get("model_info", {}) kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]}) # [TODO]: move to using async_function_with_fallbacks
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():
if k not in data: # prioritize model-specific params > default router params if k not in data: # prioritize model-specific params > default router params
@ -339,7 +339,7 @@ class Router:
**kwargs) -> Union[List[float], None]: **kwargs) -> Union[List[float], None]:
# pick the one that is available (lowest TPM/RPM) # pick the one that is available (lowest TPM/RPM)
deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None)) deployment = self.get_available_deployment(model=model, input=input, specific_deployment=kwargs.pop("specific_deployment", None))
kwargs.setdefault("metadata", {}).update({"deployment": deployment["litellm_params"]["model"]}) kwargs.setdefault("metadata", {}).update({"model_group": model, "deployment": deployment["litellm_params"]["model"]})
data = deployment["litellm_params"].copy() data = deployment["litellm_params"].copy()
kwargs["model_info"] = deployment.get("model_info", {}) kwargs["model_info"] = deployment.get("model_info", {})
for k, v in self.default_litellm_params.items(): for k, v in self.default_litellm_params.items():

View file

@ -22,8 +22,7 @@ from litellm.integrations.custom_logger import CustomLogger
# Test interfaces # Test interfaces
## 1. litellm.completion() + litellm.embeddings() ## 1. litellm.completion() + litellm.embeddings()
## 2. router.completion() + router.embeddings() ## refer to test_custom_callback_input_router.py for the router + proxy tests
## 3. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
""" """
@ -577,4 +576,4 @@ async def test_async_embedding_bedrock():
except Exception as e: except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
asyncio.run(test_async_embedding_bedrock()) # asyncio.run(test_async_embedding_bedrock())

View file

@ -0,0 +1,398 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional, Literal, List
from litellm import Router
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
## fallbacks
## retries
# Test cases
## 1. Simple Azure OpenAI acompletion + streaming call
## 2. Simple Azure OpenAI aembedding call
## 3. Azure OpenAI acompletion + streaming call with retries
## 4. Azure OpenAI aembedding call with retries
## 5. Azure OpenAI acompletion + streaming call with fallbacks
## 6. Azure OpenAI aembedding call with fallbacks
# Test interfaces
## 1. router.completion() + router.embeddings()
## 2. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
print(f'received kwargs in pre-input: {kwargs}')
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], Optional[str])
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], Optional[str])
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], Optional[str])
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], Optional[str])
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
"""
No-op.
Not implemented yet.
"""
pass
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
### ROUTER-SPECIFIC KWARGS
assert isinstance(kwargs["litellm_params"]["metadata"], dict)
assert isinstance(kwargs["litellm_params"]["metadata"]["model_group"], str)
assert isinstance(kwargs["litellm_params"]["metadata"]["deployment"], str)
assert isinstance(kwargs["litellm_params"]["model_info"], dict)
assert isinstance(kwargs["litellm_params"]["model_info"]["id"], str)
assert isinstance(kwargs["litellm_params"]["proxy_server_request"], Optional[str])
assert isinstance(kwargs["litellm_params"]["preset_cache_key"], Optional[str])
assert isinstance(kwargs["litellm_params"]["stream_response"], dict)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
print(f"received original response: {kwargs['original_response']}")
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# Simple Azure OpenAI call
## COMPLETION
@pytest.mark.asyncio
async def test_async_chat_azure():
try:
customHandler = CompletionCustomHandler()
customHandler_streaming = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler]
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
await asyncio.sleep(2)
assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success
# streaming
litellm.callbacks = [customHandler_streaming]
router2 = Router(model_list=model_list) # type: ignore
response = await router2.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_streaming.states}")
assert len(customHandler_streaming.errors) == 0
assert len(customHandler_streaming.states) >= 4 # pre, post, stream (multiple times), success
# failure
model_list = [
{
"model_name": "gpt-3.5-turbo", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/chatgpt-v-2",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
print(f"response in router3 acompletion: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_chat_azure())
## EMBEDDING
async def test_async_embedding_azure():
try:
customHandler = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler]
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": os.getenv("AZURE_API_KEY"),
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
router = Router(model_list=model_list) # type: ignore
response = await router.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
await asyncio.sleep(2)
assert len(customHandler.errors) == 0
assert len(customHandler.states) == 3 # pre, post, success
# failure
model_list = [
{
"model_name": "azure-embedding-model", # openai model name
"litellm_params": { # params for litellm completion/embedding call
"model": "azure/azure-embedding-model",
"api_key": "my-bad-key",
"api_version": os.getenv("AZURE_API_VERSION"),
"api_base": os.getenv("AZURE_API_BASE")
},
"tpm": 240000,
"rpm": 1800
},
]
litellm.callbacks = [customHandler_failure]
router3 = Router(model_list=model_list) # type: ignore
try:
response = await router3.aembedding(model="azure-embedding-model",
input=["hello from litellm!"])
print(f"response in router3 aembedding: {response}")
except:
pass
await asyncio.sleep(1)
print(f"customHandler.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
assert "async_failure" in customHandler_failure.states
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
pytest.fail(f"An exception occurred - {str(e)}")
asyncio.run(test_async_embedding_azure())
# Azure OpenAI call w/ Retries
## COMPLETION
## EMBEDDING
# Azure OpenAI call w/ fallbacks
## COMPLETION
## EMBEDDING