diff --git a/litellm/integrations/langfuse.py b/litellm/integrations/langfuse.py index c60f33d03..f7f2634de 100644 --- a/litellm/integrations/langfuse.py +++ b/litellm/integrations/langfuse.py @@ -110,9 +110,16 @@ class LangFuseLogger: ): input = prompt output = response_obj["data"] - elif response_obj is not None: + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): input = prompt output = response_obj["choices"][0]["message"].json() + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): + input = prompt + output = response_obj["data"] print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}") if self._is_langfuse_v2(): self._log_langfuse_v2( diff --git a/litellm/llms/azure.py b/litellm/llms/azure.py index f20a2e939..01b54987b 100644 --- a/litellm/llms/azure.py +++ b/litellm/llms/azure.py @@ -741,7 +741,7 @@ class AzureChatCompletion(BaseLLM): response = azure_client.images.generate(**data, timeout=timeout) # type: ignore ## LOGGING logging_obj.post_call( - input=input, + input=prompt, api_key=api_key, additional_args={"complete_input_dict": data}, original_response=response, diff --git a/litellm/main.py b/litellm/main.py index 2539039cd..e4dd684c8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3197,6 +3197,7 @@ def image_generation( "preset_cache_key": None, "stream_response": {}, }, + custom_llm_provider=custom_llm_provider, ) if custom_llm_provider == "azure": diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index 080754ca8..579fe6583 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -3,6 +3,7 @@ import sys, os, time, inspect, asyncio, traceback from datetime import datetime import pytest +from pydantic import BaseModel sys.path.insert(0, os.path.abspath("../..")) from typing import Optional, Literal, List, Union @@ -94,7 +95,8 @@ class CompletionCustomHandler( assert isinstance(kwargs["api_key"], (str, type(None))) assert ( isinstance( - kwargs["original_response"], (str, litellm.CustomStreamWrapper) + kwargs["original_response"], + (str, litellm.CustomStreamWrapper, BaseModel), ) or inspect.iscoroutine(kwargs["original_response"]) or inspect.isasyncgen(kwargs["original_response"]) @@ -174,7 +176,8 @@ class CompletionCustomHandler( ) or isinstance(kwargs["input"], (dict, str)) assert isinstance(kwargs["api_key"], (str, type(None))) assert isinstance( - kwargs["original_response"], (str, litellm.CustomStreamWrapper) + kwargs["original_response"], + (str, litellm.CustomStreamWrapper, BaseModel), ) assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["log_event_type"], str) @@ -471,7 +474,7 @@ async def test_async_chat_azure_stream(): pytest.fail(f"An exception occurred: {str(e)}") -asyncio.run(test_async_chat_azure_stream()) +# asyncio.run(test_async_chat_azure_stream()) ## Test Bedrock + sync @@ -556,6 +559,7 @@ async def test_async_chat_bedrock_stream(): # asyncio.run(test_async_chat_bedrock_stream()) + ## Test Sagemaker + Async @pytest.mark.asyncio async def test_async_chat_sagemaker_stream(): @@ -769,14 +773,18 @@ async def test_async_completion_azure_caching(): unique_time = time.time() response1 = await litellm.acompletion( model="azure/chatgpt-v-2", - messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + messages=[ + {"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"} + ], caching=True, ) await asyncio.sleep(1) print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") response2 = await litellm.acompletion( model="azure/chatgpt-v-2", - messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], + messages=[ + {"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"} + ], caching=True, ) await asyncio.sleep(1) # success callbacks are done in parallel @@ -825,21 +833,25 @@ def test_image_generation_openai(): try: customHandler_success = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler() - # litellm.callbacks = [customHandler_success] + litellm.callbacks = [customHandler_success] - # litellm.set_verbose = True + litellm.set_verbose = True - # response = litellm.image_generation( - # prompt="A cute baby sea otter", model="dall-e-3" - # ) + response = litellm.image_generation( + prompt="A cute baby sea otter", + model="azure/", + api_base=os.getenv("AZURE_API_BASE"), + api_key=os.getenv("AZURE_API_KEY"), + api_version="2023-06-01-preview", + ) - # print(f"response: {response}") - # assert len(response.data) > 0 + print(f"response: {response}") + assert len(response.data) > 0 - # print(f"customHandler_success.errors: {customHandler_success.errors}") - # print(f"customHandler_success.states: {customHandler_success.states}") - # assert len(customHandler_success.errors) == 0 - # assert len(customHandler_success.states) == 3 # pre, post, success + print(f"customHandler_success.errors: {customHandler_success.errors}") + print(f"customHandler_success.states: {customHandler_success.states}") + assert len(customHandler_success.errors) == 0 + assert len(customHandler_success.states) == 3 # pre, post, success # test failure callback litellm.callbacks = [customHandler_failure] try: @@ -862,7 +874,7 @@ def test_image_generation_openai(): pytest.fail(f"An exception occurred - {str(e)}") -test_image_generation_openai() +# test_image_generation_openai() ## Test OpenAI + Async ## Test Azure + Sync diff --git a/litellm/utils.py b/litellm/utils.py index 2e54b5e44..4dece6ab8 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -44,9 +44,9 @@ except: filename = str( resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10 ) # for python 3.10+ -os.environ[ - "TIKTOKEN_CACHE_DIR" -] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 +os.environ["TIKTOKEN_CACHE_DIR"] = ( + filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 +) encoding = tiktoken.get_encoding("cl100k_base") import importlib.metadata @@ -1110,6 +1110,9 @@ class Logging: completion_response=result, model=self.model, call_type=self.call_type, + custom_llm_provider=self.model_call_details.get( + "custom_llm_provider", None + ), # set for img gen models ) ) else: @@ -1789,14 +1792,14 @@ class Logging: input = self.model_call_details["input"] - type = ( + _type = ( "embed" if self.call_type == CallTypes.embedding.value else "llm" ) llmonitorLogger.log_event( - type=type, + type=_type, event="error", user_id=self.model_call_details.get("user", "default"), model=model, @@ -3512,6 +3515,15 @@ def completion_cost( - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. """ try: + + if ( + (call_type == "aimage_generation" or call_type == "image_generation") + and model is not None + and isinstance(model, str) + and len(model) == 0 + and custom_llm_provider == "azure" + ): + model = "dall-e-2" # for dall-e-2, azure expects an empty model name # Handle Inputs to completion_cost prompt_tokens = 0 completion_tokens = 0 @@ -3565,12 +3577,15 @@ def completion_cost( or call_type == CallTypes.aimage_generation.value ): ### IMAGE GENERATION COST CALCULATION ### + # fix size to match naming convention + if "x" in size and "-x-" not in size: + size = size.replace("x", "-x-") image_gen_model_name = f"{size}/{model}" image_gen_model_name_with_quality = image_gen_model_name if quality is not None: image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}" size = size.split("-x-") - height = int(size[0]) + height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024 width = int(size[1]) verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}") verbose_logger.debug( @@ -5968,73 +5983,6 @@ def convert_to_model_response_object( raise Exception(f"Invalid response object {e}") -# NOTE: DEPRECATING this in favor of using success_handler() in Logging: -def handle_success(args, kwargs, result, start_time, end_time): - global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger - try: - model = args[0] if len(args) > 0 else kwargs["model"] - input = ( - args[1] - if len(args) > 1 - else kwargs.get("messages", kwargs.get("input", None)) - ) - success_handler = additional_details.pop("success_handler", None) - failure_handler = additional_details.pop("failure_handler", None) - additional_details["Event_Name"] = additional_details.pop( - "successful_event_name", "litellm.succes_query" - ) - for callback in litellm.success_callback: - try: - if callback == "posthog": - ph_obj = {} - for detail in additional_details: - ph_obj[detail] = additional_details[detail] - event_name = additional_details["Event_Name"] - if "user_id" in additional_details: - posthog.capture( - additional_details["user_id"], event_name, ph_obj - ) - else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python - unique_id = str(uuid.uuid4()) - posthog.capture(unique_id, event_name, ph_obj) - pass - elif callback == "slack": - slack_msg = "" - for detail in additional_details: - slack_msg += f"{detail}: {additional_details[detail]}\n" - slack_app.client.chat_postMessage( - channel=alerts_channel, text=slack_msg - ) - elif callback == "aispend": - print_verbose("reaches aispend for logging!") - model = args[0] if len(args) > 0 else kwargs["model"] - aispendLogger.log_event( - model=model, - response_obj=result, - start_time=start_time, - end_time=end_time, - print_verbose=print_verbose, - ) - except Exception as e: - # LOGGING - exception_logging(logger_fn=user_logger_fn, exception=e) - print_verbose( - f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}" - ) - pass - - if success_handler and callable(success_handler): - success_handler(args, kwargs) - pass - except Exception as e: - # LOGGING - exception_logging(logger_fn=user_logger_fn, exception=e) - print_verbose( - f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}" - ) - pass - - def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call return litellm.acompletion(*args, **kwargs)