test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs

This commit is contained in:
Krrish Dholakia 2023-12-11 11:38:28 -08:00
parent 1d2f5ce975
commit ea89a8a938
8 changed files with 501 additions and 122 deletions

View file

@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
else:
azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore
response.model = "azure/" + str(response.model)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except AzureOpenAIError as e:
exception_mapping_worked = True
raise e

View file

@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
else:
openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore
stringified_response = response.model_dump_json()
logging_obj.post_call(
input=None,
input=messages,
api_key=api_key,
original_response=response,
original_response=stringified_response,
additional_args={"complete_input_dict": data},
)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility

View file

@ -319,7 +319,6 @@ def completion(
######### unpacking kwargs #####################
args = locals()
api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None)
@ -351,7 +350,7 @@ def completion(
client = kwargs.get("client", None)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -449,7 +448,6 @@ def completion(
# For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params(
acompletion=acompletion,
return_async=return_async,
api_key=api_key,
force_timeout=force_timeout,
logger_fn=logger_fn,
@ -526,17 +524,18 @@ def completion(
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
)
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
elif (
model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai"
@ -606,13 +605,14 @@ def completion(
)
raise e
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
if optional_params.get("stream", False) or acompletion == True:
## LOGGING
logging.post_call(
input=messages,
api_key=api_key,
original_response=response,
additional_args={"headers": headers},
)
elif (
custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model
@ -1787,7 +1787,7 @@ def embedding(
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {}

View file

@ -36,9 +36,12 @@ class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj)
assert response_cost > 0.0
return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):

View file

@ -262,41 +262,43 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
if prisma_client:
## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}")
if valid_token is None:
## check db
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")
if valid_token:
litellm.model_alias_map = valid_token.aliases
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
data = await request.json()
model = data.get("model", None)
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
## check for cache hit (In-Memory Cache)
valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}")
if valid_token is None:
## check db
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
elif valid_token is not None:
print(f"API Key Cache Hit!")
if valid_token:
litellm.model_alias_map = valid_token.aliases
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
print("\n new llm router model list", llm_model_list)
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
data = await request.json()
model = data.get("model", None)
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
except Exception as e:
print(f"An exception occurred - {traceback.format_exc()}")
if isinstance(e, HTTPException):
@ -380,25 +382,14 @@ async def track_cost_callback(
if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"]
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
response_cost = litellm.completion_cost(completion_response=completion_response)
print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses
input_text = kwargs.get("messages", "")
print(f"type of input_text: {type(input_text)}")
if isinstance(input_text, list):
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
elif isinstance(input_text, str):
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
response_cost = litellm.completion_cost(completion_response=completion_response)
print(f"received completion response: {completion_response}")
print(f"regular response_cost: {response_cost}")

View file

@ -104,7 +104,7 @@ class ProxyLogging:
2. /embeddings
"""
# check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None:
if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed
if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429

View file

@ -0,0 +1,386 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional
from litellm import completion, embedding
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock
# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## 2. router.completion() + router.embeddings()
## 3. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
def log_pre_api_call(self, model, messages, kwargs):
try:
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], str) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
## Test OpenAI + sync
def test_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync openai"
}])
## test streaming
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
for chunk in response:
continue
## test failure callback
try:
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
## test streaming
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync
def test_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}])
# test streaming
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
stream=True)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}])
## test streaming
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream())

View file

@ -801,9 +801,6 @@ class Logging:
end_time = datetime.datetime.now()
self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time
if isinstance(result, OpenAIObject):
result = result.model_dump()
if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds()
@ -857,9 +854,6 @@ class Logging:
call_type = self.call_type,
stream = self.stream,
)
if callback == "api_manager":
print_verbose("reaches api manager for updating model cost")
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
if callback == "promptlayer":
print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event(
@ -994,7 +988,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if isinstance(callback, CustomLogger): # custom logger class
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
print_verbose(f"success callbacks: Running Custom Logger Class")
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1044,7 +1038,6 @@ class Logging:
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None
if self.stream:
@ -1081,6 +1074,13 @@ class Logging:
start_time=start_time,
end_time=end_time,
)
else:
await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time
)
else:
await callback.async_log_success_event(
kwargs=self.model_call_details,
@ -1103,24 +1103,29 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None):
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
self.model_call_details.setdefault("original_response", None)
return start_time, end_time
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose(
f"Logging Details LiteLLM-Failure Call"
)
try:
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in litellm.failure_callback:
try:
@ -1212,16 +1217,8 @@ class Logging:
"""
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
"""
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
result = {} # result sent to all loggers, init this to None incase it's not created
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
result = None # result sent to all loggers, init this to None incase it's not created
for callback in litellm._async_failure_callback:
try:
if isinstance(callback, CustomLogger): # custom logger class
@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]):
return model_cost
def get_litellm_params(
return_async=False,
api_key=None,
force_timeout=600,
azure=False,
@ -2082,7 +2078,6 @@ def get_litellm_params(
):
litellm_params = {
"acompletion": acompletion,
"return_async": return_async,
"api_key": api_key,
"force_timeout": force_timeout,
"logger_fn": logger_fn,
@ -5094,9 +5089,6 @@ class CustomStreamWrapper:
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = ""
self.complete_response = ""
if self.logging_obj:
# Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream)))
def __iter__(self):
return self
@ -5121,10 +5113,6 @@ class CustomStreamWrapper:
except Exception as e:
raise e
def logging(self, text):
if self.logging_obj:
self.logging_obj.post_call(text)
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
hold = False
if finish_reason:
@ -5638,16 +5626,12 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
print_verbose(f"model_response: {model_response}")
return model_response
else:
return
elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
# enter this branch when no content has been passed in response
@ -5668,8 +5652,6 @@ class CustomStreamWrapper:
if self.sent_first_chunk == False:
model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
return model_response
else:
return
@ -5678,8 +5660,6 @@ class CustomStreamWrapper:
except Exception as e:
traceback_exception = traceback.format_exc()
e.message = str(e)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
## needs to handle the empty string case (even starting chunk can be an empty string)
@ -5692,12 +5672,17 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream)
if chunk is not None and chunk != b'':
response = self.chunk_creator(chunk=chunk)
if response is not None:
return response
if response is None:
continue
## LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response
return response
except StopIteration:
raise # Re-raise StopIteration
except Exception as e:
# Handle other exceptions if needed
traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
raise e
@ -5728,7 +5713,9 @@ class CustomStreamWrapper:
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
return processed_chunk
except Exception as e:
traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming
asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception))
raise StopAsyncIteration
class TextCompletionStreamWrapper: