forked from phoenix/litellm-mirror
test(test_custom_callback_input.py): embedding callback tests for azure, openai, bedrock
This commit is contained in:
parent
8ee77d7b82
commit
ad39afc0ad
6 changed files with 185 additions and 49 deletions
|
@ -329,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
azure_client_params: dict,
|
azure_client_params: dict,
|
||||||
|
api_key: str,
|
||||||
|
input: list,
|
||||||
client=None,
|
client=None,
|
||||||
|
logging_obj=None
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -338,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.embeddings.create(**data)
|
response = await openai_aclient.embeddings.create(**data)
|
||||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
|
stringified_response = response.model_dump_json()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=stringified_response,
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
|
@ -383,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
azure_client_params["api_key"] = api_key
|
azure_client_params["api_key"] = api_key
|
||||||
elif azure_ad_token is not None:
|
elif azure_ad_token is not None:
|
||||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||||
if aembedding == True:
|
|
||||||
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
|
|
||||||
return response
|
|
||||||
if client is None:
|
|
||||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
|
||||||
else:
|
|
||||||
azure_client = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -402,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
|
||||||
|
return response
|
||||||
|
if client is None:
|
||||||
|
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||||
|
else:
|
||||||
|
azure_client = client
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = azure_client.embeddings.create(**data) # type: ignore
|
response = azure_client.embeddings.create(**data) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -587,7 +587,7 @@ def _embedding_func_single(
|
||||||
input=input,
|
input=input,
|
||||||
api_key="",
|
api_key="",
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
original_response=response_body,
|
original_response=json.dumps(response_body),
|
||||||
)
|
)
|
||||||
if provider == "cohere":
|
if provider == "cohere":
|
||||||
response = response_body.get("embeddings")
|
response = response_body.get("embeddings")
|
||||||
|
@ -651,13 +651,4 @@ def embedding(
|
||||||
)
|
)
|
||||||
model_response.usage = usage
|
model_response.usage = usage
|
||||||
|
|
||||||
## LOGGING
|
|
||||||
logging_obj.post_call(
|
|
||||||
input=input,
|
|
||||||
api_key=api_key,
|
|
||||||
additional_args={"complete_input_dict": {"model": model,
|
|
||||||
"texts": input}},
|
|
||||||
original_response=embeddings,
|
|
||||||
)
|
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
|
@ -326,6 +326,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||||
async def aembedding(
|
async def aembedding(
|
||||||
self,
|
self,
|
||||||
|
input: list,
|
||||||
data: dict,
|
data: dict,
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
timeout: float,
|
timeout: float,
|
||||||
|
@ -333,6 +334,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
api_base: Optional[str]=None,
|
api_base: Optional[str]=None,
|
||||||
client=None,
|
client=None,
|
||||||
max_retries=None,
|
max_retries=None,
|
||||||
|
logging_obj=None
|
||||||
):
|
):
|
||||||
response = None
|
response = None
|
||||||
try:
|
try:
|
||||||
|
@ -341,9 +343,24 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
else:
|
else:
|
||||||
openai_aclient = client
|
openai_aclient = client
|
||||||
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
||||||
return response
|
stringified_response = response.model_dump_json()
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
additional_args={"complete_input_dict": data},
|
||||||
|
original_response=stringified_response,
|
||||||
|
)
|
||||||
|
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=input,
|
||||||
|
api_key=api_key,
|
||||||
|
original_response=str(e),
|
||||||
|
)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
def embedding(self,
|
def embedding(self,
|
||||||
model: str,
|
model: str,
|
||||||
input: list,
|
input: list,
|
||||||
|
@ -368,13 +385,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
max_retries = data.pop("max_retries", 2)
|
max_retries = data.pop("max_retries", 2)
|
||||||
if not isinstance(max_retries, int):
|
if not isinstance(max_retries, int):
|
||||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||||
if aembedding == True:
|
|
||||||
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
|
||||||
return response
|
|
||||||
if client is None:
|
|
||||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
|
||||||
else:
|
|
||||||
openai_client = client
|
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.pre_call(
|
logging_obj.pre_call(
|
||||||
input=input,
|
input=input,
|
||||||
|
@ -382,6 +393,14 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if aembedding == True:
|
||||||
|
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||||
|
return response
|
||||||
|
if client is None:
|
||||||
|
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||||
|
else:
|
||||||
|
openai_client = client
|
||||||
|
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
response = openai_client.embeddings.create(**data) # type: ignore
|
response = openai_client.embeddings.create(**data) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
|
|
@ -1823,7 +1823,7 @@ def embedding(
|
||||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||||
aembedding = kwargs.pop("aembedding", None)
|
aembedding = kwargs.pop("aembedding", None)
|
||||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||||
default_params = openai_params + litellm_params
|
default_params = openai_params + litellm_params
|
||||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||||
optional_params = {}
|
optional_params = {}
|
||||||
|
@ -1835,7 +1835,7 @@ def embedding(
|
||||||
try:
|
try:
|
||||||
response = None
|
response = None
|
||||||
logging = litellm_logging_obj
|
logging = litellm_logging_obj
|
||||||
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata})
|
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding})
|
||||||
if azure == True or custom_llm_provider == "azure":
|
if azure == True or custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
@ -1975,7 +1975,7 @@ def embedding(
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging.post_call(
|
logging.post_call(
|
||||||
input=input,
|
input=input,
|
||||||
api_key=openai.api_key,
|
api_key=api_key,
|
||||||
original_response=str(e),
|
original_response=str(e),
|
||||||
)
|
)
|
||||||
## Map to OpenAI Exception
|
## Map to OpenAI Exception
|
||||||
|
|
|
@ -4,7 +4,7 @@ import sys, os, time, inspect, asyncio, traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pytest
|
import pytest
|
||||||
sys.path.insert(0, os.path.abspath('../..'))
|
sys.path.insert(0, os.path.abspath('../..'))
|
||||||
from typing import Optional
|
from typing import Optional, Literal, List
|
||||||
from litellm import completion, embedding
|
from litellm import completion, embedding
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.integrations.custom_logger import CustomLogger
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
@ -32,16 +32,18 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
# Class variables or attributes
|
# Class variables or attributes
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.errors = []
|
self.errors = []
|
||||||
|
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||||
|
|
||||||
def log_pre_api_call(self, model, messages, kwargs):
|
def log_pre_api_call(self, model, messages, kwargs):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("sync_pre_api_call")
|
||||||
## MODEL
|
## MODEL
|
||||||
assert isinstance(model, str)
|
assert isinstance(model, str)
|
||||||
## MESSAGES
|
## MESSAGES
|
||||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
assert isinstance(messages, list)
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs['model'], str)
|
assert isinstance(kwargs['model'], str)
|
||||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
assert isinstance(kwargs['messages'], list)
|
||||||
assert isinstance(kwargs['optional_params'], dict)
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
assert isinstance(kwargs['litellm_params'], dict)
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
|
@ -53,9 +55,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
print("IN POST CALL API")
|
self.states.append("post_api_call")
|
||||||
print(f"kwargs input: {kwargs['input']}")
|
|
||||||
print(f"kwargs original response: {kwargs['original_response']}")
|
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
|
@ -64,13 +64,13 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
assert response_obj == None
|
assert response_obj == None
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs['model'], str)
|
assert isinstance(kwargs['model'], str)
|
||||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
assert isinstance(kwargs['messages'], list)
|
||||||
assert isinstance(kwargs['optional_params'], dict)
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
assert isinstance(kwargs['litellm_params'], dict)
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
assert isinstance(kwargs['stream'], bool)
|
assert isinstance(kwargs['stream'], bool)
|
||||||
assert isinstance(kwargs['user'], Optional[str])
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
assert isinstance(kwargs['input'], (list, dict, str))
|
||||||
assert isinstance(kwargs['api_key'], Optional[str])
|
assert isinstance(kwargs['api_key'], Optional[str])
|
||||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
@ -81,6 +81,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("async_stream")
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
|
@ -106,6 +107,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("sync_success")
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
|
@ -131,6 +133,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("sync_failure")
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
|
@ -156,6 +159,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("async_pre_api_call")
|
||||||
## MODEL
|
## MODEL
|
||||||
assert isinstance(model, str)
|
assert isinstance(model, str)
|
||||||
## MESSAGES
|
## MESSAGES
|
||||||
|
@ -174,21 +178,22 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("async_success")
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
assert isinstance(end_time, datetime)
|
assert isinstance(end_time, datetime)
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
assert isinstance(response_obj, litellm.ModelResponse)
|
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs['model'], str)
|
assert isinstance(kwargs['model'], str)
|
||||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
assert isinstance(kwargs['messages'], list)
|
||||||
assert isinstance(kwargs['optional_params'], dict)
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
assert isinstance(kwargs['litellm_params'], dict)
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
assert isinstance(kwargs['stream'], bool)
|
assert isinstance(kwargs['stream'], bool)
|
||||||
assert isinstance(kwargs['user'], Optional[str])
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
assert isinstance(kwargs['input'], (list, dict, str))
|
||||||
assert isinstance(kwargs['api_key'], Optional[str])
|
assert isinstance(kwargs['api_key'], Optional[str])
|
||||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
|
@ -199,6 +204,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
|
|
||||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
try:
|
try:
|
||||||
|
self.states.append("async_failure")
|
||||||
## START TIME
|
## START TIME
|
||||||
assert isinstance(start_time, datetime)
|
assert isinstance(start_time, datetime)
|
||||||
## END TIME
|
## END TIME
|
||||||
|
@ -207,21 +213,23 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
||||||
assert response_obj == None
|
assert response_obj == None
|
||||||
## KWARGS
|
## KWARGS
|
||||||
assert isinstance(kwargs['model'], str)
|
assert isinstance(kwargs['model'], str)
|
||||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
assert isinstance(kwargs['messages'], list)
|
||||||
assert isinstance(kwargs['optional_params'], dict)
|
assert isinstance(kwargs['optional_params'], dict)
|
||||||
assert isinstance(kwargs['litellm_params'], dict)
|
assert isinstance(kwargs['litellm_params'], dict)
|
||||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||||
assert isinstance(kwargs['stream'], bool)
|
assert isinstance(kwargs['stream'], bool)
|
||||||
assert isinstance(kwargs['user'], Optional[str])
|
assert isinstance(kwargs['user'], Optional[str])
|
||||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
assert isinstance(kwargs['input'], (list, str, dict))
|
||||||
assert isinstance(kwargs['api_key'], Optional[str])
|
assert isinstance(kwargs['api_key'], Optional[str])
|
||||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
|
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
|
||||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||||
assert isinstance(kwargs['log_event_type'], str)
|
assert isinstance(kwargs['log_event_type'], str)
|
||||||
except:
|
except:
|
||||||
print(f"Assertion Error: {traceback.format_exc()}")
|
print(f"Assertion Error: {traceback.format_exc()}")
|
||||||
self.errors.append(traceback.format_exc())
|
self.errors.append(traceback.format_exc())
|
||||||
|
|
||||||
|
|
||||||
|
# COMPLETION
|
||||||
## Test OpenAI + sync
|
## Test OpenAI + sync
|
||||||
def test_chat_openai_stream():
|
def test_chat_openai_stream():
|
||||||
try:
|
try:
|
||||||
|
@ -379,7 +387,7 @@ async def test_async_chat_azure_stream():
|
||||||
continue
|
continue
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
time.sleep(1)
|
await asyncio.sleep(1)
|
||||||
print(f"customHandler.errors: {customHandler.errors}")
|
print(f"customHandler.errors: {customHandler.errors}")
|
||||||
assert len(customHandler.errors) == 0
|
assert len(customHandler.errors) == 0
|
||||||
litellm.callbacks = []
|
litellm.callbacks = []
|
||||||
|
@ -472,3 +480,101 @@ async def test_async_chat_bedrock_stream():
|
||||||
pytest.fail(f"An exception occurred: {str(e)}")
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
# asyncio.run(test_async_chat_bedrock_stream())
|
# asyncio.run(test_async_chat_bedrock_stream())
|
||||||
|
|
||||||
|
# EMBEDDING
|
||||||
|
## Test OpenAI + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_openai():
|
||||||
|
try:
|
||||||
|
customHandler_success = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"])
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
assert len(customHandler_success.errors) == 0
|
||||||
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
|
# test failure callback
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
try:
|
||||||
|
response = await litellm.aembedding(model="text-embedding-ada-002",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
api_key="my-bad-key")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_embedding_openai())
|
||||||
|
|
||||||
|
## Test Azure + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_azure():
|
||||||
|
try:
|
||||||
|
customHandler_success = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"])
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
assert len(customHandler_success.errors) == 0
|
||||||
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
|
# test failure callback
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
try:
|
||||||
|
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
api_key="my-bad-key")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
# asyncio.run(test_async_embedding_azure())
|
||||||
|
|
||||||
|
## Test Bedrock + Async
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_embedding_bedrock():
|
||||||
|
try:
|
||||||
|
customHandler_success = CompletionCustomHandler()
|
||||||
|
customHandler_failure = CompletionCustomHandler()
|
||||||
|
litellm.callbacks = [customHandler_success]
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||||
|
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
assert len(customHandler_success.errors) == 0
|
||||||
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
|
# test failure callback
|
||||||
|
litellm.callbacks = [customHandler_failure]
|
||||||
|
try:
|
||||||
|
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||||
|
input=["good morning from litellm"],
|
||||||
|
aws_region_name="my-bad-region")
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||||
|
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||||
|
assert len(customHandler_failure.errors) == 0
|
||||||
|
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
asyncio.run(test_async_embedding_bedrock())
|
|
@ -989,7 +989,7 @@ class Logging:
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
)
|
)
|
||||||
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
|
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
|
||||||
print_verbose(f"success callbacks: Running Custom Logger Class")
|
print_verbose(f"success callbacks: Running Custom Logger Class")
|
||||||
if self.stream and complete_streaming_response is None:
|
if self.stream and complete_streaming_response is None:
|
||||||
callback.log_stream_event(
|
callback.log_stream_event(
|
||||||
|
@ -1192,7 +1192,7 @@ class Logging:
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
callback_func=callback
|
callback_func=callback
|
||||||
)
|
)
|
||||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class
|
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
|
||||||
callback.log_failure_event(
|
callback.log_failure_event(
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
end_time=end_time,
|
end_time=end_time,
|
||||||
|
@ -1641,7 +1641,7 @@ def client(original_function):
|
||||||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||||
litellm.cache.add_cache(result, *args, **kwargs)
|
litellm.cache.add_cache(result, *args, **kwargs)
|
||||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
|
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
|
||||||
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
||||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||||
# RETURN RESULT
|
# RETURN RESULT
|
||||||
|
@ -1678,7 +1678,7 @@ def client(original_function):
|
||||||
end_time = datetime.datetime.now()
|
end_time = datetime.datetime.now()
|
||||||
if logging_obj:
|
if logging_obj:
|
||||||
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
||||||
asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time))
|
await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time)
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue