test(test_custom_callback_input.py): embedding callback tests for azure, openai, bedrock

This commit is contained in:
Krrish Dholakia 2023-12-11 15:32:34 -08:00
parent 8ee77d7b82
commit ad39afc0ad
6 changed files with 185 additions and 49 deletions

View file

@ -329,7 +329,10 @@ class AzureChatCompletion(BaseLLM):
data: dict,
model_response: ModelResponse,
azure_client_params: dict,
api_key: str,
input: list,
client=None,
logging_obj=None
):
response = None
try:
@ -338,8 +341,23 @@ class AzureChatCompletion(BaseLLM):
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data)
return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding")
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding")
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e
def embedding(self,
@ -383,13 +401,7 @@ class AzureChatCompletion(BaseLLM):
azure_client_params["api_key"] = api_key
elif azure_ad_token is not None:
azure_client_params["azure_ad_token"] = azure_ad_token
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## LOGGING
logging_obj.pre_call(
input=input,
@ -402,6 +414,14 @@ class AzureChatCompletion(BaseLLM):
}
},
)
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params)
return response
if client is None:
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
else:
azure_client = client
## COMPLETION CALL
response = azure_client.embeddings.create(**data) # type: ignore
## LOGGING

View file

@ -587,7 +587,7 @@ def _embedding_func_single(
input=input,
api_key="",
additional_args={"complete_input_dict": data},
original_response=response_body,
original_response=json.dumps(response_body),
)
if provider == "cohere":
response = response_body.get("embeddings")
@ -651,13 +651,4 @@ def embedding(
)
model_response.usage = usage
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": {"model": model,
"texts": input}},
original_response=embeddings,
)
return model_response

View file

@ -326,6 +326,7 @@ class OpenAIChatCompletion(BaseLLM):
raise OpenAIError(status_code=500, message=f"{str(e)}")
async def aembedding(
self,
input: list,
data: dict,
model_response: ModelResponse,
timeout: float,
@ -333,6 +334,7 @@ class OpenAIChatCompletion(BaseLLM):
api_base: Optional[str]=None,
client=None,
max_retries=None,
logging_obj=None
):
response = None
try:
@ -341,9 +343,24 @@ class OpenAIChatCompletion(BaseLLM):
else:
openai_aclient = client
response = await openai_aclient.embeddings.create(**data) # type: ignore
return response
stringified_response = response.model_dump_json()
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=stringified_response,
)
return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore
except Exception as e:
## LOGGING
logging_obj.post_call(
input=input,
api_key=api_key,
original_response=str(e),
)
raise e
def embedding(self,
model: str,
input: list,
@ -368,13 +385,7 @@ class OpenAIChatCompletion(BaseLLM):
max_retries = data.pop("max_retries", 2)
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
if aembedding == True:
response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## LOGGING
logging_obj.pre_call(
input=input,
@ -382,6 +393,14 @@ class OpenAIChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base},
)
if aembedding == True:
response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
if client is None:
openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries)
else:
openai_client = client
## COMPLETION CALL
response = openai_client.embeddings.create(**data) # type: ignore
## LOGGING

View file

@ -1823,7 +1823,7 @@ def embedding(
proxy_server_request = kwargs.get("proxy_server_request", None)
aembedding = kwargs.pop("aembedding", None)
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"]
litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
optional_params = {}
@ -1835,7 +1835,7 @@ def embedding(
try:
response = None
logging = litellm_logging_obj
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata})
logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding})
if azure == True or custom_llm_provider == "azure":
# azure configs
api_type = get_secret("AZURE_API_TYPE") or "azure"
@ -1975,7 +1975,7 @@ def embedding(
## LOGGING
logging.post_call(
input=input,
api_key=openai.api_key,
api_key=api_key,
original_response=str(e),
)
## Map to OpenAI Exception

View file

@ -4,7 +4,7 @@ import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
sys.path.insert(0, os.path.abspath('../..'))
from typing import Optional
from typing import Optional, Literal, List
from litellm import completion, embedding
import litellm
from litellm.integrations.custom_logger import CustomLogger
@ -32,16 +32,18 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
# Class variables or attributes
def __init__(self):
self.errors = []
self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = []
def log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("sync_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
assert isinstance(messages, list) and isinstance(messages[0], dict)
assert isinstance(messages, list)
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
@ -53,9 +55,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
def log_post_api_call(self, kwargs, response_obj, start_time, end_time):
try:
print("IN POST CALL API")
print(f"kwargs input: {kwargs['input']}")
print(f"kwargs original response: {kwargs['original_response']}")
self.states.append("post_api_call")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
@ -64,13 +64,13 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
@ -81,6 +81,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_stream")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
@ -106,6 +107,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
@ -131,6 +133,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
def log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("sync_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
@ -156,6 +159,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
async def async_log_pre_api_call(self, model, messages, kwargs):
try:
self.states.append("async_pre_api_call")
## MODEL
assert isinstance(model, str)
## MESSAGES
@ -174,21 +178,22 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_success")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
assert isinstance(end_time, datetime)
## RESPONSE OBJECT
assert isinstance(response_obj, litellm.ModelResponse)
assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse))
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['input'], (list, dict, str))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response'])
assert isinstance(kwargs['additional_args'], Optional[dict])
@ -199,6 +204,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
try:
self.states.append("async_failure")
## START TIME
assert isinstance(start_time, datetime)
## END TIME
@ -207,21 +213,23 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse
assert response_obj == None
## KWARGS
assert isinstance(kwargs['model'], str)
assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict)
assert isinstance(kwargs['messages'], list)
assert isinstance(kwargs['optional_params'], dict)
assert isinstance(kwargs['litellm_params'], dict)
assert isinstance(kwargs['start_time'], Optional[datetime])
assert isinstance(kwargs['stream'], bool)
assert isinstance(kwargs['user'], Optional[str])
assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str))
assert isinstance(kwargs['input'], (list, str, dict))
assert isinstance(kwargs['api_key'], Optional[str])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response'])
assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None
assert isinstance(kwargs['additional_args'], Optional[dict])
assert isinstance(kwargs['log_event_type'], str)
except:
print(f"Assertion Error: {traceback.format_exc()}")
self.errors.append(traceback.format_exc())
# COMPLETION
## Test OpenAI + sync
def test_chat_openai_stream():
try:
@ -379,7 +387,7 @@ async def test_async_chat_azure_stream():
continue
except:
pass
time.sleep(1)
await asyncio.sleep(1)
print(f"customHandler.errors: {customHandler.errors}")
assert len(customHandler.errors) == 0
litellm.callbacks = []
@ -472,3 +480,101 @@ async def test_async_chat_bedrock_stream():
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_chat_bedrock_stream())
# EMBEDDING
## Test OpenAI + Async
@pytest.mark.asyncio
async def test_async_embedding_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="text-embedding-ada-002",
input=["good morning from litellm"],
api_key="my-bad-key")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_openai())
## Test Azure + Async
@pytest.mark.asyncio
async def test_async_embedding_azure():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"])
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="azure/azure-embedding-model",
input=["good morning from litellm"],
api_key="my-bad-key")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
# asyncio.run(test_async_embedding_azure())
## Test Bedrock + Async
@pytest.mark.asyncio
async def test_async_embedding_bedrock():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
litellm.set_verbose = True
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2")
await asyncio.sleep(1)
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3",
input=["good morning from litellm"],
aws_region_name="my-bad-region")
except:
pass
await asyncio.sleep(1)
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, success
except Exception as e:
pytest.fail(f"An exception occurred: {str(e)}")
asyncio.run(test_async_embedding_bedrock())

View file

@ -989,7 +989,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
print_verbose(f"success callbacks: Running Custom Logger Class")
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1192,7 +1192,7 @@ class Logging:
print_verbose=print_verbose,
callback_func=callback
)
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class
elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class
callback.log_failure_event(
start_time=start_time,
end_time=end_time,
@ -1641,7 +1641,7 @@ def client(original_function):
if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object
litellm.cache.add_cache(result, *args, **kwargs)
# LOG SUCCESS - handle streaming success logging in the _next_ object
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler")
print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}")
asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time))
threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start()
# RETURN RESULT
@ -1678,7 +1678,7 @@ def client(original_function):
end_time = datetime.datetime.now()
if logging_obj:
logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this!
asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time))
await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time)
raise e
is_coroutine = inspect.iscoroutinefunction(original_function)