forked from phoenix/litellm-mirror
fix(utils.py): support image gen logging to langfuse
This commit is contained in:
parent
5f9e141d1e
commit
f57483ea70
5 changed files with 58 additions and 91 deletions
|
@ -110,9 +110,16 @@ class LangFuseLogger:
|
||||||
):
|
):
|
||||||
input = prompt
|
input = prompt
|
||||||
output = response_obj["data"]
|
output = response_obj["data"]
|
||||||
elif response_obj is not None:
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ModelResponse
|
||||||
|
):
|
||||||
input = prompt
|
input = prompt
|
||||||
output = response_obj["choices"][0]["message"].json()
|
output = response_obj["choices"][0]["message"].json()
|
||||||
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ImageResponse
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = response_obj["data"]
|
||||||
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
|
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
|
||||||
if self._is_langfuse_v2():
|
if self._is_langfuse_v2():
|
||||||
self._log_langfuse_v2(
|
self._log_langfuse_v2(
|
||||||
|
|
|
@ -741,7 +741,7 @@ class AzureChatCompletion(BaseLLM):
|
||||||
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
|
response = azure_client.images.generate(**data, timeout=timeout) # type: ignore
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=input,
|
input=prompt,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
additional_args={"complete_input_dict": data},
|
additional_args={"complete_input_dict": data},
|
||||||
original_response=response,
|
original_response=response,
|
||||||
|
|
|
@ -3197,6 +3197,7 @@ def image_generation(
|
||||||
"preset_cache_key": None,
|
"preset_cache_key": None,
|
||||||
"stream_response": {},
|
"stream_response": {},
|
||||||
},
|
},
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
|
|
|
@ -3,6 +3,7 @@
|
||||||
import sys, os, time, inspect, asyncio, traceback
|
import sys, os, time, inspect, asyncio, traceback
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
import pytest
|
import pytest
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
sys.path.insert(0, os.path.abspath("../.."))
|
sys.path.insert(0, os.path.abspath("../.."))
|
||||||
from typing import Optional, Literal, List, Union
|
from typing import Optional, Literal, List, Union
|
||||||
|
@ -94,7 +95,8 @@ class CompletionCustomHandler(
|
||||||
assert isinstance(kwargs["api_key"], (str, type(None)))
|
assert isinstance(kwargs["api_key"], (str, type(None)))
|
||||||
assert (
|
assert (
|
||||||
isinstance(
|
isinstance(
|
||||||
kwargs["original_response"], (str, litellm.CustomStreamWrapper)
|
kwargs["original_response"],
|
||||||
|
(str, litellm.CustomStreamWrapper, BaseModel),
|
||||||
)
|
)
|
||||||
or inspect.iscoroutine(kwargs["original_response"])
|
or inspect.iscoroutine(kwargs["original_response"])
|
||||||
or inspect.isasyncgen(kwargs["original_response"])
|
or inspect.isasyncgen(kwargs["original_response"])
|
||||||
|
@ -471,7 +473,7 @@ async def test_async_chat_azure_stream():
|
||||||
pytest.fail(f"An exception occurred: {str(e)}")
|
pytest.fail(f"An exception occurred: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
asyncio.run(test_async_chat_azure_stream())
|
# asyncio.run(test_async_chat_azure_stream())
|
||||||
|
|
||||||
|
|
||||||
## Test Bedrock + sync
|
## Test Bedrock + sync
|
||||||
|
@ -556,6 +558,7 @@ async def test_async_chat_bedrock_stream():
|
||||||
|
|
||||||
# asyncio.run(test_async_chat_bedrock_stream())
|
# asyncio.run(test_async_chat_bedrock_stream())
|
||||||
|
|
||||||
|
|
||||||
## Test Sagemaker + Async
|
## Test Sagemaker + Async
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_async_chat_sagemaker_stream():
|
async def test_async_chat_sagemaker_stream():
|
||||||
|
@ -769,14 +772,18 @@ async def test_async_completion_azure_caching():
|
||||||
unique_time = time.time()
|
unique_time = time.time()
|
||||||
response1 = await litellm.acompletion(
|
response1 = await litellm.acompletion(
|
||||||
model="azure/chatgpt-v-2",
|
model="azure/chatgpt-v-2",
|
||||||
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||||
|
],
|
||||||
caching=True,
|
caching=True,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(1)
|
||||||
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
print(f"customHandler_caching.states pre-cache hit: {customHandler_caching.states}")
|
||||||
response2 = await litellm.acompletion(
|
response2 = await litellm.acompletion(
|
||||||
model="azure/chatgpt-v-2",
|
model="azure/chatgpt-v-2",
|
||||||
messages=[{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}],
|
messages=[
|
||||||
|
{"role": "user", "content": f"Hi 👋 - i'm async azure {unique_time}"}
|
||||||
|
],
|
||||||
caching=True,
|
caching=True,
|
||||||
)
|
)
|
||||||
await asyncio.sleep(1) # success callbacks are done in parallel
|
await asyncio.sleep(1) # success callbacks are done in parallel
|
||||||
|
@ -825,21 +832,25 @@ def test_image_generation_openai():
|
||||||
try:
|
try:
|
||||||
customHandler_success = CompletionCustomHandler()
|
customHandler_success = CompletionCustomHandler()
|
||||||
customHandler_failure = CompletionCustomHandler()
|
customHandler_failure = CompletionCustomHandler()
|
||||||
# litellm.callbacks = [customHandler_success]
|
litellm.callbacks = [customHandler_success]
|
||||||
|
|
||||||
# litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
||||||
# response = litellm.image_generation(
|
response = litellm.image_generation(
|
||||||
# prompt="A cute baby sea otter", model="dall-e-3"
|
prompt="A cute baby sea otter",
|
||||||
# )
|
model="azure/",
|
||||||
|
api_base=os.getenv("AZURE_API_BASE"),
|
||||||
|
api_key=os.getenv("AZURE_API_KEY"),
|
||||||
|
api_version="2023-06-01-preview",
|
||||||
|
)
|
||||||
|
|
||||||
# print(f"response: {response}")
|
print(f"response: {response}")
|
||||||
# assert len(response.data) > 0
|
assert len(response.data) > 0
|
||||||
|
|
||||||
# print(f"customHandler_success.errors: {customHandler_success.errors}")
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
# print(f"customHandler_success.states: {customHandler_success.states}")
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
# assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
# assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
litellm.callbacks = [customHandler_failure]
|
litellm.callbacks = [customHandler_failure]
|
||||||
try:
|
try:
|
||||||
|
@ -862,7 +873,7 @@ def test_image_generation_openai():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
test_image_generation_openai()
|
# test_image_generation_openai()
|
||||||
## Test OpenAI + Async
|
## Test OpenAI + Async
|
||||||
|
|
||||||
## Test Azure + Sync
|
## Test Azure + Sync
|
||||||
|
|
|
@ -44,9 +44,9 @@ except:
|
||||||
filename = str(
|
filename = str(
|
||||||
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
|
resources.files(litellm).joinpath("llms/tokenizers") # for python 3.10
|
||||||
) # for python 3.10+
|
) # for python 3.10+
|
||||||
os.environ[
|
os.environ["TIKTOKEN_CACHE_DIR"] = (
|
||||||
"TIKTOKEN_CACHE_DIR"
|
filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
|
||||||
] = filename # use local copy of tiktoken b/c of - https://github.com/BerriAI/litellm/issues/1071
|
)
|
||||||
|
|
||||||
encoding = tiktoken.get_encoding("cl100k_base")
|
encoding = tiktoken.get_encoding("cl100k_base")
|
||||||
import importlib.metadata
|
import importlib.metadata
|
||||||
|
@ -1110,6 +1110,9 @@ class Logging:
|
||||||
completion_response=result,
|
completion_response=result,
|
||||||
model=self.model,
|
model=self.model,
|
||||||
call_type=self.call_type,
|
call_type=self.call_type,
|
||||||
|
custom_llm_provider=self.model_call_details.get(
|
||||||
|
"custom_llm_provider", None
|
||||||
|
), # set for img gen models
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -1789,14 +1792,14 @@ class Logging:
|
||||||
|
|
||||||
input = self.model_call_details["input"]
|
input = self.model_call_details["input"]
|
||||||
|
|
||||||
type = (
|
_type = (
|
||||||
"embed"
|
"embed"
|
||||||
if self.call_type == CallTypes.embedding.value
|
if self.call_type == CallTypes.embedding.value
|
||||||
else "llm"
|
else "llm"
|
||||||
)
|
)
|
||||||
|
|
||||||
llmonitorLogger.log_event(
|
llmonitorLogger.log_event(
|
||||||
type=type,
|
type=_type,
|
||||||
event="error",
|
event="error",
|
||||||
user_id=self.model_call_details.get("user", "default"),
|
user_id=self.model_call_details.get("user", "default"),
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -3512,6 +3515,15 @@ def completion_cost(
|
||||||
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
|
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
|
if (
|
||||||
|
(call_type == "aimage_generation" or call_type == "image_generation")
|
||||||
|
and model is not None
|
||||||
|
and isinstance(model, str)
|
||||||
|
and len(model) == 0
|
||||||
|
and custom_llm_provider == "azure"
|
||||||
|
):
|
||||||
|
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
|
||||||
# Handle Inputs to completion_cost
|
# Handle Inputs to completion_cost
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
|
@ -3565,12 +3577,15 @@ def completion_cost(
|
||||||
or call_type == CallTypes.aimage_generation.value
|
or call_type == CallTypes.aimage_generation.value
|
||||||
):
|
):
|
||||||
### IMAGE GENERATION COST CALCULATION ###
|
### IMAGE GENERATION COST CALCULATION ###
|
||||||
|
# fix size to match naming convention
|
||||||
|
if "x" in size and "-x-" not in size:
|
||||||
|
size = size.replace("x", "-x-")
|
||||||
image_gen_model_name = f"{size}/{model}"
|
image_gen_model_name = f"{size}/{model}"
|
||||||
image_gen_model_name_with_quality = image_gen_model_name
|
image_gen_model_name_with_quality = image_gen_model_name
|
||||||
if quality is not None:
|
if quality is not None:
|
||||||
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
|
image_gen_model_name_with_quality = f"{quality}/{image_gen_model_name}"
|
||||||
size = size.split("-x-")
|
size = size.split("-x-")
|
||||||
height = int(size[0])
|
height = int(size[0]) # if it's 1024-x-1024 vs. 1024x1024
|
||||||
width = int(size[1])
|
width = int(size[1])
|
||||||
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
|
verbose_logger.debug(f"image_gen_model_name: {image_gen_model_name}")
|
||||||
verbose_logger.debug(
|
verbose_logger.debug(
|
||||||
|
@ -5968,73 +5983,6 @@ def convert_to_model_response_object(
|
||||||
raise Exception(f"Invalid response object {e}")
|
raise Exception(f"Invalid response object {e}")
|
||||||
|
|
||||||
|
|
||||||
# NOTE: DEPRECATING this in favor of using success_handler() in Logging:
|
|
||||||
def handle_success(args, kwargs, result, start_time, end_time):
|
|
||||||
global heliconeLogger, aispendLogger, supabaseClient, liteDebuggerClient, llmonitorLogger
|
|
||||||
try:
|
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
|
||||||
input = (
|
|
||||||
args[1]
|
|
||||||
if len(args) > 1
|
|
||||||
else kwargs.get("messages", kwargs.get("input", None))
|
|
||||||
)
|
|
||||||
success_handler = additional_details.pop("success_handler", None)
|
|
||||||
failure_handler = additional_details.pop("failure_handler", None)
|
|
||||||
additional_details["Event_Name"] = additional_details.pop(
|
|
||||||
"successful_event_name", "litellm.succes_query"
|
|
||||||
)
|
|
||||||
for callback in litellm.success_callback:
|
|
||||||
try:
|
|
||||||
if callback == "posthog":
|
|
||||||
ph_obj = {}
|
|
||||||
for detail in additional_details:
|
|
||||||
ph_obj[detail] = additional_details[detail]
|
|
||||||
event_name = additional_details["Event_Name"]
|
|
||||||
if "user_id" in additional_details:
|
|
||||||
posthog.capture(
|
|
||||||
additional_details["user_id"], event_name, ph_obj
|
|
||||||
)
|
|
||||||
else: # PostHog calls require a unique id to identify a user - https://posthog.com/docs/libraries/python
|
|
||||||
unique_id = str(uuid.uuid4())
|
|
||||||
posthog.capture(unique_id, event_name, ph_obj)
|
|
||||||
pass
|
|
||||||
elif callback == "slack":
|
|
||||||
slack_msg = ""
|
|
||||||
for detail in additional_details:
|
|
||||||
slack_msg += f"{detail}: {additional_details[detail]}\n"
|
|
||||||
slack_app.client.chat_postMessage(
|
|
||||||
channel=alerts_channel, text=slack_msg
|
|
||||||
)
|
|
||||||
elif callback == "aispend":
|
|
||||||
print_verbose("reaches aispend for logging!")
|
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
|
||||||
aispendLogger.log_event(
|
|
||||||
model=model,
|
|
||||||
response_obj=result,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
print_verbose=print_verbose,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# LOGGING
|
|
||||||
exception_logging(logger_fn=user_logger_fn, exception=e)
|
|
||||||
print_verbose(
|
|
||||||
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
if success_handler and callable(success_handler):
|
|
||||||
success_handler(args, kwargs)
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
# LOGGING
|
|
||||||
exception_logging(logger_fn=user_logger_fn, exception=e)
|
|
||||||
print_verbose(
|
|
||||||
f"[Non-Blocking] Success Callback Error - {traceback.format_exc()}"
|
|
||||||
)
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
|
def acreate(*args, **kwargs): ## Thin client to handle the acreate langchain call
|
||||||
return litellm.acompletion(*args, **kwargs)
|
return litellm.acompletion(*args, **kwargs)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue