mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs
This commit is contained in:
parent
1d2f5ce975
commit
ea89a8a938
8 changed files with 501 additions and 122 deletions
|
@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
azure_client = client
|
||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||
response.model = "azure/" + str(response.model)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=stringified_response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except AzureOpenAIError as e:
|
||||
exception_mapping_worked = True
|
||||
raise e
|
||||
|
|
|
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_client = client
|
||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||
stringified_response = response.model_dump_json()
|
||||
logging_obj.post_call(
|
||||
input=None,
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
original_response=stringified_response,
|
||||
additional_args={"complete_input_dict": data},
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||
except Exception as e:
|
||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||
|
|
|
@ -319,7 +319,6 @@ def completion(
|
|||
######### unpacking kwargs #####################
|
||||
args = locals()
|
||||
api_base = kwargs.get('api_base', None)
|
||||
return_async = kwargs.get('return_async', False)
|
||||
mock_response = kwargs.get('mock_response', None)
|
||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||
logger_fn = kwargs.get('logger_fn', None)
|
||||
|
@ -351,7 +350,7 @@ def completion(
|
|||
client = kwargs.get("client", None)
|
||||
######## end of unpacking kwargs ###########
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
|
@ -449,7 +448,6 @@ def completion(
|
|||
# For logging - save the values of the litellm-specific params passed in
|
||||
litellm_params = get_litellm_params(
|
||||
acompletion=acompletion,
|
||||
return_async=return_async,
|
||||
api_key=api_key,
|
||||
force_timeout=force_timeout,
|
||||
logger_fn=logger_fn,
|
||||
|
@ -526,17 +524,18 @@ def completion(
|
|||
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={
|
||||
"headers": headers,
|
||||
"api_version": api_version,
|
||||
"api_base": api_base,
|
||||
},
|
||||
)
|
||||
elif (
|
||||
model in litellm.open_ai_chat_completion_models
|
||||
or custom_llm_provider == "custom_openai"
|
||||
|
@ -606,13 +605,14 @@ def completion(
|
|||
)
|
||||
raise e
|
||||
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
api_key=api_key,
|
||||
original_response=response,
|
||||
additional_args={"headers": headers},
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "text-completion-openai"
|
||||
or "ft:babbage-002" in model
|
||||
|
@ -1787,7 +1787,7 @@ def embedding(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.pop("aembedding", None)
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = {}
|
||||
|
|
|
@ -36,9 +36,12 @@ class MyCustomHandler(CustomLogger):
|
|||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose("On Success!")
|
||||
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print_verbose(f"On Async Success!")
|
||||
response_cost = litellm.completion_cost(completion_response=response_obj)
|
||||
assert response_cost > 0.0
|
||||
return
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
|
|
@ -262,41 +262,43 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||
|
||||
if prisma_client:
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
raise Exception("No connected db.")
|
||||
|
||||
## check for cache hit (In-Memory Cache)
|
||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||
print(f"valid_token from cache: {valid_token}")
|
||||
if valid_token is None:
|
||||
## check db
|
||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||
elif valid_token is not None:
|
||||
print(f"API Key Cache Hit!")
|
||||
if valid_token:
|
||||
litellm.model_alias_map = valid_token.aliases
|
||||
config = valid_token.config
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
print("\n new llm router model list", llm_model_list)
|
||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
raise Exception(f"Token not allowed to access model")
|
||||
api_key = valid_token.token
|
||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||
valid_token_dict.pop("token", None)
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
if isinstance(e, HTTPException):
|
||||
|
@ -380,25 +382,14 @@ async def track_cost_callback(
|
|||
if "complete_streaming_response" in kwargs:
|
||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||
completion_response=kwargs["complete_streaming_response"]
|
||||
input_text = kwargs["messages"]
|
||||
output_text = completion_response["choices"][0]["message"]["content"]
|
||||
response_cost = litellm.completion_cost(
|
||||
model = kwargs["model"],
|
||||
messages = input_text,
|
||||
completion=output_text
|
||||
)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
print("streaming response_cost", response_cost)
|
||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||
if user_api_key and prisma_client:
|
||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||
elif kwargs["stream"] == False: # for non streaming responses
|
||||
input_text = kwargs.get("messages", "")
|
||||
print(f"type of input_text: {type(input_text)}")
|
||||
if isinstance(input_text, list):
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
||||
elif isinstance(input_text, str):
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
||||
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||
print(f"received completion response: {completion_response}")
|
||||
|
||||
print(f"regular response_cost: {response_cost}")
|
||||
|
|
|
@ -104,7 +104,7 @@ class ProxyLogging:
|
|||
2. /embeddings
|
||||
"""
|
||||
# check if max parallel requests set
|
||||
if user_api_key_dict.max_parallel_requests is not None:
|
||||
if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
|
||||
## decrement call count if call failed
|
||||
if (hasattr(original_exception, "status_code")
|
||||
and original_exception.status_code == 429
|
||||
|
|
386
litellm/tests/test_custom_callback_input.py
Normal file
386
litellm/tests/test_custom_callback_input.py
Normal file
|
@ -0,0 +1,386 @@
|
|||
### What this tests ####
|
||||
## This test asserts the type of data passed into each method of the custom callback handler
|
||||
import sys, os, time, inspect, asyncio, traceback
|
||||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional
|
||||
from litellm import completion, embedding
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
||||
# Test Scenarios (test across completion, streaming, embedding)
|
||||
## 1: Pre-API-Call
|
||||
## 2: Post-API-Call
|
||||
## 3: On LiteLLM Call success
|
||||
## 4: On LiteLLM Call failure
|
||||
|
||||
# Test models
|
||||
## 1. OpenAI
|
||||
## 2. Azure OpenAI
|
||||
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
||||
|
||||
# Test interfaces
|
||||
## 1. litellm.completion() + litellm.embeddings()
|
||||
## 2. router.completion() + router.embeddings()
|
||||
## 3. proxy.completions + proxy.embeddings
|
||||
|
||||
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
"""
|
||||
The set of expected inputs to a custom handler for a
|
||||
"""
|
||||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert end_time == None
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||
assert isinstance(kwargs['api_key'], str)
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||
assert isinstance(kwargs['api_key'], str)
|
||||
assert inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||
assert isinstance(kwargs['api_key'], str)
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||
assert isinstance(kwargs['api_key'], str)
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
except Exception as e:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||
assert isinstance(kwargs['api_key'], str)
|
||||
assert isinstance(kwargs['original_response'], str) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||
assert isinstance(kwargs['api_key'], str)
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
## Test OpenAI + sync
|
||||
def test_chat_openai_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync openai"
|
||||
}])
|
||||
## test streaming
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_openai_stream()
|
||||
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_openai_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}])
|
||||
## test streaming
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm openai"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_openai_stream())
|
||||
|
||||
## Test Azure + sync
|
||||
def test_chat_azure_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}])
|
||||
# test streaming
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}],
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
# test failure callback
|
||||
try:
|
||||
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm sync azure"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# test_chat_azure_stream()
|
||||
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_chat_azure_stream():
|
||||
try:
|
||||
customHandler = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler]
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}])
|
||||
## test streaming
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}],
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
## test failure callback
|
||||
try:
|
||||
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||
messages=[{
|
||||
"role": "user",
|
||||
"content": "Hi 👋 - i'm async azure"
|
||||
}],
|
||||
api_key="my-bad-key",
|
||||
stream=True)
|
||||
async for chunk in response:
|
||||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_azure_stream())
|
|
@ -801,9 +801,6 @@ class Logging:
|
|||
end_time = datetime.datetime.now()
|
||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||
self.model_call_details["end_time"] = end_time
|
||||
|
||||
if isinstance(result, OpenAIObject):
|
||||
result = result.model_dump()
|
||||
|
||||
if litellm.max_budget and self.stream:
|
||||
time_diff = (end_time - start_time).total_seconds()
|
||||
|
@ -857,9 +854,6 @@ class Logging:
|
|||
call_type = self.call_type,
|
||||
stream = self.stream,
|
||||
)
|
||||
if callback == "api_manager":
|
||||
print_verbose("reaches api manager for updating model cost")
|
||||
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
|
||||
if callback == "promptlayer":
|
||||
print_verbose("reaches promptlayer for logging!")
|
||||
promptLayerLogger.log_event(
|
||||
|
@ -994,7 +988,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
|
||||
print_verbose(f"success callbacks: Running Custom Logger Class")
|
||||
if self.stream and complete_streaming_response is None:
|
||||
callback.log_stream_event(
|
||||
|
@ -1044,7 +1038,6 @@ class Logging:
|
|||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
|
||||
|
||||
## BUILD COMPLETE STREAMED RESPONSE
|
||||
complete_streaming_response = None
|
||||
if self.stream:
|
||||
|
@ -1081,6 +1074,13 @@ class Logging:
|
|||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
else:
|
||||
await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function
|
||||
kwargs=self.model_call_details,
|
||||
response_obj=result,
|
||||
start_time=start_time,
|
||||
end_time=end_time
|
||||
)
|
||||
else:
|
||||
await callback.async_log_success_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -1103,24 +1103,29 @@ class Logging:
|
|||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
if end_time is None:
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||
if not hasattr(self, "model_call_details"):
|
||||
self.model_call_details = {}
|
||||
|
||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||
self.model_call_details["exception"] = exception
|
||||
self.model_call_details["traceback_exception"] = traceback_exception
|
||||
self.model_call_details["end_time"] = end_time
|
||||
self.model_call_details.setdefault("original_response", None)
|
||||
return start_time, end_time
|
||||
|
||||
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||
print_verbose(
|
||||
f"Logging Details LiteLLM-Failure Call"
|
||||
)
|
||||
try:
|
||||
if start_time is None:
|
||||
start_time = self.start_time
|
||||
if end_time is None:
|
||||
end_time = datetime.datetime.now()
|
||||
|
||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||
if not hasattr(self, "model_call_details"):
|
||||
self.model_call_details = {}
|
||||
|
||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||
self.model_call_details["exception"] = exception
|
||||
self.model_call_details["traceback_exception"] = traceback_exception
|
||||
self.model_call_details["end_time"] = end_time
|
||||
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
for callback in litellm.failure_callback:
|
||||
try:
|
||||
|
@ -1212,16 +1217,8 @@ class Logging:
|
|||
"""
|
||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||
"""
|
||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||
if not hasattr(self, "model_call_details"):
|
||||
self.model_call_details = {}
|
||||
|
||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||
self.model_call_details["exception"] = exception
|
||||
self.model_call_details["traceback_exception"] = traceback_exception
|
||||
self.model_call_details["end_time"] = end_time
|
||||
result = {} # result sent to all loggers, init this to None incase it's not created
|
||||
|
||||
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
|
||||
result = None # result sent to all loggers, init this to None incase it's not created
|
||||
for callback in litellm._async_failure_callback:
|
||||
try:
|
||||
if isinstance(callback, CustomLogger): # custom logger class
|
||||
|
@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]):
|
|||
return model_cost
|
||||
|
||||
def get_litellm_params(
|
||||
return_async=False,
|
||||
api_key=None,
|
||||
force_timeout=600,
|
||||
azure=False,
|
||||
|
@ -2082,7 +2078,6 @@ def get_litellm_params(
|
|||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
"return_async": return_async,
|
||||
"api_key": api_key,
|
||||
"force_timeout": force_timeout,
|
||||
"logger_fn": logger_fn,
|
||||
|
@ -5094,9 +5089,6 @@ class CustomStreamWrapper:
|
|||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||
self.holding_chunk = ""
|
||||
self.complete_response = ""
|
||||
if self.logging_obj:
|
||||
# Log the type of the received item
|
||||
self.logging_obj.post_call(str(type(completion_stream)))
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
@ -5121,10 +5113,6 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def logging(self, text):
|
||||
if self.logging_obj:
|
||||
self.logging_obj.post_call(text)
|
||||
|
||||
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
||||
hold = False
|
||||
if finish_reason:
|
||||
|
@ -5638,16 +5626,12 @@ class CustomStreamWrapper:
|
|||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||
print_verbose(f"model_response: {model_response}")
|
||||
return model_response
|
||||
else:
|
||||
return
|
||||
elif model_response.choices[0].finish_reason:
|
||||
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
||||
return model_response
|
||||
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
|
||||
# enter this branch when no content has been passed in response
|
||||
|
@ -5668,8 +5652,6 @@ class CustomStreamWrapper:
|
|||
if self.sent_first_chunk == False:
|
||||
model_response.choices[0].delta["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
# LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
|
||||
return model_response
|
||||
else:
|
||||
return
|
||||
|
@ -5678,8 +5660,6 @@ class CustomStreamWrapper:
|
|||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
e.message = str(e)
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
|
||||
|
||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||
|
@ -5692,12 +5672,17 @@ class CustomStreamWrapper:
|
|||
chunk = next(self.completion_stream)
|
||||
if chunk is not None and chunk != b'':
|
||||
response = self.chunk_creator(chunk=chunk)
|
||||
if response is not None:
|
||||
return response
|
||||
if response is None:
|
||||
continue
|
||||
## LOGGING
|
||||
threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response
|
||||
return response
|
||||
except StopIteration:
|
||||
raise # Re-raise StopIteration
|
||||
except Exception as e:
|
||||
# Handle other exceptions if needed
|
||||
traceback_exception = traceback.format_exc()
|
||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||
raise e
|
||||
|
||||
|
||||
|
@ -5728,7 +5713,9 @@ class CustomStreamWrapper:
|
|||
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
|
||||
return processed_chunk
|
||||
except Exception as e:
|
||||
traceback_exception = traceback.format_exc()
|
||||
# Handle any exceptions that might occur during streaming
|
||||
asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception))
|
||||
raise StopAsyncIteration
|
||||
|
||||
class TextCompletionStreamWrapper:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue