Merge pull request #3581 from BerriAI/litellm_log_metadata_langfuse_traces

[Feat] - log metadata on traces + allow users to log metadata when `existing_trace_id` exists
This commit is contained in:
Ishaan Jaff 2024-05-11 14:19:48 -07:00 committed by GitHub
commit b9b8bf52f3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 16 additions and 1 deletions

View file

@ -136,6 +136,7 @@ response = completion(
"existing_trace_id": "trace-id22",
"trace_metadata": {"key": "updated_trace_value"}, # The new value to use for the langfuse Trace Metadata
"update_trace_keys": ["input", "output", "trace_metadata"], # Updates the trace input & output to be this generations input & output also updates the Trace Metadata to match the passed in value
"debug_langfuse": True, # Will log the exact metadata sent to litellm for the trace/generation as `metadata_passed_to_litellm`
},
)

View file

@ -323,6 +323,7 @@ class LangFuseLogger:
trace_id = clean_metadata.pop("trace_id", None)
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
debug = clean_metadata.pop("debug_langfuse", None)
if trace_name is None and existing_trace_id is None:
# just log `litellm-{call_type}` as the trace name
@ -376,6 +377,13 @@ class LangFuseLogger:
else:
trace_params["output"] = output
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params:
# log the raw_metadata in the trace
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
else:
trace_params["metadata"] = {"metadata_passed_to_litellm": metadata}
cost = kwargs.get("response_cost", None)
print_verbose(f"trace: {cost}")
@ -426,7 +434,6 @@ class LangFuseLogger:
"url": url,
"headers": clean_headers,
}
trace = self.Langfuse.trace(**trace_params)
generation_id = None

View file

@ -339,6 +339,13 @@ async def test_langfuse_logging_metadata(langfuse_client):
for generation_id, generation in zip(generation_ids, generations):
assert generation.id == generation_id
assert generation.trace_id == trace_id
print(
"common keys in trace",
set(generation.metadata.keys()).intersection(
expected_filtered_metadata_keys
),
)
assert set(generation.metadata.keys()).isdisjoint(
expected_filtered_metadata_keys
)