test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs

This commit is contained in:
Krrish Dholakia 2023-12-11 11:38:28 -08:00
parent 1d2f5ce975
commit ea89a8a938
8 changed files with 501 additions and 122 deletions

View file

@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
else: else:
azure_client = client azure_client = client
response = azure_client.chat.completions.create(**data) # type: ignore response = azure_client.chat.completions.create(**data) # type: ignore
response.model = "azure/" + str(response.model) stringified_response = response.model_dump_json()
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) ## LOGGING
logging_obj.post_call(
input=messages,
api_key=api_key,
original_response=stringified_response,
additional_args={
"headers": headers,
"api_version": api_version,
"api_base": api_base,
},
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except AzureOpenAIError as e: except AzureOpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
raise e raise e

View file

@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
else: else:
openai_client = client openai_client = client
response = openai_client.chat.completions.create(**data) # type: ignore response = openai_client.chat.completions.create(**data) # type: ignore
stringified_response = response.model_dump_json()
logging_obj.post_call( logging_obj.post_call(
input=None, input=messages,
api_key=api_key, api_key=api_key,
original_response=response, original_response=stringified_response,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
) )
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response) return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
except Exception as e: except Exception as e:
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e): if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility

View file

@ -319,7 +319,6 @@ def completion(
######### unpacking kwargs ##################### ######### unpacking kwargs #####################
args = locals() args = locals()
api_base = kwargs.get('api_base', None) api_base = kwargs.get('api_base', None)
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None) mock_response = kwargs.get('mock_response', None)
force_timeout= kwargs.get('force_timeout', 600) ## deprecated force_timeout= kwargs.get('force_timeout', 600) ## deprecated
logger_fn = kwargs.get('logger_fn', None) logger_fn = kwargs.get('logger_fn', None)
@ -351,7 +350,7 @@ def completion(
client = kwargs.get("client", None) client = kwargs.get("client", None)
######## end of unpacking kwargs ########### ######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"] openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"] litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response: if mock_response:
@ -449,7 +448,6 @@ def completion(
# For logging - save the values of the litellm-specific params passed in # For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params( litellm_params = get_litellm_params(
acompletion=acompletion, acompletion=acompletion,
return_async=return_async,
api_key=api_key, api_key=api_key,
force_timeout=force_timeout, force_timeout=force_timeout,
logger_fn=logger_fn, logger_fn=logger_fn,
@ -526,17 +524,18 @@ def completion(
client=client # pass AsyncAzureOpenAI, AzureOpenAI client client=client # pass AsyncAzureOpenAI, AzureOpenAI client
) )
## LOGGING if optional_params.get("stream", False) or acompletion == True:
logging.post_call( ## LOGGING
input=messages, logging.post_call(
api_key=api_key, input=messages,
original_response=response, api_key=api_key,
additional_args={ original_response=response,
"headers": headers, additional_args={
"api_version": api_version, "headers": headers,
"api_base": api_base, "api_version": api_version,
}, "api_base": api_base,
) },
)
elif ( elif (
model in litellm.open_ai_chat_completion_models model in litellm.open_ai_chat_completion_models
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
@ -606,13 +605,14 @@ def completion(
) )
raise e raise e
## LOGGING if optional_params.get("stream", False) or acompletion == True:
logging.post_call( ## LOGGING
input=messages, logging.post_call(
api_key=api_key, input=messages,
original_response=response, api_key=api_key,
additional_args={"headers": headers}, original_response=response,
) additional_args={"headers": headers},
)
elif ( elif (
custom_llm_provider == "text-completion-openai" custom_llm_provider == "text-completion-openai"
or "ft:babbage-002" in model or "ft:babbage-002" in model
@ -1787,7 +1787,7 @@ def embedding(
proxy_server_request = kwargs.get("proxy_server_request", None) proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None) aembedding = kwargs.pop("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"] openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
default_params = openai_params + litellm_params default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {} optional_params = {}

View file

@ -36,9 +36,12 @@ class MyCustomHandler(CustomLogger):
def log_success_event(self, kwargs, response_obj, start_time, end_time): def log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose("On Success!") print_verbose("On Success!")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print_verbose(f"On Async Success!") print_verbose(f"On Async Success!")
response_cost = litellm.completion_cost(completion_response=response_obj)
assert response_cost > 0.0
return return
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):

View file

@ -262,41 +262,43 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid: if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys") raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
if prisma_client: if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
## check for cache hit (In-Memory Cache) raise Exception("No connected db.")
valid_token = user_api_key_cache.get_cache(key=api_key)
print(f"valid_token from cache: {valid_token}") ## check for cache hit (In-Memory Cache)
if valid_token is None: valid_token = user_api_key_cache.get_cache(key=api_key)
## check db print(f"valid_token from cache: {valid_token}")
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow()) if valid_token is None:
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60) ## check db
elif valid_token is not None: valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
print(f"API Key Cache Hit!") user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
if valid_token: elif valid_token is not None:
litellm.model_alias_map = valid_token.aliases print(f"API Key Cache Hit!")
config = valid_token.config if valid_token:
if config != {}: litellm.model_alias_map = valid_token.aliases
model_list = config.get("model_list", []) config = valid_token.config
llm_model_list = model_list if config != {}:
print("\n new llm router model list", llm_model_list) model_list = config.get("model_list", [])
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called llm_model_list = model_list
api_key = valid_token.token print("\n new llm router model list", llm_model_list)
valid_token_dict = _get_pydantic_json_dict(valid_token) if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
data = await request.json()
model = data.get("model", None)
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
api_key = valid_token.token api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token) valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None) valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict) return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else: else:
raise Exception(f"Invalid token") data = await request.json()
model = data.get("model", None)
if model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
raise Exception(f"Token not allowed to access model")
api_key = valid_token.token
valid_token_dict = _get_pydantic_json_dict(valid_token)
valid_token_dict.pop("token", None)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception(f"Invalid token")
except Exception as e: except Exception as e:
print(f"An exception occurred - {traceback.format_exc()}") print(f"An exception occurred - {traceback.format_exc()}")
if isinstance(e, HTTPException): if isinstance(e, HTTPException):
@ -380,25 +382,14 @@ async def track_cost_callback(
if "complete_streaming_response" in kwargs: if "complete_streaming_response" in kwargs:
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost # for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
completion_response=kwargs["complete_streaming_response"] completion_response=kwargs["complete_streaming_response"]
input_text = kwargs["messages"] response_cost = litellm.completion_cost(completion_response=completion_response)
output_text = completion_response["choices"][0]["message"]["content"]
response_cost = litellm.completion_cost(
model = kwargs["model"],
messages = input_text,
completion=output_text
)
print("streaming response_cost", response_cost) print("streaming response_cost", response_cost)
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None) user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}") print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
if user_api_key and prisma_client: if user_api_key and prisma_client:
await update_prisma_database(token=user_api_key, response_cost=response_cost) await update_prisma_database(token=user_api_key, response_cost=response_cost)
elif kwargs["stream"] == False: # for non streaming responses elif kwargs["stream"] == False: # for non streaming responses
input_text = kwargs.get("messages", "") response_cost = litellm.completion_cost(completion_response=completion_response)
print(f"type of input_text: {type(input_text)}")
if isinstance(input_text, list):
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
elif isinstance(input_text, str):
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
print(f"received completion response: {completion_response}") print(f"received completion response: {completion_response}")
print(f"regular response_cost: {response_cost}") print(f"regular response_cost: {response_cost}")

