fix(langfuse.py): don't overwrite trace details if existing trace id passed in

This commit is contained in:
Krrish Dholakia 2024-05-01 08:15:03 -07:00
parent fc5a845838
commit abdae87ba2
2 changed files with 217 additions and 12 deletions

View file

@ -79,7 +79,7 @@ class LangFuseLogger:
print_verbose,
level="DEFAULT",
status_message=None,
):
) -> dict:
# Method definition
try:
@ -111,6 +111,7 @@ class LangFuseLogger:
pass
# end of processing langfuse ########################
print(f"response obj type: {type(response_obj)}")
if (
level == "ERROR"
and status_message is not None
@ -140,8 +141,11 @@ class LangFuseLogger:
input = prompt
output = response_obj["data"]
print_verbose(f"OUTPUT IN LANGFUSE: {output}; original: {response_obj}")
trace_id = None
generation_id = None
if self._is_langfuse_v2():
self._log_langfuse_v2(
print("INSIDE V2 LANGFUSE")
trace_id, generation_id = self._log_langfuse_v2(
user_id,
metadata,
litellm_params,
@ -171,10 +175,12 @@ class LangFuseLogger:
f"Langfuse Layer Logging - final response object: {response_obj}"
)
verbose_logger.info(f"Langfuse Layer Logging - logging success")
return {"trace_id": trace_id, "generation_id": generation_id}
except:
traceback.print_exc()
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
pass
return {"trace_id": None, "generation_id": None}
async def _async_log_event(
self, kwargs, response_obj, start_time, end_time, user_id, print_verbose
@ -246,7 +252,7 @@ class LangFuseLogger:
response_obj,
level,
print_verbose,
):
) -> tuple:
import langfuse
try:
@ -272,13 +278,16 @@ class LangFuseLogger:
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
trace_params = {
"name": trace_name,
"input": input,
"user_id": metadata.get("trace_user_id", user_id),
"id": trace_id or existing_trace_id,
"session_id": metadata.get("session_id", None),
}
if existing_trace_id is not None:
trace_params = {"trace_id": existing_trace_id}
else: # don't overwrite an existing trace
trace_params = {
"name": trace_name,
"input": input,
"user_id": metadata.get("trace_user_id", user_id),
"id": trace_id,
"session_id": metadata.get("session_id", None),
}
if level == "ERROR":
trace_params["status_message"] = output
@ -414,6 +423,10 @@ class LangFuseLogger:
print_verbose(f"generation_params: {generation_params}")
trace.generation(**generation_params)
generation_client = trace.generation(**generation_params)
print(f"LANGFUSE TRACE ID - {generation_client.trace_id}")
return generation_client.trace_id, generation_id
except Exception as e:
verbose_logger.debug(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None