Update support for langfuse metadata

- Added ability to set trace release, version, metadata
- Added ability to update fields during a trace continuation
- Added ability to update input and output during a trace continuation
- Wrote new test for verifying metadata is set correctly
- Small improvement to setting secret boolean, prevent unnecessary literal_eval
- Small improvements to langfuse tests
This commit is contained in:
Alex Epstein 2024-05-04 23:10:04 -04:00
parent d45328dda6
commit b82162832a
4 changed files with 224 additions and 67 deletions

View file

@ -262,6 +262,7 @@ class LangFuseLogger:
try:
tags = []
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
supports_tags = Version(langfuse.version.__version__) >= Version("2.6.3")
supports_prompt = Version(langfuse.version.__version__) >= Version("2.7.3")
supports_costs = Version(langfuse.version.__version__) >= Version("2.7.3")
@ -272,35 +273,9 @@ class LangFuseLogger:
print_verbose(f"Langfuse Layer Logging - logging to langfuse v2 ")
if supports_tags:
metadata_tags = metadata.get("tags", [])
metadata_tags = metadata.pop("tags", [])
tags = metadata_tags
trace_name = metadata.get("trace_name", None)
trace_id = metadata.get("trace_id", None)
existing_trace_id = metadata.get("existing_trace_id", None)
if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if existing_trace_id is not None:
trace_params = {"id": existing_trace_id}
else: # don't overwrite an existing trace
trace_params = {
"name": trace_name,
"input": input,
"user_id": metadata.get("trace_user_id", user_id),
"id": trace_id,
"session_id": metadata.get("session_id", None),
}
if level == "ERROR":
trace_params["status_message"] = output
else:
trace_params["output"] = output
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
# Clean Metadata before logging - never log raw metadata
# the raw metadata can contain circular references which leads to infinite recursion
@ -328,6 +303,58 @@ class LangFuseLogger:
else:
clean_metadata[key] = value
session_id = clean_metadata.pop("session_id", None)
trace_name = clean_metadata.pop("trace_name", None)
trace_id = clean_metadata.pop("trace_id", None)
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name
## DO NOT SET TRACE_NAME if trace-id set. this can lead to overwriting of past traces.
trace_name = f"litellm-{kwargs.get('call_type', 'completion')}"
if existing_trace_id is not None:
trace_params = {"id": existing_trace_id}
# Update the following keys for this trace
for metadata_param_key in update_trace_keys:
trace_param_key = metadata_param_key.replace("trace_", "")
if trace_param_key not in trace_params:
updated_trace_value = clean_metadata.pop(metadata_param_key, None)
if updated_trace_value is not None:
trace_params[trace_param_key] = updated_trace_value
# Pop the trace specific keys that would have been popped if there were a new trace
for key in list(filter(lambda key: key.startswith("trace_"), clean_metadata.keys())):
clean_metadata.pop(key, None)
# Special keys that are found in the function arguments and not the metadata
if "input" in update_trace_keys:
trace_params["input"] = input
if "output" in update_trace_keys:
trace_params["output"] = output
else: # don't overwrite an existing trace
trace_params = {
"id": trace_id,
"name": trace_name,
"session_id": session_id,
"input": input,
"version": clean_metadata.pop("trace_version", clean_metadata.get("version", None)), # If provided just version, it will applied to the trace as well, if applied a trace version it will take precedence
}
for key in list(filter(lambda key: key.startswith("trace_"), clean_metadata.keys())):
trace_params[key.replace("trace_", "")] = clean_metadata.pop(key, None)
if level == "ERROR":
trace_params["status_message"] = output
else:
trace_params["output"] = output
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
@ -387,7 +414,7 @@ class LangFuseLogger:
"completion_tokens": response_obj["usage"]["completion_tokens"],
"total_cost": cost if supports_costs else None,
}
generation_name = metadata.get("generation_name", None)
generation_name = clean_metadata.pop("generation_name", None)
if generation_name is None:
# just log `litellm-{call_type}` as the generation name
generation_name = f"litellm-{kwargs.get('call_type', 'completion')}"
@ -402,7 +429,7 @@ class LangFuseLogger:
generation_params = {
"name": generation_name,
"id": metadata.get("generation_id", generation_id),
"id": clean_metadata.pop("generation_id", generation_id),
"start_time": start_time,
"end_time": end_time,
"model": kwargs["model"],
@ -412,10 +439,11 @@ class LangFuseLogger:
"usage": usage,
"metadata": clean_metadata,
"level": level,
"version": clean_metadata.pop("version", None),
}
if supports_prompt:
generation_params["prompt"] = metadata.get("prompt", None)
generation_params["prompt"] = clean_metadata.pop("prompt", None)
if output is not None and isinstance(output, str) and level == "ERROR":
generation_params["status_message"] = output