Merge pull request #2025 from BerriAI/litellm_langfuse_image_gen_logging

fix(utils.py): support image gen logging to langfuse
This commit is contained in:
Krish Dholakia 2024-02-16 17:08:07 -08:00 committed by GitHub
commit 5e7dda4f88
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 60 additions and 92 deletions

View file

@ -110,9 +110,16 @@ class LangFuseLogger:
):
input = prompt
output = response_obj["data"]
elif response_obj is not None:
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
if self._is_langfuse_v2():
self._log_langfuse_v2(

View file

@ -741,7 +741,7 @@ class AzureChatCompletion(BaseLLM):
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
## LOGGING
logging_obj.post_call(
input=input,
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=response,

View file

@ -3197,6 +3197,7 @@ def image_generation(
"preset_cache_key": None,
"stream_response": {},
},
custom_llm_provider=custom_llm_provider,
)
if custom_llm_provider == "azure":

View file

@ -3,6 +3,7 @@
import sys, os, time, inspect, asyncio, traceback
from datetime import datetime
import pytest
from pydantic import BaseModel
sys.path.insert(0, os.path.abspath("../.."))
from typing import Optional, Literal, List, Union
@ -94,7 +95,8 @@ class CompletionCustomHandler(
assert isinstance(kwargs["api_key"], (str, type(None)))
assert (
isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
)
or inspect.iscoroutine(kwargs["original_response"])
or inspect.isasyncgen(kwargs["original_response"])
@ -174,7 +176,8 @@ class CompletionCustomHandler(
) or isinstance(kwargs["input"], (dict, str))
assert isinstance(kwargs["api_key"], (str, type(None)))
assert isinstance(
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
kwargs["original_response"],
(str, litellm.CustomStreamWrapper, BaseModel),
)
assert isinstance(kwargs["additional_args"], (dict, type(None)))
assert isinstance(kwargs["log_event_type"], str)
@ -471,7 +474,7 @@ async def test_async_chat_azure_stream():
pytest.fail(f"An exception occurred: {str(e)}")
asyncio.run(test_async_chat_azure_stream())
# asyncio.run(test_async_chat_azure_stream())
## Test Bedrock + sync
@ -556,6 +559,7 @@ async def test_async_chat_bedrock_stream():
# asyncio.run(test_async_chat_bedrock_stream())
## Test Sagemaker + Async
@pytest.mark.asyncio
async def test_async_chat_sagemaker_stream():
@ -769,14 +773,18 @@ async def test_async_completion_azure_caching():
unique_time = time.time()
response1 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1)
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
response2 = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
messages=[
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
],
caching=True,
)
await asyncio.sleep(1) # success callbacks are done in parallel
@ -825,21 +833,25 @@ def test_image_generation_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success]
litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True
litellm.set_verbose = True
# response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3"
# )
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/",
api_base=os.getenv("AZURE_API_BASE"),
api_key=os.getenv("AZURE_API_KEY"),
api_version="2023-06-01-preview",
)
# print(f"response: {response}")
# assert len(response.data) > 0
print(f"response: {response}")
assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
@ -862,7 +874,7 @@ def test_image_generation_openai():
pytest.fail(f"An exception occurred - {str(e)}")
test_image_generation_openai()
# test_image_generation_openai()
## Test OpenAI + Async
## Test Azure + Sync

View file

@ -44,9 +44,9 @@ except:
filename = str(
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
) # for python 3.10+
os.environ[
"TIKTOKEN_CACHE_DIR"
] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
os.environ["TIKTOKEN_CACHE_DIR"] = (
filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
)
encoding = tiktoken.get_encoding("cl100k_base")
import importlib.metadata
@ -1110,6 +1110,9 @@ class Logging:
completion_response=result,
model=self.model,
call_type=self.call_type,
custom_llm_provider=self.model_call_details.get(
"custom_llm_provider", None
), # set for img gen models
)
)
else:
@ -1789,14 +1792,14 @@ class Logging:
input = self.model_call_details["input"]
type = (
_type = (
"embed"
if self.call_type == CallTypes.embedding.value
else "llm"
)
llmonitorLogger.log_event(
type=type,
type=_type,
event="error",
user_id=self.model_call_details.get("user", "default"),
model=model,
@ -3512,6 +3515,15 @@ def completion_cost(
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
"""
try:
if (
(call_type == "aimage_generation" or call_type == "image_generation")
and model is not None
and isinstance(model, str)
and len(model) == 0
and custom_llm_provider == "azure"
):
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
# Handle Inputs to completion_cost
prompt_tokens = 0
completion_tokens = 0
@ -3565,12 +3577,15 @@ def completion_cost(
or call_type == CallTypes.aimage_generation.value
):
### IMAGE GENERATION COST CALCULATION ###
# fix size to match naming convention
if "x" in size and "-x-" not in size:
size = size.replace("x", "-x-")
image_gen_model_name = f"{size}/{model}"
image_gen_model_name_with_quality = image_gen_model_name
if quality is not None:
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
size = size.split("-x-")
height = int(size[0])
height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024
width = int(size[1])
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
verbose_logger.debug(
@ -5968,73 +5983,6 @@ def convert_to_model_response_object(
raise Exception(f"Invalid response object {e}")
# NOTE: DEPRECATING this in favor of using success_handler() in Logging:
def handle_success(args, kwargs, result, start_time, end_time):
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
try:
model = args[0] if len(args) > 0 else kwargs["model"]
input = (
args[1]
if len(args) > 1
else kwargs.get("messages", kwargs.get("input", None))
)
success_handler = additional_details.pop("success_handler", None)
failure_handler = additional_details.pop("failure_handler", None)
additional_details["Event_Name"] = additional_details.pop(
"successful_event_name", "litellm.succes_query"
)
for callback in litellm.success_callback:
try:
if callback == "posthog":
ph_obj = {}
for detail in additional_details:
ph_obj[detail] = additional_details[detail]
event_name = additional_details["Event_Name"]
if "user_id" in additional_details:
posthog.capture(
additional_details["user_id"], event_name, ph_obj
)
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
unique_id = str(uuid.uuid4())
posthog.capture(unique_id, event_name, ph_obj)
pass
elif callback == "slack":
slack_msg = ""
for detail in additional_details:
slack_msg += f"{detail}: {additional_details[detail]}\n"
slack_app.client.chat_postMessage(
channel=alerts_channel, text=slack_msg
)
elif callback == "aispend":
print_verbose("reaches aispend for logging!")
model = args[0] if len(args) > 0 else kwargs["model"]
aispendLogger.log_event(
model=model,
response_obj=result,
start_time=start_time,
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e:
# LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e)
print_verbose(
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
)
pass
if success_handler and callable(success_handler):
success_handler(args, kwargs)
pass
except Exception as e:
# LOGGING
exception_logging(logger_fn=user_logger_fn, exception=e)
print_verbose(
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
)
pass
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
return litellm.acompletion(*args, **kwargs)