View file

@ -104,7 +104,7 @@ class ProxyLogging:
2. /embeddings 2. /embeddings
""" """
# check if max parallel requests set # check if max parallel requests set
if user_api_key_dict.max_parallel_requests is not None: if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
## decrement call count if call failed ## decrement call count if call failed
if (hasattr(original_exception, "status_code") if (hasattr(original_exception, "status_code")
and original_exception.status_code == 429 and original_exception.status_code == 429

View file

@ -0,0 +1,386 @@
### What this tests ####
## This test asserts the type of data passed into each method of the custom callback handler
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional
from litellm import completion, embedding
import litellm
from litellm.integrations.custom_logger import CustomLogger
# Test Scenarios (test across completion, streaming, embedding)
## 1: Pre-API-Call
## 2: Post-API-Call
## 3: On LiteLLM Call success
## 4: On LiteLLM Call failure
# Test models
## 1. OpenAI
## 2. Azure OpenAI
## 3. Non-OpenAI/Azure - e.g. Bedrock
# Test interfaces
## 1. litellm.completion() + litellm.embeddings()
## 2. router.completion() + router.embeddings()
## 3. proxy.completions + proxy.embeddings
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
"""
The set of expected inputs to a custom handler for a
"""
# Class variables or attributes
def __init__(self):
self.errors = []
def log_pre_api_call(self, model, messages, kwargs):
try:
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert end_time == None
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
except Exception as e:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], str) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
assert isinstance(kwargs['api_key'], str)
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
## Test OpenAI + sync
def test_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync openai"
}])
## test streaming
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
for chunk in response:
continue
## test failure callback
try:
response = litellm.completion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_openai_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_openai_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}])
## test streaming
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="gpt-3.5-turbo",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm openai"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_openai_stream())
## Test Azure + sync
def test_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}])
# test streaming
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
stream=True)
for chunk in response:
continue
# test failure callback
try:
response = litellm.completion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm sync azure"
}],
api_key="my-bad-key",
stream=True)
for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# test_chat_azure_stream()
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_chat_azure_stream():
try:
customHandler = CompletionCustomHandler()
litellm.callbacks = [customHandler]
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}])
## test streaming
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
stream=True)
async for chunk in response:
continue
## test failure callback
try:
response = await litellm.acompletion(model="azure/chatgpt-v-2",
messages=[{
"role": "user",
"content": "Hi 👋 - i'm async azure"
}],
api_key="my-bad-key",
stream=True)
async for chunk in response:
continue
except:
pass
time.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_azure_stream())

View file

@ -801,9 +801,6 @@ class Logging:
end_time = datetime.datetime.now() end_time = datetime.datetime.now()
self.model_call_details["log_event_type"] = "successful_api_call" self.model_call_details["log_event_type"] = "successful_api_call"
self.model_call_details["end_time"] = end_time self.model_call_details["end_time"] = end_time
if isinstance(result, OpenAIObject):
result = result.model_dump()
if litellm.max_budget and self.stream: if litellm.max_budget and self.stream:
time_diff = (end_time - start_time).total_seconds() time_diff = (end_time - start_time).total_seconds()
@ -857,9 +854,6 @@ class Logging:
call_type = self.call_type, call_type = self.call_type,
stream = self.stream, stream = self.stream,
) )
if callback == "api_manager":
print_verbose("reaches api manager for updating model cost")
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
if callback == "promptlayer": if callback == "promptlayer":
print_verbose("reaches promptlayer for logging!") print_verbose("reaches promptlayer for logging!")
promptLayerLogger.log_event( promptLayerLogger.log_event(
@ -994,7 +988,7 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
if isinstance(callback, CustomLogger): # custom logger class if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
print_verbose(f"success callbacks: Running Custom Logger Class") print_verbose(f"success callbacks: Running Custom Logger Class")
if self.stream and complete_streaming_response is None: if self.stream and complete_streaming_response is None:
callback.log_stream_event( callback.log_stream_event(
@ -1044,7 +1038,6 @@ class Logging:
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
print_verbose(f"Async success callbacks: {litellm._async_success_callback}") print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream: if self.stream:
@ -1081,6 +1074,13 @@ class Logging:
start_time=start_time, start_time=start_time,
end_time=end_time, end_time=end_time,
) )
else:
await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function
kwargs=self.model_call_details,
response_obj=result,
start_time=start_time,
end_time=end_time
)
else: else:
await callback.async_log_success_event( await callback.async_log_success_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1103,24 +1103,29 @@ class Logging:
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
) )
def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None):
if start_time is None:
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
self.model_call_details.setdefault("original_response", None)
return start_time, end_time
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None): def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
print_verbose( print_verbose(
f"Logging Details LiteLLM-Failure Call" f"Logging Details LiteLLM-Failure Call"
) )
try: try:
if start_time is None: start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
start_time = self.start_time
if end_time is None:
end_time = datetime.datetime.now()
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
if not hasattr(self, "model_call_details"):
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
result = None # result sent to all loggers, init this to None incase it's not created result = None # result sent to all loggers, init this to None incase it's not created
for callback in litellm.failure_callback: for callback in litellm.failure_callback:
try: try:
@ -1212,16 +1217,8 @@ class Logging:
""" """
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions. Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
""" """
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
if not hasattr(self, "model_call_details"): result = None # result sent to all loggers, init this to None incase it's not created
self.model_call_details = {}
self.model_call_details["log_event_type"] = "failed_api_call"
self.model_call_details["exception"] = exception
self.model_call_details["traceback_exception"] = traceback_exception
self.model_call_details["end_time"] = end_time
result = {} # result sent to all loggers, init this to None incase it's not created
for callback in litellm._async_failure_callback: for callback in litellm._async_failure_callback:
try: try:
if isinstance(callback, CustomLogger): # custom logger class if isinstance(callback, CustomLogger): # custom logger class
@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]):
return model_cost return model_cost
def get_litellm_params( def get_litellm_params(
return_async=False,
api_key=None, api_key=None,
force_timeout=600, force_timeout=600,
azure=False, azure=False,
@ -2082,7 +2078,6 @@ def get_litellm_params(
): ):
litellm_params = { litellm_params = {
"acompletion": acompletion, "acompletion": acompletion,
"return_async": return_async,
"api_key": api_key, "api_key": api_key,
"force_timeout": force_timeout, "force_timeout": force_timeout,
"logger_fn": logger_fn, "logger_fn": logger_fn,
@ -5094,9 +5089,6 @@ class CustomStreamWrapper:
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"] self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
self.holding_chunk = "" self.holding_chunk = ""
self.complete_response = "" self.complete_response = ""
if self.logging_obj:
# Log the type of the received item
self.logging_obj.post_call(str(type(completion_stream)))
def __iter__(self): def __iter__(self):
return self return self
@ -5121,10 +5113,6 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
raise e raise e
def logging(self, text):
if self.logging_obj:
self.logging_obj.post_call(text)
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]): def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
hold = False hold = False
if finish_reason: if finish_reason:
@ -5638,16 +5626,12 @@ class CustomStreamWrapper:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
print_verbose(f"model_response: {model_response}") print_verbose(f"model_response: {model_response}")
return model_response return model_response
else: else:
return return
elif model_response.choices[0].finish_reason: elif model_response.choices[0].finish_reason:
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
return model_response return model_response
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
# enter this branch when no content has been passed in response # enter this branch when no content has been passed in response
@ -5668,8 +5652,6 @@ class CustomStreamWrapper:
if self.sent_first_chunk == False: if self.sent_first_chunk == False:
model_response.choices[0].delta["role"] = "assistant" model_response.choices[0].delta["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
# LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
return model_response return model_response
else: else:
return return
@ -5678,8 +5660,6 @@ class CustomStreamWrapper:
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc() traceback_exception = traceback.format_exc()
e.message = str(e) e.message = str(e)
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e) raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
## needs to handle the empty string case (even starting chunk can be an empty string) ## needs to handle the empty string case (even starting chunk can be an empty string)
@ -5692,12 +5672,17 @@ class CustomStreamWrapper:
chunk = next(self.completion_stream) chunk = next(self.completion_stream)
if chunk is not None and chunk != b'': if chunk is not None and chunk != b'':
response = self.chunk_creator(chunk=chunk) response = self.chunk_creator(chunk=chunk)
if response is not None: if response is None:
return response continue
## LOGGING
threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response
return response
except StopIteration: except StopIteration:
raise # Re-raise StopIteration raise # Re-raise StopIteration
except Exception as e: except Exception as e:
# Handle other exceptions if needed traceback_exception = traceback.format_exc()
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
raise e raise e
@ -5728,7 +5713,9 @@ class CustomStreamWrapper:
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,)) asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
return processed_chunk return processed_chunk
except Exception as e: except Exception as e:
traceback_exception = traceback.format_exc()
# Handle any exceptions that might occur during streaming # Handle any exceptions that might occur during streaming
asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception))
raise StopAsyncIteration raise StopAsyncIteration
class TextCompletionStreamWrapper: class TextCompletionStreamWrapper: