forked from phoenix/litellm-mirror
test(test_braintrust.py): add testing for braintrust integration
This commit is contained in:
parent
dd6d58d29b
commit
92b1262caa
2 changed files with 187 additions and 10 deletions
|
@ -11,7 +11,6 @@ from typing import Literal, Optional
|
||||||
|
|
||||||
import dotenv
|
import dotenv
|
||||||
import httpx
|
import httpx
|
||||||
from braintrust import Span, SpanTypeAttribute, init, start_span
|
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
@ -20,6 +19,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from litellm.utils import get_formatted_prompt
|
from litellm.utils import get_formatted_prompt
|
||||||
|
|
||||||
global_braintrust_http_handler = AsyncHTTPHandler()
|
global_braintrust_http_handler = AsyncHTTPHandler()
|
||||||
|
global_braintrust_sync_http_handler = HTTPHandler()
|
||||||
API_BASE = "https://api.braintrustdata.com/v1"
|
API_BASE = "https://api.braintrustdata.com/v1"
|
||||||
|
|
||||||
|
|
||||||
|
@ -107,11 +107,143 @@ class BraintrustLogger(CustomLogger):
|
||||||
|
|
||||||
self.default_project_id = project_dict["id"]
|
self.default_project_id = project_dict["id"]
|
||||||
|
|
||||||
|
def create_sync_default_project_and_experiment(self):
|
||||||
|
project = global_braintrust_sync_http_handler.post(
|
||||||
|
f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"}
|
||||||
|
)
|
||||||
|
|
||||||
|
project_dict = project.json()
|
||||||
|
|
||||||
|
self.default_project_id = project_dict["id"]
|
||||||
|
|
||||||
|
def log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
||||||
|
try:
|
||||||
|
litellm_call_id = kwargs.get("litellm_call_id")
|
||||||
|
project_id = kwargs.get("project_id", None)
|
||||||
|
if project_id is None:
|
||||||
|
if self.default_project_id is None:
|
||||||
|
self.create_sync_default_project_and_experiment()
|
||||||
|
project_id = self.default_project_id
|
||||||
|
|
||||||
|
prompt = {"messages": kwargs.get("messages")}
|
||||||
|
|
||||||
|
if response_obj is not None and (
|
||||||
|
kwargs.get("call_type", None) == "embedding"
|
||||||
|
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = None
|
||||||
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ModelResponse
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = response_obj["choices"][0]["message"].json()
|
||||||
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.TextCompletionResponse
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = response_obj.choices[0].text
|
||||||
|
elif response_obj is not None and isinstance(
|
||||||
|
response_obj, litellm.ImageResponse
|
||||||
|
):
|
||||||
|
input = prompt
|
||||||
|
output = response_obj["data"]
|
||||||
|
|
||||||
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
|
metadata = (
|
||||||
|
litellm_params.get("metadata", {}) or {}
|
||||||
|
) # if litellm_params['metadata'] == None
|
||||||
|
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||||
|
clean_metadata = {}
|
||||||
|
try:
|
||||||
|
metadata = copy.deepcopy(
|
||||||
|
metadata
|
||||||
|
) # Avoid modifying the original metadata
|
||||||
|
except:
|
||||||
|
new_metadata = {}
|
||||||
|
for key, value in metadata.items():
|
||||||
|
if (
|
||||||
|
isinstance(value, list)
|
||||||
|
or isinstance(value, dict)
|
||||||
|
or isinstance(value, str)
|
||||||
|
or isinstance(value, int)
|
||||||
|
or isinstance(value, float)
|
||||||
|
):
|
||||||
|
new_metadata[key] = copy.deepcopy(value)
|
||||||
|
metadata = new_metadata
|
||||||
|
|
||||||
|
tags = []
|
||||||
|
if isinstance(metadata, dict):
|
||||||
|
for key, value in metadata.items():
|
||||||
|
|
||||||
|
# generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy
|
||||||
|
if (
|
||||||
|
litellm._langfuse_default_tags is not None
|
||||||
|
and isinstance(litellm._langfuse_default_tags, list)
|
||||||
|
and key in litellm._langfuse_default_tags
|
||||||
|
):
|
||||||
|
tags.append(f"{key}:{value}")
|
||||||
|
|
||||||
|
# clean litellm metadata before logging
|
||||||
|
if key in [
|
||||||
|
"headers",
|
||||||
|
"endpoint",
|
||||||
|
"caching_groups",
|
||||||
|
"previous_models",
|
||||||
|
]:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
clean_metadata[key] = value
|
||||||
|
|
||||||
|
cost = kwargs.get("response_cost", None)
|
||||||
|
if cost is not None:
|
||||||
|
clean_metadata["litellm_response_cost"] = cost
|
||||||
|
|
||||||
|
metrics: Optional[dict] = None
|
||||||
|
if (
|
||||||
|
response_obj is not None
|
||||||
|
and hasattr(response_obj, "usage")
|
||||||
|
and isinstance(response_obj.usage, litellm.Usage)
|
||||||
|
):
|
||||||
|
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
||||||
|
metrics = {
|
||||||
|
"prompt_tokens": response_obj.usage.prompt_tokens,
|
||||||
|
"completion_tokens": response_obj.usage.completion_tokens,
|
||||||
|
"total_tokens": response_obj.usage.total_tokens,
|
||||||
|
"total_cost": cost,
|
||||||
|
}
|
||||||
|
|
||||||
|
request_data = {
|
||||||
|
"id": litellm_call_id,
|
||||||
|
"input": prompt,
|
||||||
|
"output": output,
|
||||||
|
"metadata": clean_metadata,
|
||||||
|
"tags": tags,
|
||||||
|
}
|
||||||
|
if metrics is not None:
|
||||||
|
request_data["metrics"] = metrics
|
||||||
|
|
||||||
|
try:
|
||||||
|
global_braintrust_sync_http_handler.post(
|
||||||
|
url=f"{self.api_base}/project_logs/{project_id}/insert",
|
||||||
|
json={"events": [request_data]},
|
||||||
|
headers=self.headers,
|
||||||
|
)
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
raise Exception(e.response.text)
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error(
|
||||||
|
"Error logging to braintrust - Exception received - {}\n{}".format(
|
||||||
|
str(e), traceback.format_exc()
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raise e
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
verbose_logger.debug("REACHES BRAINTRUST SUCCESS")
|
||||||
try:
|
try:
|
||||||
litellm_call_id = kwargs.get("litellm_call_id")
|
litellm_call_id = kwargs.get("litellm_call_id")
|
||||||
trace_id = kwargs.get("trace_id", litellm_call_id)
|
|
||||||
project_id = kwargs.get("project_id", None)
|
project_id = kwargs.get("project_id", None)
|
||||||
if project_id is None:
|
if project_id is None:
|
||||||
if self.default_project_id is None:
|
if self.default_project_id is None:
|
||||||
|
@ -188,14 +320,6 @@ class BraintrustLogger(CustomLogger):
|
||||||
else:
|
else:
|
||||||
clean_metadata[key] = value
|
clean_metadata[key] = value
|
||||||
|
|
||||||
session_id = clean_metadata.pop("session_id", None)
|
|
||||||
trace_name = clean_metadata.pop("trace_name", None)
|
|
||||||
trace_id = clean_metadata.pop("trace_id", litellm_call_id)
|
|
||||||
existing_trace_id = clean_metadata.pop("existing_trace_id", None)
|
|
||||||
update_trace_keys = clean_metadata.pop("update_trace_keys", [])
|
|
||||||
debug = clean_metadata.pop("debug_langfuse", None)
|
|
||||||
mask_input = clean_metadata.pop("mask_input", False)
|
|
||||||
mask_output = clean_metadata.pop("mask_output", False)
|
|
||||||
cost = kwargs.get("response_cost", None)
|
cost = kwargs.get("response_cost", None)
|
||||||
if cost is not None:
|
if cost is not None:
|
||||||
clean_metadata["litellm_response_cost"] = cost
|
clean_metadata["litellm_response_cost"] = cost
|
||||||
|
|
53
litellm/tests/test_braintrust.py
Normal file
53
litellm/tests/test_braintrust.py
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
# What is this?
|
||||||
|
## This tests the braintrust integration
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import random
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
from fastapi import Request
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
|
||||||
|
def test_braintrust_logging():
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
http_client = HTTPHandler()
|
||||||
|
|
||||||
|
setattr(
|
||||||
|
litellm.integrations.braintrust_logging,
|
||||||
|
"global_braintrust_sync_http_handler",
|
||||||
|
http_client,
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch.object(http_client, "post", new=MagicMock()) as mock_client:
|
||||||
|
|
||||||
|
# set braintrust as a callback, litellm will send the data to braintrust
|
||||||
|
litellm.callbacks = ["braintrust"]
|
||||||
|
|
||||||
|
# openai call
|
||||||
|
response = litellm.completion(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hi 👋 - i'm openai"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_client.assert_called()
|
Loading…
Add table
Add a link
Reference in a new issue