Merge pull request #2025 from BerriAI/litellm_langfuse_image_gen_logging

fix(utils.py): support image gen logging to langfuse
This commit is contained in:
Krish Dholakia 2024-02-16 17:08:07 -08:00 committed by GitHub
commit 5e7dda4f88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 60 additions and 92 deletions

View file

@ -110,9 +110,16 @@ class LangFuseLogger:
): ):
input = prompt input = prompt
output = response_obj["data"] output = response_obj["data"]
elif response_obj is not None: elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}") print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
if self._is_langfuse_v2(): if self._is_langfuse_v2():
self._log_langfuse_v2( self._log_langfuse_v2(

View file

@ -741,7 +741,7 @@ class AzureChatCompletion(BaseLLM):
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=input, input=prompt,
api_key=api_key, api_key=api_key,
additional_args={"complete_input_dict": data}, additional_args={"complete_input_dict": data},
original_response=response, original_response=response,

View file

@ -3197,6 +3197,7 @@ def image_generation(
"preset_cache_key": None, "preset_cache_key": None,
"stream_response": {}, "stream_response": {},
}, },
custom_llm_provider=custom_llm_provider,
) )
if custom_llm_provider == "azure": if custom_llm_provider == "azure":

View file

@ -3,6 +3,7 @@
import sys, os, time, inspect, asyncio, traceback import sys, os, time, inspect, asyncio, traceback
from datetime import datetime from datetime import datetime
import pytest import pytest
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../..")) sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List, Union from typing import Optional, Literal, List, Union
@ -94,7 +95,8 @@ class CompletionCustomHandler(
assert isinstance(kwargs["api_key"], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert ( assert (
isinstance( isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper) kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
) )
or inspect.iscoroutine(kwargs["original_response"]) or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"]) or inspect.isasyncgen(kwargs["original_response"])
@ -174,7 +176,8 @@ class CompletionCustomHandler(
) or isinstance(kwargs["input"], (dict, str)) ) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None))) assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance( assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper) kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
) )
assert isinstance(kwargs["additional_args"], (dict, type(None))) assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str) assert isinstance(kwargs["log_event_type"], str)
@ -471,7 +474,7 @@ async def test_async_chat_azure_stream():
pytest.fail(f"An exception occurred: {str(e)}") pytest.fail(f"An exception occurred: {str(e)}")
asyncio.run(test_async_chat_azure_stream()) # asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync ## Test Bedrock + sync
@ -556,6 +559,7 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream()) # asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async ## Test Sagemaker + Async
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_async_chat_sagemaker_stream(): async def test_async_chat_sagemaker_stream():
@ -769,14 +773,18 @@ async def test_async_completion_azure_caching():
unique_time = time.time() unique_time = time.time()
response1 = await litellm.acompletion( response1 = await litellm.acompletion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True, caching=True,
) )
await asyncio.sleep(1) await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}") print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion( response2 = await litellm.acompletion(
model="azure/chatgpt-v-2", model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}], messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True, caching=True,
) )
await asyncio.sleep(1) # success callbacks are done in parallel await asyncio.sleep(1) # success callbacks are done in parallel
@ -825,21 +833,25 @@ def test_image_generation_openai():
try: try:
customHandler_success = CompletionCustomHandler() customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success] litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True litellm.set_verbose = True
# response = litellm.image_generation( response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3" prompt="A cute baby sea otter",
# ) model="azure/",
api_base=os.getenv("AZURE_API_BASE"),
api_key=os.getenv("AZURE_API_KEY"),
api_version="2023-06-01-preview",
)
# print(f"response: {response}") print(f"response: {response}")
# assert len(response.data) > 0 assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}") print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0 assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback # test failure callback
litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
try: try:
@ -862,7 +874,7 @@ def test_image_generation_openai():
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
test_image_generation_openai() # test_image_generation_openai()
## Test OpenAI + Async ## Test OpenAI + Async
## Test Azure + Sync ## Test Azure + Sync

View file

@ -44,9 +44,9 @@ except:
filename = str( filename = str(
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10 resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
) # for python 3.10+ ) # for python 3.10+
os.environ[ os.environ["TIKTOKEN_CACHE_DIR"] = (
"TIKTOKEN_CACHE_DIR" filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071 )
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata import importlib.metadata
@ -1110,6 +1110,9 @@ class Logging:
completion_response=result, completion_response=result,
model=self.model, model=self.model,
call_type=self.call_type, call_type=self.call_type,
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
), # set for img gen models
) )
) )
else: else:
@ -1789,14 +1792,14 @@ class Logging:
input = self.model_call_details["input"] input = self.model_call_details["input"]
type = ( _type = (
"embed" "embed"
if self.call_type == CallTypes.embedding.value if self.call_type == CallTypes.embedding.value
else "llm" else "llm"
) )
llmonitorLogger.log_event( llmonitorLogger.log_event(
type=type, type=_type,
event="error", event="error",
user_id=self.model_call_details.get("user", "default"), user_id=self.model_call_details.get("user", "default"),
model=model, model=model,
@ -3512,6 +3515,15 @@ def completion_cost(
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
""" """
try: try:
if (
(call_type == "aimage_generation" or call_type == "image_generation")
and model is not None
and isinstance(model, str)
and len(model) == 0
and custom_llm_provider == "azure"
):
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
# Handle Inputs to completion_cost # Handle Inputs to completion_cost
prompt_tokens = 0 prompt_tokens = 0
completion_tokens = 0 completion_tokens = 0
@ -3565,12 +3577,15 @@ def completion_cost(
or call_type == CallTypes.aimage_generation.value or call_type == CallTypes.aimage_generation.value
): ):
### IMAGE GENERATION COST CALCULATION ### ### IMAGE GENERATION COST CALCULATION ###
# fix size to match naming convention
if "x" in size and "-x-" not in size:
size = size.replace("x", "-x-")
image_gen_model_name = f"{size}/{model}" image_gen_model_name = f"{size}/{model}"
image_gen_model_name_with_quality = image_gen_model_name image_gen_model_name_with_quality = image_gen_model_name
if quality is not None: if quality is not None:
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}" image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
size = size.split("-x-") size = size.split("-x-")
height = int(size[0]) height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024
width = int(size[1]) width = int(size[1])
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}") verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
verbose_logger.debug( verbose_logger.debug(
@ -5968,73 +5983,6 @@ def convert_to_model_response_object(
raise Exception(f"Invalid response object {e}") raise Exception(f"Invalid response object {e}")
# NOTE: DEPRECATING this in favor of using success_handler() in Logging:
def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try:
model = args[0] if len(args) > 0 else kwargs["model"]
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop(
"successful_event_name", "litellm.succes_query"
)
for callback in litellm.success_callback:
try:
if callback == "posthog":
ph_obj = {}
for detail in additional_details:
ph_obj[detail] = additional_details[detail]
event_name = additional_details["Event_Name"]
if "user_id" in additional_details:
posthog.capture(
additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name, ph_obj)
pass
elif callback == "slack":
slack_msg = ""
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_app.client.chat_postMessage(
channel=alerts_channel, text=slack_msg
)
elif callback == "aispend":
print_verbose("reaches aispend for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
aispendLogger.log_event(
model=model,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e:
# LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e)
print_verbose(
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
)
pass
if success_handler and callable(success_handler):
success_handler(args, kwargs)
pass
except Exception as e:
# LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e)
print_verbose(
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
)
pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs) return litellm.acompletion(*args, **kwargs)