test(test_braintrust.py): add testing for braintrust integration

This commit is contained in:
Krrish Dholakia 2024-07-22 18:05:11 -07:00
parent dd6d58d29b
commit 92b1262caa
2 changed files with 187 additions and 10 deletions

View file

@ -11,7 +11,6 @@ from typing import Literal, Optional
import dotenv
import httpx
from braintrust import Span, SpanTypeAttribute, init, start_span
import litellm
from litellm import verbose_logger
@ -20,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import get_formatted_prompt
global_braintrust_http_handler = AsyncHTTPHandler()
global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1"
@ -107,11 +107,143 @@ class BraintrustLogger(CustomLogger):
self.default_project_id = project_dict["id"]
def create_sync_default_project_and_experiment(self):
project = global_braintrust_sync_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
)
project_dict = project.json()
self.default_project_id = project_dict["id"]
def log_success_event(self, kwargs, response_obj, start_time, end_time):
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
try:
litellm_call_id = kwargs.get("litellm_call_id")
project_id = kwargs.get("project_id", None)
if project_id is None:
if self.default_project_id is None:
self.create_sync_default_project_and_experiment()
project_id = self.default_project_id
prompt = {"messages": kwargs.get("messages")}
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
try:
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, dict)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = copy.deepcopy(value)
metadata = new_metadata
tags = []
if isinstance(metadata, dict):
for key, value in metadata.items():
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
if (
litellm._langfuse_default_tags is not None
and isinstance(litellm._langfuse_default_tags, list)
and key in litellm._langfuse_default_tags
):
tags.append(f"{key}:{value}")
# clean litellm metadata before logging
if key in [
"headers",
"endpoint",
"caching_groups",
"previous_models",
]:
continue
else:
clean_metadata[key] = value
cost = kwargs.get("response_cost", None)
if cost is not None:
clean_metadata["litellm_response_cost"] = cost
metrics: Optional[dict] = None
if (
response_obj is not None
and hasattr(response_obj, "usage")
and isinstance(response_obj.usage, litellm.Usage)
):
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
metrics = {
"prompt_tokens": response_obj.usage.prompt_tokens,
"completion_tokens": response_obj.usage.completion_tokens,
"total_tokens": response_obj.usage.total_tokens,
"total_cost": cost,
}
request_data = {
"id": litellm_call_id,
"input": prompt,
"output": output,
"metadata": clean_metadata,
"tags": tags,
}
if metrics is not None:
request_data["metrics"] = metrics
try:
global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},
headers=self.headers,
)
except httpx.HTTPStatusError as e:
raise Exception(e.response.text)
except Exception as e:
verbose_logger.error(
"Error logging to braintrust - Exception received - {}\n{}".format(
str(e), traceback.format_exc()
)
)
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
try:
litellm_call_id = kwargs.get("litellm_call_id")
trace_id = kwargs.get("trace_id", litellm_call_id)
project_id = kwargs.get("project_id", None)
if project_id is None:
if self.default_project_id is None:
@ -188,14 +320,6 @@ class BraintrustLogger(CustomLogger):
else:
clean_metadata[key] = value
session_id = clean_metadata.pop("session_id", None)
trace_name = clean_metadata.pop("trace_name", None)
trace_id = clean_metadata.pop("trace_id", litellm_call_id)
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
debug = clean_metadata.pop("debug_langfuse", None)
mask_input = clean_metadata.pop("mask_input", False)
mask_output = clean_metadata.pop("mask_output", False)
cost = kwargs.get("response_cost", None)
if cost is not None:
clean_metadata["litellm_response_cost"] = cost

View file

@ -0,0 +1,53 @@
# What is this?
## This tests the braintrust integration
import asyncio
import os
import random
import sys
import time
import traceback
from datetime import datetime
from dotenv import load_dotenv
from fastapi import Request
load_dotenv()
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import asyncio
import logging
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import litellm
from litellm.llms.custom_httpx.http_handler import HTTPHandler
def test_braintrust_logging():
import litellm
http_client = HTTPHandler()
setattr(
litellm.integrations.braintrust_logging,
"global_braintrust_sync_http_handler",
http_client,
)
with patch.object(http_client, "post", new=MagicMock()) as mock_client:
# set braintrust as a callback, litellm will send the data to braintrust
litellm.callbacks = ["braintrust"]
# openai call
response = litellm.completion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
)
mock_client.assert_called()