mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
test(test_custom_callback_input.py): embedding callback tests for azure, openai, bedrock
This commit is contained in:
parent
8ee77d7b82
commit
ad39afc0ad
6 changed files with 185 additions and 49 deletions
|
@ -329,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
azure_client_params: dict,
|
||||
api_key: str,
|
||||
input: list,
|
||||
client=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -338,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data)
|
||||
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
|
@ -383,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["api_key"] = api_key
|
||||
elif azure_ad_token is not None:
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -402,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
|
|||
}
|
||||
},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
azure_client = client
|
||||
## COMPLETION CALL
|
||||
response = azure_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
|
|
|
@ -587,7 +587,7 @@ def _embedding_func_single(
|
|||
input=input,
|
||||
api_key="",
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=response_body,
|
||||
original_response=json.dumps(response_body),
|
||||
)
|
||||
if provider == "cohere":
|
||||
response = response_body.get("embeddings")
|
||||
|
@ -651,13 +651,4 @@ def embedding(
|
|||
)
|
||||
model_response.usage = usage
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": {"model": model,
|
||||
"texts": input}},
|
||||
original_response=embeddings,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
|
|
@ -326,6 +326,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
raise OpenAIError(status_code=500, message=f"{str(e)}")
|
||||
async def aembedding(
|
||||
self,
|
||||
input: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
timeout: float,
|
||||
|
@ -333,6 +334,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str]=None,
|
||||
client=None,
|
||||
max_retries=None,
|
||||
logging_obj=None
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
@ -341,9 +343,24 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
else:
|
||||
openai_aclient = client
|
||||
response = await openai_aclient.embeddings.create(**data) # type: ignore
|
||||
return response
|
||||
stringified_response = response.model_dump_json()
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
raise e
|
||||
|
||||
def embedding(self,
|
||||
model: str,
|
||||
input: list,
|
||||
|
@ -368,13 +385,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
max_retries = data.pop("max_retries", 2)
|
||||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=input,
|
||||
|
@ -382,6 +393,14 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
if client is None:
|
||||
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
|
||||
else:
|
||||
openai_client = client
|
||||
|
||||
## COMPLETION CALL
|
||||
response = openai_client.embeddings.create(**data) # type: ignore
|
||||
## LOGGING
|
||||
|
|
|
@ -1823,7 +1823,7 @@ def embedding(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
aembedding = kwargs.pop("aembedding", None)
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
|
||||
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
optional_params = {}
|
||||
|
@ -1835,7 +1835,7 @@ def embedding(
|
|||
try:
|
||||
response = None
|
||||
logging = litellm_logging_obj
|
||||
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata})
|
||||
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding})
|
||||
if azure == True or custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
|
@ -1975,7 +1975,7 @@ def embedding(
|
|||
## LOGGING
|
||||
logging.post_call(
|
||||
input=input,
|
||||
api_key=openai.api_key,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
)
|
||||
## Map to OpenAI Exception
|
||||
|
|
|
@ -4,7 +4,7 @@ import sys, os, time, inspect, asyncio, traceback
|
|||
from datetime import datetime
|
||||
import pytest
|
||||
sys.path.insert(0, os.path.abspath('../..'))
|
||||
from typing import Optional
|
||||
from typing import Optional, Literal, List
|
||||
from litellm import completion, embedding
|
||||
import litellm
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -32,16 +32,18 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
self.errors = []
|
||||
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
|
||||
|
||||
def log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("sync_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
assert isinstance(messages, list) and isinstance(messages[0], dict)
|
||||
assert isinstance(messages, list)
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
|
@ -53,9 +55,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
print("IN POST CALL API")
|
||||
print(f"kwargs input: {kwargs['input']}")
|
||||
print(f"kwargs original response: {kwargs['original_response']}")
|
||||
self.states.append("post_api_call")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
|
@ -64,13 +64,13 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], Optional[str])
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
|
@ -81,6 +81,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_stream")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
|
@ -106,6 +107,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
|
@ -131,6 +133,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("sync_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
|
@ -156,6 +159,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
async def async_log_pre_api_call(self, model, messages, kwargs):
|
||||
try:
|
||||
self.states.append("async_pre_api_call")
|
||||
## MODEL
|
||||
assert isinstance(model, str)
|
||||
## MESSAGES
|
||||
|
@ -174,21 +178,22 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_success")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
assert isinstance(end_time, datetime)
|
||||
## RESPONSE OBJECT
|
||||
assert isinstance(response_obj, litellm.ModelResponse)
|
||||
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['input'], (list, dict, str))
|
||||
assert isinstance(kwargs['api_key'], Optional[str])
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
|
@ -199,6 +204,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
try:
|
||||
self.states.append("async_failure")
|
||||
## START TIME
|
||||
assert isinstance(start_time, datetime)
|
||||
## END TIME
|
||||
|
@ -207,21 +213,23 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
|
|||
assert response_obj == None
|
||||
## KWARGS
|
||||
assert isinstance(kwargs['model'], str)
|
||||
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
|
||||
assert isinstance(kwargs['messages'], list)
|
||||
assert isinstance(kwargs['optional_params'], dict)
|
||||
assert isinstance(kwargs['litellm_params'], dict)
|
||||
assert isinstance(kwargs['start_time'], Optional[datetime])
|
||||
assert isinstance(kwargs['stream'], bool)
|
||||
assert isinstance(kwargs['user'], Optional[str])
|
||||
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
|
||||
assert isinstance(kwargs['input'], (list, str, dict))
|
||||
assert isinstance(kwargs['api_key'], Optional[str])
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
|
||||
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
|
||||
assert isinstance(kwargs['additional_args'], Optional[dict])
|
||||
assert isinstance(kwargs['log_event_type'], str)
|
||||
except:
|
||||
print(f"Assertion Error: {traceback.format_exc()}")
|
||||
self.errors.append(traceback.format_exc())
|
||||
|
||||
|
||||
# COMPLETION
|
||||
## Test OpenAI + sync
|
||||
def test_chat_openai_stream():
|
||||
try:
|
||||
|
@ -379,7 +387,7 @@ async def test_async_chat_azure_stream():
|
|||
continue
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler.errors: {customHandler.errors}")
|
||||
assert len(customHandler.errors) == 0
|
||||
litellm.callbacks = []
|
||||
|
@ -472,3 +480,101 @@ async def test_async_chat_bedrock_stream():
|
|||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_chat_bedrock_stream())
|
||||
|
||||
# EMBEDDING
|
||||
## Test OpenAI + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_openai():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"])
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="text-embedding-ada-002",
|
||||
input=["good morning from litellm"],
|
||||
api_key="my-bad-key")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_openai())
|
||||
|
||||
## Test Azure + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_azure():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"])
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="azure/azure-embedding-model",
|
||||
input=["good morning from litellm"],
|
||||
api_key="my-bad-key")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
# asyncio.run(test_async_embedding_azure())
|
||||
|
||||
## Test Bedrock + Async
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_embedding_bedrock():
|
||||
try:
|
||||
customHandler_success = CompletionCustomHandler()
|
||||
customHandler_failure = CompletionCustomHandler()
|
||||
litellm.callbacks = [customHandler_success]
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||
assert len(customHandler_success.errors) == 0
|
||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||
# test failure callback
|
||||
litellm.callbacks = [customHandler_failure]
|
||||
try:
|
||||
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
|
||||
input=["good morning from litellm"],
|
||||
aws_region_name="my-bad-region")
|
||||
except:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
|
||||
print(f"customHandler_failure.states: {customHandler_failure.states}")
|
||||
assert len(customHandler_failure.errors) == 0
|
||||
assert len(customHandler_failure.states) == 3 # pre, post, success
|
||||
except Exception as e:
|
||||
pytest.fail(f"An exception occurred: {str(e)}")
|
||||
|
||||
asyncio.run(test_async_embedding_bedrock())
|
|
@ -989,7 +989,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
|
||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
|
||||
print_verbose(f"success callbacks: Running Custom Logger Class")
|
||||
if self.stream and complete_streaming_response is None:
|
||||
callback.log_stream_event(
|
||||
|
@ -1192,7 +1192,7 @@ class Logging:
|
|||
print_verbose=print_verbose,
|
||||
callback_func=callback
|
||||
)
|
||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class
|
||||
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
|
||||
callback.log_failure_event(
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
|
@ -1641,7 +1641,7 @@ def client(original_function):
|
|||
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
|
||||
litellm.cache.add_cache(result, *args, **kwargs)
|
||||
# LOG SUCCESS - handle streaming success logging in the _next_ object
|
||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
|
||||
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
|
||||
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
|
||||
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
|
||||
# RETURN RESULT
|
||||
|
@ -1678,7 +1678,7 @@ def client(original_function):
|
|||
end_time = datetime.datetime.now()
|
||||
if logging_obj:
|
||||
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
||||
asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time))
|
||||
await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time)
|
||||
raise e
|
||||
|
||||
is_coroutine = inspect.iscoroutinefunction(original_function)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue