diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index 8bd813b69..0f27bb102 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -11,7 +11,6 @@ from typing import Literal, Optional import dotenv import httpx -from braintrust import Span, SpanTypeAttribute, init, start_span import litellm from litellm import verbose_logger @@ -20,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import get_formatted_prompt global_braintrust_http_handler = AsyncHTTPHandler() +global_braintrust_sync_http_handler = HTTPHandler() API_BASE = "https://api.braintrustdata.com/v1" @@ -107,11 +107,143 @@ class BraintrustLogger(CustomLogger): self.default_project_id = project_dict["id"] + def create_sync_default_project_and_experiment(self): + project = global_braintrust_sync_http_handler.post( + f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} + ) + + project_dict = project.json() + + self.default_project_id = project_dict["id"] + + def log_success_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.debug("REACHES BRAINTRUST SUCCESS") + try: + litellm_call_id = kwargs.get("litellm_call_id") + project_id = kwargs.get("project_id", None) + if project_id is None: + if self.default_project_id is None: + self.create_sync_default_project_and_experiment() + project_id = self.default_project_id + + prompt = {"messages": kwargs.get("messages")} + + if response_obj is not None and ( + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) + ): + input = prompt + output = None + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): + input = prompt + output = response_obj["choices"][0]["message"].json() + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): + input = prompt + output = response_obj.choices[0].text + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): + input = prompt + output = response_obj["data"] + + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + metadata = self.add_metadata_from_header(litellm_params, metadata) + clean_metadata = {} + try: + metadata = copy.deepcopy( + metadata + ) # Avoid modifying the original metadata + except: + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, dict) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = copy.deepcopy(value) + metadata = new_metadata + + tags = [] + if isinstance(metadata, dict): + for key, value in metadata.items(): + + # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy + if ( + litellm._langfuse_default_tags is not None + and isinstance(litellm._langfuse_default_tags, list) + and key in litellm._langfuse_default_tags + ): + tags.append(f"{key}:{value}") + + # clean litellm metadata before logging + if key in [ + "headers", + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + cost = kwargs.get("response_cost", None) + if cost is not None: + clean_metadata["litellm_response_cost"] = cost + + metrics: Optional[dict] = None + if ( + response_obj is not None + and hasattr(response_obj, "usage") + and isinstance(response_obj.usage, litellm.Usage) + ): + generation_id = litellm.utils.get_logging_id(start_time, response_obj) + metrics = { + "prompt_tokens": response_obj.usage.prompt_tokens, + "completion_tokens": response_obj.usage.completion_tokens, + "total_tokens": response_obj.usage.total_tokens, + "total_cost": cost, + } + + request_data = { + "id": litellm_call_id, + "input": prompt, + "output": output, + "metadata": clean_metadata, + "tags": tags, + } + if metrics is not None: + request_data["metrics"] = metrics + + try: + global_braintrust_sync_http_handler.post( + url=f"{self.api_base}/project_logs/{project_id}/insert", + json={"events": [request_data]}, + headers=self.headers, + ) + except httpx.HTTPStatusError as e: + raise Exception(e.response.text) + except Exception as e: + verbose_logger.error( + "Error logging to braintrust - Exception received - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + raise e + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): verbose_logger.debug("REACHES BRAINTRUST SUCCESS") try: litellm_call_id = kwargs.get("litellm_call_id") - trace_id = kwargs.get("trace_id", litellm_call_id) project_id = kwargs.get("project_id", None) if project_id is None: if self.default_project_id is None: @@ -188,14 +320,6 @@ class BraintrustLogger(CustomLogger): else: clean_metadata[key] = value - session_id = clean_metadata.pop("session_id", None) - trace_name = clean_metadata.pop("trace_name", None) - trace_id = clean_metadata.pop("trace_id", litellm_call_id) - existing_trace_id = clean_metadata.pop("existing_trace_id", None) - update_trace_keys = clean_metadata.pop("update_trace_keys", []) - debug = clean_metadata.pop("debug_langfuse", None) - mask_input = clean_metadata.pop("mask_input", False) - mask_output = clean_metadata.pop("mask_output", False) cost = kwargs.get("response_cost", None) if cost is not None: clean_metadata["litellm_response_cost"] = cost diff --git a/litellm/tests/test_braintrust.py b/litellm/tests/test_braintrust.py new file mode 100644 index 000000000..7792a0841 --- /dev/null +++ b/litellm/tests/test_braintrust.py @@ -0,0 +1,53 @@ +# What is this? +## This tests the braintrust integration + +import asyncio +import os +import random +import sys +import time +import traceback +from datetime import datetime + +from dotenv import load_dotenv +from fastapi import Request + +load_dotenv() +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import asyncio +import logging +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +import litellm +from litellm.llms.custom_httpx.http_handler import HTTPHandler + + +def test_braintrust_logging(): + import litellm + + http_client = HTTPHandler() + + setattr( + litellm.integrations.braintrust_logging, + "global_braintrust_sync_http_handler", + http_client, + ) + + with patch.object(http_client, "post", new=MagicMock()) as mock_client: + + # set braintrust as a callback, litellm will send the data to braintrust + litellm.callbacks = ["braintrust"] + + # openai call + response = litellm.completion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}], + ) + + mock_client.assert_called()