mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
test(test_custom_callback_unit.py): adding unit tests for custom callbacks + fixing related bugs
This commit is contained in:
parent
1d2f5ce975
commit
ea89a8a938
8 changed files with 501 additions and 122 deletions
|
@ -196,8 +196,19 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
azure_client = client
|
azure_client = client
|
||||||
response = azure_client.chat.completions.create(**data) # type: ignore
|
response = azure_client.chat.completions.create(**data) # type: ignore
|
||||||
response.model = "azure/" + str(response.model)
|
stringified_response = response.model_dump_json()
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=stringified_response,
|
||||||
|
additional_args={
|
||||||
|
"headers": headers,
|
||||||
|
"api_version": api_version,
|
||||||
|
"api_base": api_base,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||||
except AzureOpenAIError as e:
|
except AzureOpenAIError as e:
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
raise e
|
raise e
|
||||||
|
|
|
@ -219,13 +219,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_client = client
|
openai_client = client
|
||||||
response = openai_client.chat.completions.create(**data) # type: ignore
|
response = openai_client.chat.completions.create(**data) # type: ignore
|
||||||
|
stringified_response = response.model_dump_json()
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=None,
|
input=messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
original_response=response,
|
original_response=stringified_response,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
)
|
)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response)
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
if "Conversation roles must alternate user/assistant" in str(e) or "user and assistant roles should be alternating" in str(e):
|
||||||
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
# reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility
|
||||||
|
|
|
@ -319,7 +319,6 @@ def completion(
|
||||||
######### unpacking kwargs #####################
|
######### unpacking kwargs #####################
|
||||||
args = locals()
|
args = locals()
|
||||||
api_base = kwargs.get('api_base', None)
|
api_base = kwargs.get('api_base', None)
|
||||||
return_async = kwargs.get('return_async', False)
|
|
||||||
mock_response = kwargs.get('mock_response', None)
|
mock_response = kwargs.get('mock_response', None)
|
||||||
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
force_timeout= kwargs.get('force_timeout', 600) ## deprecated
|
||||||
logger_fn = kwargs.get('logger_fn', None)
|
logger_fn = kwargs.get('logger_fn', None)
|
||||||
|
@ -351,7 +350,7 @@ def completion(
|
||||||
client = kwargs.get("client", None)
|
client = kwargs.get("client", None)
|
||||||
######## end of unpacking kwargs ###########
|
######## end of unpacking kwargs ###########
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "model_info", "proxy_server_request", "preset_cache_key"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
if mock_response:
|
if mock_response:
|
||||||
|
@ -449,7 +448,6 @@ def completion(
|
||||||
# For logging - save the values of the litellm-specific params passed in
|
# For logging - save the values of the litellm-specific params passed in
|
||||||
litellm_params = get_litellm_params(
|
litellm_params = get_litellm_params(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
return_async=return_async,
|
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
force_timeout=force_timeout,
|
force_timeout=force_timeout,
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
|
@ -526,17 +524,18 @@ def completion(
|
||||||
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
client=client # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||||
)
|
)
|
||||||
|
|
||||||
## LOGGING
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
logging.post_call(
|
## LOGGING
|
||||||
input=messages,
|
logging.post_call(
|
||||||
api_key=api_key,
|
input=messages,
|
||||||
original_response=response,
|
api_key=api_key,
|
||||||
additional_args={
|
original_response=response,
|
||||||
"headers": headers,
|
additional_args={
|
||||||
"api_version": api_version,
|
"headers": headers,
|
||||||
"api_base": api_base,
|
"api_version": api_version,
|
||||||
},
|
"api_base": api_base,
|
||||||
)
|
},
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
model in litellm.open_ai_chat_completion_models
|
model in litellm.open_ai_chat_completion_models
|
||||||
or custom_llm_provider == "custom_openai"
|
or custom_llm_provider == "custom_openai"
|
||||||
|
@ -606,13 +605,14 @@ def completion(
|
||||||
)
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
## LOGGING
|
if optional_params.get("stream", False) or acompletion == True:
|
||||||
logging.post_call(
|
## LOGGING
|
||||||
input=messages,
|
logging.post_call(
|
||||||
api_key=api_key,
|
input=messages,
|
||||||
original_response=response,
|
api_key=api_key,
|
||||||
additional_args={"headers": headers},
|
original_response=response,
|
||||||
)
|
additional_args={"headers": headers},
|
||||||
|
)
|
||||||
elif (
|
elif (
|
||||||
custom_llm_provider == "text-completion-openai"
|
custom_llm_provider == "text-completion-openai"
|
||||||
or "ft:babbage-002" in model
|
or "ft:babbage-002" in model
|
||||||
|
@ -1787,7 +1787,7 @@ def embedding(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.pop("aembedding", None)
|
aembedding = kwargs.pop("aembedding", None)
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
|
|
|
@ -36,9 +36,12 @@ class MyCustomHandler(CustomLogger):
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print_verbose("On Success!")
|
print_verbose("On Success!")
|
||||||
|
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
print_verbose(f"On Async Success!")
|
print_verbose(f"On Async Success!")
|
||||||
|
response_cost = litellm.completion_cost(completion_response=response_obj)
|
||||||
|
assert response_cost > 0.0
|
||||||
return
|
return
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
|
|
@ -262,41 +262,43 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||||
|
|
||||||
if prisma_client:
|
if prisma_client is None: # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||||
## check for cache hit (In-Memory Cache)
|
raise Exception("No connected db.")
|
||||||
valid_token = user_api_key_cache.get_cache(key=api_key)
|
|
||||||
print(f"valid_token from cache: {valid_token}")
|
## check for cache hit (In-Memory Cache)
|
||||||
if valid_token is None:
|
valid_token = user_api_key_cache.get_cache(key=api_key)
|
||||||
## check db
|
print(f"valid_token from cache: {valid_token}")
|
||||||
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
if valid_token is None:
|
||||||
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
## check db
|
||||||
elif valid_token is not None:
|
valid_token = await prisma_client.get_data(token=api_key, expires=datetime.utcnow())
|
||||||
print(f"API Key Cache Hit!")
|
user_api_key_cache.set_cache(key=api_key, value=valid_token, ttl=60)
|
||||||
if valid_token:
|
elif valid_token is not None:
|
||||||
litellm.model_alias_map = valid_token.aliases
|
print(f"API Key Cache Hit!")
|
||||||
config = valid_token.config
|
if valid_token:
|
||||||
if config != {}:
|
litellm.model_alias_map = valid_token.aliases
|
||||||
model_list = config.get("model_list", [])
|
config = valid_token.config
|
||||||
llm_model_list = model_list
|
if config != {}:
|
||||||
print("\n new llm router model list", llm_model_list)
|
model_list = config.get("model_list", [])
|
||||||
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
llm_model_list = model_list
|
||||||
api_key = valid_token.token
|
print("\n new llm router model list", llm_model_list)
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
if len(valid_token.models) == 0: # assume an empty model list means all models are allowed to be called
|
||||||
valid_token_dict.pop("token", None)
|
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
|
||||||
else:
|
|
||||||
data = await request.json()
|
|
||||||
model = data.get("model", None)
|
|
||||||
if model in litellm.model_alias_map:
|
|
||||||
model = litellm.model_alias_map[model]
|
|
||||||
if model and model not in valid_token.models:
|
|
||||||
raise Exception(f"Token not allowed to access model")
|
|
||||||
api_key = valid_token.token
|
api_key = valid_token.token
|
||||||
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
valid_token_dict.pop("token", None)
|
valid_token_dict.pop("token", None)
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
data = await request.json()
|
||||||
|
model = data.get("model", None)
|
||||||
|
if model in litellm.model_alias_map:
|
||||||
|
model = litellm.model_alias_map[model]
|
||||||
|
if model and model not in valid_token.models:
|
||||||
|
raise Exception(f"Token not allowed to access model")
|
||||||
|
api_key = valid_token.token
|
||||||
|
valid_token_dict = _get_pydantic_json_dict(valid_token)
|
||||||
|
valid_token_dict.pop("token", None)
|
||||||
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
|
else:
|
||||||
|
raise Exception(f"Invalid token")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An exception occurred - {traceback.format_exc()}")
|
print(f"An exception occurred - {traceback.format_exc()}")
|
||||||
if isinstance(e, HTTPException):
|
if isinstance(e, HTTPException):
|
||||||
|
@ -380,25 +382,14 @@ async def track_cost_callback(
|
||||||
if "complete_streaming_response" in kwargs:
|
if "complete_streaming_response" in kwargs:
|
||||||
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
# for tracking streaming cost we pass the "messages" and the output_text to litellm.completion_cost
|
||||||
completion_response=kwargs["complete_streaming_response"]
|
completion_response=kwargs["complete_streaming_response"]
|
||||||
input_text = kwargs["messages"]
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
output_text = completion_response["choices"][0]["message"]["content"]
|
|
||||||
response_cost = litellm.completion_cost(
|
|
||||||
model = kwargs["model"],
|
|
||||||
messages = input_text,
|
|
||||||
completion=output_text
|
|
||||||
)
|
|
||||||
print("streaming response_cost", response_cost)
|
print("streaming response_cost", response_cost)
|
||||||
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
user_api_key = kwargs["litellm_params"]["metadata"].get("user_api_key", None)
|
||||||
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
print(f"user_api_key - {user_api_key}; prisma_client - {prisma_client}")
|
||||||
if user_api_key and prisma_client:
|
if user_api_key and prisma_client:
|
||||||
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
await update_prisma_database(token=user_api_key, response_cost=response_cost)
|
||||||
elif kwargs["stream"] == False: # for non streaming responses
|
elif kwargs["stream"] == False: # for non streaming responses
|
||||||
input_text = kwargs.get("messages", "")
|
response_cost = litellm.completion_cost(completion_response=completion_response)
|
||||||
print(f"type of input_text: {type(input_text)}")
|
|
||||||
if isinstance(input_text, list):
|
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, messages=input_text)
|
|
||||||
elif isinstance(input_text, str):
|
|
||||||
response_cost = litellm.completion_cost(completion_response=completion_response, prompt=input_text)
|
|
||||||
print(f"received completion response: {completion_response}")
|
print(f"received completion response: {completion_response}")
|
||||||
|
|
||||||
print(f"regular response_cost: {response_cost}")
|
print(f"regular response_cost: {response_cost}")
|
||||||
|
|
|
@ -104,7 +104,7 @@ class ProxyLogging:
|
||||||
2. /embeddings
|
2. /embeddings
|
||||||
"""
|
"""
|
||||||
# check if max parallel requests set
|
# check if max parallel requests set
|
||||||
if user_api_key_dict.max_parallel_requests is not None:
|
if user_api_key_dict is not None and user_api_key_dict.max_parallel_requests is not None:
|
||||||
## decrement call count if call failed
|
## decrement call count if call failed
|
||||||
if (hasattr(original_exception, "status_code")
|
if (hasattr(original_exception, "status_code")
|
||||||
and original_exception.status_code == 429
|
and original_exception.status_code == 429
|
||||||
|
|
386
litellm/tests/test_custom_callback_input.py
Normal file
386
litellm/tests/test_custom_callback_input.py
Normal file
|
@ -0,0 +1,386 @@
|
||||||
|
### What this tests ####
|
||||||
|
## This test asserts the type of data passed into each method of the custom callback handler
|
||||||
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
|
from datetime import datetime
|
||||||
|
import pytest
|
||||||
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
|
from typing import Optional
|
||||||
|
from litellm import completion, embedding
|
||||||
|
import litellm
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
|
||||||
|
# Test Scenarios (test across completion, streaming, embedding)
|
||||||
|
## 1: Pre-API-Call
|
||||||
|
## 2: Post-API-Call
|
||||||
|
## 3: On LiteLLM Call success
|
||||||
|
## 4: On LiteLLM Call failure
|
||||||
|
|
||||||
|
# Test models
|
||||||
|
## 1. OpenAI
|
||||||
|
## 2. Azure OpenAI
|
||||||
|
## 3. Non-OpenAI/Azure - e.g. Bedrock
|
||||||
|
|
||||||
|
# Test interfaces
|
||||||
|
## 1. litellm.completion() + litellm.embeddings()
|
||||||
|
## 2. router.completion() + router.embeddings()
|
||||||
|
## 3. proxy.completions + proxy.embeddings
|
||||||
|
|
||||||
|
class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||||
|
"""
|
||||||
|
The set of expected inputs to a custom handler for a
|
||||||
|
"""
|
||||||
|
# Class variables or attributes
|
||||||
|
def __init__(self):
|
||||||
|
self.errors = []
|
||||||
|
|
||||||
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
try:
|
||||||
|
## MODEL
|
||||||
|
assert isinstance(model, str)
|
||||||
|
## MESSAGES
|
||||||
|
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert end_time == None
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||||
|
assert isinstance(kwargs['api_key'], str)
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||||
|
assert isinstance(kwargs['api_key'], str)
|
||||||
|
assert inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||||
|
assert isinstance(kwargs['api_key'], str)
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper))
|
||||||
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||||
|
assert isinstance(kwargs['api_key'], str)
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or kwargs["original_response"] == None
|
||||||
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
|
try:
|
||||||
|
## MODEL
|
||||||
|
assert isinstance(model, str)
|
||||||
|
## MESSAGES
|
||||||
|
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert isinstance(response_obj, litellm.ModelResponse)
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||||
|
assert isinstance(kwargs['api_key'], str)
|
||||||
|
assert isinstance(kwargs['original_response'], str) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
try:
|
||||||
|
## START TIME
|
||||||
|
assert isinstance(start_time, datetime)
|
||||||
|
## END TIME
|
||||||
|
assert isinstance(end_time, datetime)
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
assert response_obj == None
|
||||||
|
## KWARGS
|
||||||
|
assert isinstance(kwargs['model'], str)
|
||||||
|
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||||
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
assert isinstance(kwargs['stream'], bool)
|
||||||
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
|
assert isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)
|
||||||
|
assert isinstance(kwargs['api_key'], str)
|
||||||
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
|
||||||
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
|
except:
|
||||||
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
## Test OpenAI + sync
|
||||||
|
def test_chat_openai_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync openai"
|
||||||
|
}])
|
||||||
|
## test streaming
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = litellm.completion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# test_chat_openai_stream()
|
||||||
|
|
||||||
|
## Test OpenAI + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_openai_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}])
|
||||||
|
## test streaming
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(model="gpt-3.5-turbo",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm openai"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_chat_openai_stream())
|
||||||
|
|
||||||
|
## Test Azure + sync
|
||||||
|
def test_chat_azure_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync azure"
|
||||||
|
}])
|
||||||
|
# test streaming
|
||||||
|
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync azure"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
# test failure callback
|
||||||
|
try:
|
||||||
|
response = litellm.completion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm sync azure"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# test_chat_azure_stream()
|
||||||
|
|
||||||
|
## Test OpenAI + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_azure_stream():
|
||||||
|
try:
|
||||||
|
customHandler = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler]
|
||||||
|
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async azure"
|
||||||
|
}])
|
||||||
|
## test streaming
|
||||||
|
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async azure"
|
||||||
|
}],
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
## test failure callback
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(model="azure/chatgpt-v-2",
|
||||||
|
messages=[{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Hi 👋 - i'm async azure"
|
||||||
|
}],
|
||||||
|
api_key="my-bad-key",
|
||||||
|
stream=True)
|
||||||
|
async for chunk in response:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
time.sleep(1)
|
||||||
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
|
assert len(customHandler.errors) == 0
|
||||||
|
litellm.callbacks = []
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_chat_azure_stream())
|
|
@ -801,9 +801,6 @@ class Logging:
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
self.model_call_details["log_event_type"] = "successful_api_call"
|
self.model_call_details["log_event_type"] = "successful_api_call"
|
||||||
self.model_call_details["end_time"] = end_time
|
self.model_call_details["end_time"] = end_time
|
||||||
|
|
||||||
if isinstance(result, OpenAIObject):
|
|
||||||
result = result.model_dump()
|
|
||||||
|
|
||||||
if litellm.max_budget and self.stream:
|
if litellm.max_budget and self.stream:
|
||||||
time_diff = (end_time - start_time).total_seconds()
|
time_diff = (end_time - start_time).total_seconds()
|
||||||
|
@ -857,9 +854,6 @@ class Logging:
|
||||||
call_type = self.call_type,
|
call_type = self.call_type,
|
||||||
stream = self.stream,
|
stream = self.stream,
|
||||||
)
|
)
|
||||||
if callback == "api_manager":
|
|
||||||
print_verbose("reaches api manager for updating model cost")
|
|
||||||
litellm.apiManager.update_cost(completion_obj=result, user=self.user)
|
|
||||||
if callback == "promptlayer":
|
if callback == "promptlayer":
|
||||||
print_verbose("reaches promptlayer for logging!")
|
print_verbose("reaches promptlayer for logging!")
|
||||||
promptLayerLogger.log_event(
|
promptLayerLogger.log_event(
|
||||||
|
@ -994,7 +988,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
|
||||||
print_verbose(f"success callbacks: Running Custom Logger Class")
|
print_verbose(f"success callbacks: Running Custom Logger Class")
|
||||||
if self.stream and complete_streaming_response is None:
|
if self.stream and complete_streaming_response is None:
|
||||||
callback.log_stream_event(
|
callback.log_stream_event(
|
||||||
|
@ -1044,7 +1038,6 @@ class Logging:
|
||||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
"""
|
"""
|
||||||
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
|
print_verbose(f"Async success callbacks: {litellm._async_success_callback}")
|
||||||
|
|
||||||
## BUILD COMPLETE STREAMED RESPONSE
|
## BUILD COMPLETE STREAMED RESPONSE
|
||||||
complete_streaming_response = None
|
complete_streaming_response = None
|
||||||
if self.stream:
|
if self.stream:
|
||||||
|
@ -1081,6 +1074,13 @@ class Logging:
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
await callback.async_log_stream_event( # [TODO]: move this to being an async log stream event function
|
||||||
|
kwargs=self.model_call_details,
|
||||||
|
response_obj=result,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
await callback.async_log_success_event(
|
await callback.async_log_success_event(
|
||||||
kwargs=self.model_call_details,
|
kwargs=self.model_call_details,
|
||||||
|
@ -1103,24 +1103,29 @@ class Logging:
|
||||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _failure_handler_helper_fn(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||||
|
if start_time is None:
|
||||||
|
start_time = self.start_time
|
||||||
|
if end_time is None:
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
|
||||||
|
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
||||||
|
if not hasattr(self, "model_call_details"):
|
||||||
|
self.model_call_details = {}
|
||||||
|
|
||||||
|
self.model_call_details["log_event_type"] = "failed_api_call"
|
||||||
|
self.model_call_details["exception"] = exception
|
||||||
|
self.model_call_details["traceback_exception"] = traceback_exception
|
||||||
|
self.model_call_details["end_time"] = end_time
|
||||||
|
self.model_call_details.setdefault("original_response", None)
|
||||||
|
return start_time, end_time
|
||||||
|
|
||||||
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
def failure_handler(self, exception, traceback_exception, start_time=None, end_time=None):
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"Logging Details LiteLLM-Failure Call"
|
f"Logging Details LiteLLM-Failure Call"
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
if start_time is None:
|
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
|
||||||
start_time = self.start_time
|
|
||||||
if end_time is None:
|
|
||||||
end_time = datetime.datetime.now()
|
|
||||||
|
|
||||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
|
||||||
if not hasattr(self, "model_call_details"):
|
|
||||||
self.model_call_details = {}
|
|
||||||
|
|
||||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
|
||||||
self.model_call_details["exception"] = exception
|
|
||||||
self.model_call_details["traceback_exception"] = traceback_exception
|
|
||||||
self.model_call_details["end_time"] = end_time
|
|
||||||
result = None # result sent to all loggers, init this to None incase it's not created
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
for callback in litellm.failure_callback:
|
for callback in litellm.failure_callback:
|
||||||
try:
|
try:
|
||||||
|
@ -1212,16 +1217,8 @@ class Logging:
|
||||||
"""
|
"""
|
||||||
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
Implementing async callbacks, to handle asyncio event loop issues when custom integrations need to use async functions.
|
||||||
"""
|
"""
|
||||||
# on some exceptions, model_call_details is not always initialized, this ensures that we still log those exceptions
|
start_time, end_time = self._failure_handler_helper_fn(exception=exception, traceback_exception=traceback_exception, start_time=start_time, end_time=end_time)
|
||||||
if not hasattr(self, "model_call_details"):
|
result = None # result sent to all loggers, init this to None incase it's not created
|
||||||
self.model_call_details = {}
|
|
||||||
|
|
||||||
self.model_call_details["log_event_type"] = "failed_api_call"
|
|
||||||
self.model_call_details["exception"] = exception
|
|
||||||
self.model_call_details["traceback_exception"] = traceback_exception
|
|
||||||
self.model_call_details["end_time"] = end_time
|
|
||||||
result = {} # result sent to all loggers, init this to None incase it's not created
|
|
||||||
|
|
||||||
for callback in litellm._async_failure_callback:
|
for callback in litellm._async_failure_callback:
|
||||||
try:
|
try:
|
||||||
if isinstance(callback, CustomLogger): # custom logger class
|
if isinstance(callback, CustomLogger): # custom logger class
|
||||||
|
@ -2060,7 +2057,6 @@ def register_model(model_cost: Union[str, dict]):
|
||||||
return model_cost
|
return model_cost
|
||||||
|
|
||||||
def get_litellm_params(
|
def get_litellm_params(
|
||||||
return_async=False,
|
|
||||||
api_key=None,
|
api_key=None,
|
||||||
force_timeout=600,
|
force_timeout=600,
|
||||||
azure=False,
|
azure=False,
|
||||||
|
@ -2082,7 +2078,6 @@ def get_litellm_params(
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
"return_async": return_async,
|
|
||||||
"api_key": api_key,
|
"api_key": api_key,
|
||||||
"force_timeout": force_timeout,
|
"force_timeout": force_timeout,
|
||||||
"logger_fn": logger_fn,
|
"logger_fn": logger_fn,
|
||||||
|
@ -5094,9 +5089,6 @@ class CustomStreamWrapper:
|
||||||
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
self.special_tokens = ["<|assistant|>", "<|system|>", "<|user|>", "<s>", "</s>"]
|
||||||
self.holding_chunk = ""
|
self.holding_chunk = ""
|
||||||
self.complete_response = ""
|
self.complete_response = ""
|
||||||
if self.logging_obj:
|
|
||||||
# Log the type of the received item
|
|
||||||
self.logging_obj.post_call(str(type(completion_stream)))
|
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -5121,10 +5113,6 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def logging(self, text):
|
|
||||||
if self.logging_obj:
|
|
||||||
self.logging_obj.post_call(text)
|
|
||||||
|
|
||||||
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
||||||
hold = False
|
hold = False
|
||||||
if finish_reason:
|
if finish_reason:
|
||||||
|
@ -5638,16 +5626,12 @@ class CustomStreamWrapper:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
model_response.choices[0].delta = Delta(**completion_obj)
|
||||||
# LOGGING
|
|
||||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
|
||||||
print_verbose(f"model_response: {model_response}")
|
print_verbose(f"model_response: {model_response}")
|
||||||
return model_response
|
return model_response
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
elif model_response.choices[0].finish_reason:
|
elif model_response.choices[0].finish_reason:
|
||||||
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
model_response.choices[0].finish_reason = map_finish_reason(model_response.choices[0].finish_reason) # ensure consistent output to openai
|
||||||
# LOGGING
|
|
||||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start()
|
|
||||||
return model_response
|
return model_response
|
||||||
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
|
elif response_obj is not None and response_obj.get("original_chunk", None) is not None: # function / tool calling branch - only set for openai/azure compatible endpoints
|
||||||
# enter this branch when no content has been passed in response
|
# enter this branch when no content has been passed in response
|
||||||
|
@ -5668,8 +5652,6 @@ class CustomStreamWrapper:
|
||||||
if self.sent_first_chunk == False:
|
if self.sent_first_chunk == False:
|
||||||
model_response.choices[0].delta["role"] = "assistant"
|
model_response.choices[0].delta["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
# LOGGING
|
|
||||||
threading.Thread(target=self.logging_obj.success_handler, args=(model_response,)).start() # log response
|
|
||||||
return model_response
|
return model_response
|
||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
@ -5678,8 +5660,6 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
traceback_exception = traceback.format_exc()
|
traceback_exception = traceback.format_exc()
|
||||||
e.message = str(e)
|
e.message = str(e)
|
||||||
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
|
||||||
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
|
||||||
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
|
raise exception_type(model=self.model, custom_llm_provider=self.custom_llm_provider, original_exception=e)
|
||||||
|
|
||||||
## needs to handle the empty string case (even starting chunk can be an empty string)
|
## needs to handle the empty string case (even starting chunk can be an empty string)
|
||||||
|
@ -5692,12 +5672,17 @@ class CustomStreamWrapper:
|
||||||
chunk = next(self.completion_stream)
|
chunk = next(self.completion_stream)
|
||||||
if chunk is not None and chunk != b'':
|
if chunk is not None and chunk != b'':
|
||||||
response = self.chunk_creator(chunk=chunk)
|
response = self.chunk_creator(chunk=chunk)
|
||||||
if response is not None:
|
if response is None:
|
||||||
return response
|
continue
|
||||||
|
## LOGGING
|
||||||
|
threading.Thread(target=self.logging_obj.success_handler, args=(response,)).start() # log response
|
||||||
|
return response
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
raise # Re-raise StopIteration
|
raise # Re-raise StopIteration
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# Handle other exceptions if needed
|
traceback_exception = traceback.format_exc()
|
||||||
|
# LOG FAILURE - handle streaming failure logging in the _next_ object, remove `handle_failure` once it's deprecated
|
||||||
|
threading.Thread(target=self.logging_obj.failure_handler, args=(e, traceback_exception)).start()
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@ -5728,7 +5713,9 @@ class CustomStreamWrapper:
|
||||||
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
|
asyncio.create_task(self.logging_obj.async_success_handler(processed_chunk,))
|
||||||
return processed_chunk
|
return processed_chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
traceback_exception = traceback.format_exc()
|
||||||
# Handle any exceptions that might occur during streaming
|
# Handle any exceptions that might occur during streaming
|
||||||
|
asyncio.create_task(self.logging_obj.async_failure_handler(e, traceback_exception))
|
||||||
raise StopAsyncIteration
|
raise StopAsyncIteration
|
||||||
|
|
||||||
class TextCompletionStreamWrapper:
|
class TextCompletionStreamWrapper:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue