diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index 2677d12c5..b014667df 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -329,7 +329,10 @@ class AzureChatCompletion(BaseLLM): data: dict, model_response: ModelResponse, azure_client_params: dict, + api_key: str, + input: list, client=None, + logging_obj=None ): response = None try: @@ -338,8 +341,23 @@ class AzureChatCompletion(BaseLLM): else: openai_aclient = client response = await openai_aclient.embeddings.create(**data) - return convert_to_model_response_object(response_object=json.loads(response.model_dump_json()), model_response_object=model_response, response_type="embedding") + stringified_response = response.model_dump_json() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=str(e), + ) raise e def embedding(self, @@ -383,13 +401,7 @@ class AzureChatCompletion(BaseLLM): azure_client_params["api_key"] = api_key elif azure_ad_token is not None: azure_client_params["azure_ad_token"] = azure_ad_token - if aembedding == True: - response = self.aembedding(data=data, model_response=model_response, azure_client_params=azure_client_params) - return response - if client is None: - azure_client = AzureOpenAI(**azure_client_params) # type: ignore - else: - azure_client = client + ## LOGGING logging_obj.pre_call( input=input, @@ -402,6 +414,14 @@ class AzureChatCompletion(BaseLLM): } }, ) + + if aembedding == True: + response = self.aembedding(data=data, input=input, logging_obj=logging_obj, api_key=api_key, model_response=model_response, azure_client_params=azure_client_params) + return response + if client is None: + azure_client = AzureOpenAI(**azure_client_params) # type: ignore + else: + azure_client = client ## COMPLETION CALL response = azure_client.embeddings.create(**data) # type: ignore ## LOGGING diff --git a/litellm/llms/bedrock.py b/litellm/llms/bedrock.py index 5b3659f88..5207c4cca 100644 --- a/litellm/llms/bedrock.py +++ b/litellm/llms/bedrock.py @@ -587,7 +587,7 @@ def _embedding_func_single( input=input, api_key="", additional_args={"complete_input_dict": data}, - original_response=response_body, + original_response=json.dumps(response_body), ) if provider == "cohere": response = response_body.get("embeddings") @@ -650,14 +650,5 @@ def embedding( total_tokens=input_tokens + 0 ) model_response.usage = usage - - ## LOGGING - logging_obj.post_call( - input=input, - api_key=api_key, - additional_args={"complete_input_dict": {"model": model, - "texts": input}}, - original_response=embeddings, - ) return model_response diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index 9c16708e2..33d3504bb 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -326,6 +326,7 @@ class OpenAIChatCompletion(BaseLLM): raise OpenAIError(status_code=500, message=f"{str(e)}") async def aembedding( self, + input: list, data: dict, model_response: ModelResponse, timeout: float, @@ -333,6 +334,7 @@ class OpenAIChatCompletion(BaseLLM): api_base: Optional[str]=None, client=None, max_retries=None, + logging_obj=None ): response = None try: @@ -341,9 +343,24 @@ class OpenAIChatCompletion(BaseLLM): else: openai_aclient = client response = await openai_aclient.embeddings.create(**data) # type: ignore - return response + stringified_response = response.model_dump_json() + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + additional_args={"complete_input_dict": data}, + original_response=stringified_response, + ) + return convert_to_model_response_object(response_object=json.loads(stringified_response), model_response_object=model_response, response_type="embedding") # type: ignore except Exception as e: + ## LOGGING + logging_obj.post_call( + input=input, + api_key=api_key, + original_response=str(e), + ) raise e + def embedding(self, model: str, input: list, @@ -368,13 +385,7 @@ class OpenAIChatCompletion(BaseLLM): max_retries = data.pop("max_retries", 2) if not isinstance(max_retries, int): raise OpenAIError(status_code=422, message="max retries must be an int") - if aembedding == True: - response = self.aembedding(data=data, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore - return response - if client is None: - openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) - else: - openai_client = client + ## LOGGING logging_obj.pre_call( input=input, @@ -382,6 +393,14 @@ class OpenAIChatCompletion(BaseLLM): additional_args={"complete_input_dict": data, "api_base": api_base}, ) + if aembedding == True: + response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore + return response + if client is None: + openai_client = OpenAI(api_key=api_key, base_url=api_base, http_client=litellm.client_session, timeout=timeout, max_retries=max_retries) + else: + openai_client = client + ## COMPLETION CALL response = openai_client.embeddings.create(**data) # type: ignore ## LOGGING diff --git a/litellm/main.py b/litellm/main.py index 3b50c075b..5ecfcf1db 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1823,7 +1823,7 @@ def embedding( proxy_server_request = kwargs.get("proxy_server_request", None) aembedding = kwargs.pop("aembedding", None) openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id", "organization", "base_url", "default_headers", "timeout", "response_format", "seed", "tools", "tool_choice", "max_retries", "encoding_format"] - litellm_params = ["metadata", "acompletion", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] + litellm_params = ["metadata", "aembedding", "caching", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response", "self", "client", "rpm", "tpm", "input_cost_per_token", "output_cost_per_token", "hf_model_name", "proxy_server_request", "model_info", "preset_cache_key"] default_params = openai_params + litellm_params non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider optional_params = {} @@ -1835,7 +1835,7 @@ def embedding( try: response = None logging = litellm_logging_obj - logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata}) + logging.update_environment_variables(model=model, user="", optional_params=optional_params, litellm_params={"timeout": timeout, "azure": azure, "litellm_call_id": litellm_call_id, "logger_fn": logger_fn, "proxy_server_request": proxy_server_request, "model_info": model_info, "metadata": metadata, "aembedding": aembedding}) if azure == True or custom_llm_provider == "azure": # azure configs api_type = get_secret("AZURE_API_TYPE") or "azure" @@ -1975,7 +1975,7 @@ def embedding( ## LOGGING logging.post_call( input=input, - api_key=openai.api_key, + api_key=api_key, original_response=str(e), ) ## Map to OpenAI Exception diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 0102f6eb7..d9bcc8947 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -4,7 +4,7 @@ import sys, os, time, inspect, asyncio, traceback from datetime import datetime import pytest sys.path.insert(0, os.path.abspath('../..')) -from typing import Optional +from typing import Optional, Literal, List from litellm import completion, embedding import litellm from litellm.integrations.custom_logger import CustomLogger @@ -32,16 +32,18 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse # Class variables or attributes def __init__(self): self.errors = [] + self.states: Optional[List[Literal["sync_pre_api_call", "async_pre_api_call", "post_api_call", "sync_stream", "async_stream", "sync_success", "async_success", "sync_failure", "async_failure"]]] = [] def log_pre_api_call(self, model, messages, kwargs): try: + self.states.append("sync_pre_api_call") ## MODEL assert isinstance(model, str) ## MESSAGES - assert isinstance(messages, list) and isinstance(messages[0], dict) + assert isinstance(messages, list) ## KWARGS assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['messages'], list) assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs['start_time'], Optional[datetime]) @@ -53,9 +55,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse def log_post_api_call(self, kwargs, response_obj, start_time, end_time): try: - print("IN POST CALL API") - print(f"kwargs input: {kwargs['input']}") - print(f"kwargs original response: {kwargs['original_response']}") + self.states.append("post_api_call") ## START TIME assert isinstance(start_time, datetime) ## END TIME @@ -64,13 +64,13 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse assert response_obj == None ## KWARGS assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['messages'], list) assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs['start_time'], Optional[datetime]) assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs['user'], Optional[str]) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) + assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs['api_key'], Optional[str]) assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.iscoroutine(kwargs['original_response']) or inspect.isasyncgen(kwargs['original_response']) assert isinstance(kwargs['additional_args'], Optional[dict]) @@ -81,6 +81,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse async def async_log_stream_event(self, kwargs, response_obj, start_time, end_time): try: + self.states.append("async_stream") ## START TIME assert isinstance(start_time, datetime) ## END TIME @@ -106,6 +107,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse def log_success_event(self, kwargs, response_obj, start_time, end_time): try: + self.states.append("sync_success") ## START TIME assert isinstance(start_time, datetime) ## END TIME @@ -131,6 +133,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse def log_failure_event(self, kwargs, response_obj, start_time, end_time): try: + self.states.append("sync_failure") ## START TIME assert isinstance(start_time, datetime) ## END TIME @@ -156,6 +159,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse async def async_log_pre_api_call(self, model, messages, kwargs): try: + self.states.append("async_pre_api_call") ## MODEL assert isinstance(model, str) ## MESSAGES @@ -174,21 +178,22 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): try: + self.states.append("async_success") ## START TIME assert isinstance(start_time, datetime) ## END TIME assert isinstance(end_time, datetime) ## RESPONSE OBJECT - assert isinstance(response_obj, litellm.ModelResponse) + assert isinstance(response_obj, (litellm.ModelResponse, litellm.EmbeddingResponse)) ## KWARGS assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['messages'], list) assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs['start_time'], Optional[datetime]) assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs['user'], Optional[str]) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) + assert isinstance(kwargs['input'], (list, dict, str)) assert isinstance(kwargs['api_key'], Optional[str]) assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or inspect.iscoroutine(kwargs['original_response']) assert isinstance(kwargs['additional_args'], Optional[dict]) @@ -199,6 +204,7 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): try: + self.states.append("async_failure") ## START TIME assert isinstance(start_time, datetime) ## END TIME @@ -207,21 +213,23 @@ class CompletionCustomHandler(CustomLogger): # https://docs.litellm.ai/docs/obse assert response_obj == None ## KWARGS assert isinstance(kwargs['model'], str) - assert isinstance(kwargs['messages'], list) and isinstance(kwargs['messages'][0], dict) + assert isinstance(kwargs['messages'], list) assert isinstance(kwargs['optional_params'], dict) assert isinstance(kwargs['litellm_params'], dict) assert isinstance(kwargs['start_time'], Optional[datetime]) assert isinstance(kwargs['stream'], bool) assert isinstance(kwargs['user'], Optional[str]) - assert (isinstance(kwargs['input'], list) and isinstance(kwargs['input'][0], dict)) or isinstance(kwargs['input'], (dict, str)) + assert isinstance(kwargs['input'], (list, str, dict)) assert isinstance(kwargs['api_key'], Optional[str]) - assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) + assert isinstance(kwargs['original_response'], (str, litellm.CustomStreamWrapper)) or inspect.isasyncgen(kwargs['original_response']) or kwargs['original_response'] == None assert isinstance(kwargs['additional_args'], Optional[dict]) assert isinstance(kwargs['log_event_type'], str) except: print(f"Assertion Error: {traceback.format_exc()}") self.errors.append(traceback.format_exc()) + +# COMPLETION ## Test OpenAI + sync def test_chat_openai_stream(): try: @@ -379,7 +387,7 @@ async def test_async_chat_azure_stream(): continue except: pass - time.sleep(1) + await asyncio.sleep(1) print(f"customHandler.errors: {customHandler.errors}") assert len(customHandler.errors) == 0 litellm.callbacks = [] @@ -471,4 +479,102 @@ async def test_async_chat_bedrock_stream(): except Exception as e: pytest.fail(f"An exception occurred: {str(e)}") -# asyncio.run(test_async_chat_bedrock_stream()) \ No newline at end of file +# asyncio.run(test_async_chat_bedrock_stream()) + +# EMBEDDING +## Test OpenAI + Async +@pytest.mark.asyncio +async def test_async_embedding_openai(): + try: + customHandler_success = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler_success] + response = await litellm.aembedding(model="azure/azure-embedding-model", + input=["good morning from litellm"]) + await asyncio.sleep(1) + print(f"customHandler_success.errors: {customHandler_success.errors}") + print(f"customHandler_success.states: {customHandler_success.states}") + assert len(customHandler_success.errors) == 0 + assert len(customHandler_success.states) == 3 # pre, post, success + # test failure callback + litellm.callbacks = [customHandler_failure] + try: + response = await litellm.aembedding(model="text-embedding-ada-002", + input=["good morning from litellm"], + api_key="my-bad-key") + except: + pass + await asyncio.sleep(1) + print(f"customHandler_failure.errors: {customHandler_failure.errors}") + print(f"customHandler_failure.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 0 + assert len(customHandler_failure.states) == 3 # pre, post, success + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +# asyncio.run(test_async_embedding_openai()) + +## Test Azure + Async +@pytest.mark.asyncio +async def test_async_embedding_azure(): + try: + customHandler_success = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler_success] + response = await litellm.aembedding(model="azure/azure-embedding-model", + input=["good morning from litellm"]) + await asyncio.sleep(1) + print(f"customHandler_success.errors: {customHandler_success.errors}") + print(f"customHandler_success.states: {customHandler_success.states}") + assert len(customHandler_success.errors) == 0 + assert len(customHandler_success.states) == 3 # pre, post, success + # test failure callback + litellm.callbacks = [customHandler_failure] + try: + response = await litellm.aembedding(model="azure/azure-embedding-model", + input=["good morning from litellm"], + api_key="my-bad-key") + except: + pass + await asyncio.sleep(1) + print(f"customHandler_failure.errors: {customHandler_failure.errors}") + print(f"customHandler_failure.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 0 + assert len(customHandler_failure.states) == 3 # pre, post, success + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +# asyncio.run(test_async_embedding_azure()) + +## Test Bedrock + Async +@pytest.mark.asyncio +async def test_async_embedding_bedrock(): + try: + customHandler_success = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler_success] + litellm.set_verbose = True + response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3", + input=["good morning from litellm"], aws_region_name="os.environ/AWS_REGION_NAME_2") + await asyncio.sleep(1) + print(f"customHandler_success.errors: {customHandler_success.errors}") + print(f"customHandler_success.states: {customHandler_success.states}") + assert len(customHandler_success.errors) == 0 + assert len(customHandler_success.states) == 3 # pre, post, success + # test failure callback + litellm.callbacks = [customHandler_failure] + try: + response = await litellm.aembedding(model="bedrock/cohere.embed-multilingual-v3", + input=["good morning from litellm"], + aws_region_name="my-bad-region") + except: + pass + await asyncio.sleep(1) + print(f"customHandler_failure.errors: {customHandler_failure.errors}") + print(f"customHandler_failure.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 0 + assert len(customHandler_failure.states) == 3 # pre, post, success + except Exception as e: + pytest.fail(f"An exception occurred: {str(e)}") + +asyncio.run(test_async_embedding_bedrock()) \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 1e61d7989..efc146413 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -989,7 +989,7 @@ class Logging: end_time=end_time, print_verbose=print_verbose, ) - if isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class - only call for sync callbacks + elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class print_verbose(f"success callbacks: Running Custom Logger Class") if self.stream and complete_streaming_response is None: callback.log_stream_event( @@ -1192,7 +1192,7 @@ class Logging: print_verbose=print_verbose, callback_func=callback ) - elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False: # custom logger class + elif isinstance(callback, CustomLogger) and self.model_call_details.get("litellm_params", {}).get("acompletion", False) == False and self.model_call_details.get("litellm_params", {}).get("aembedding", False) == False: # custom logger class callback.log_failure_event( start_time=start_time, end_time=end_time, @@ -1641,7 +1641,7 @@ def client(original_function): if litellm.caching or litellm.caching_with_models or litellm.cache != None: # user init a cache object litellm.cache.add_cache(result, *args, **kwargs) # LOG SUCCESS - handle streaming success logging in the _next_ object - print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler") + print_verbose(f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}") asyncio.create_task(logging_obj.async_success_handler(result, start_time, end_time)) threading.Thread(target=logging_obj.success_handler, args=(result, start_time, end_time)).start() # RETURN RESULT @@ -1678,7 +1678,7 @@ def client(original_function): end_time = datetime.datetime.now() if logging_obj: logging_obj.failure_handler(e, traceback_exception, start_time, end_time) # DO NOT MAKE THREADED - router retry fallback relies on this! - asyncio.create_task(logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time)) + await logging_obj.async_failure_handler(e, traceback_exception, start_time, end_time) raise e is_coroutine = inspect.iscoroutinefunction(original_function)