From e4ab50e1a1e94e9b311d2a8b8eb15b304ecc2e36 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 22 Jul 2024 17:04:55 -0700 Subject: [PATCH] feat(braintrust_logging.py): working braintrust logging for successful calls --- litellm/__init__.py | 8 +- litellm/integrations/braintrust.py | 0 litellm/integrations/braintrust_logging.py | 245 ++++++++++++++++++ litellm/litellm_core_utils/litellm_logging.py | 12 + litellm/proxy/_new_secret_config.yaml | 4 + litellm/proxy/common_utils/init_callbacks.py | 1 + 6 files changed, 269 insertions(+), 1 deletion(-) delete mode 100644 litellm/integrations/braintrust.py create mode 100644 litellm/integrations/braintrust_logging.py diff --git a/litellm/__init__.py b/litellm/__init__.py index 7dcc934a6..bf3f77385 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -38,7 +38,13 @@ success_callback: List[Union[str, Callable]] = [] failure_callback: List[Union[str, Callable]] = [] service_callback: List[Union[str, Callable]] = [] _custom_logger_compatible_callbacks_literal = Literal[ - "lago", "openmeter", "logfire", "dynamic_rate_limiter", "langsmith", "galileo" + "lago", + "openmeter", + "logfire", + "dynamic_rate_limiter", + "langsmith", + "galileo", + "braintrust", ] callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = [] _langfuse_default_tags: Optional[ diff --git a/litellm/integrations/braintrust.py b/litellm/integrations/braintrust.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py new file mode 100644 index 000000000..8bd813b69 --- /dev/null +++ b/litellm/integrations/braintrust_logging.py @@ -0,0 +1,245 @@ +# What is this? +## Log success + failure events to Braintrust + +import copy +import json +import os +import threading +import traceback +import uuid +from typing import Literal, Optional + +import dotenv +import httpx +from braintrust import Span, SpanTypeAttribute, init, start_span + +import litellm +from litellm import verbose_logger +from litellm.integrations.custom_logger import CustomLogger +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import get_formatted_prompt + +global_braintrust_http_handler = AsyncHTTPHandler() +API_BASE = "https://api.braintrustdata.com/v1" + + +def get_utc_datetime(): + import datetime as dt + from datetime import datetime + + if hasattr(dt, "UTC"): + return datetime.now(dt.UTC) # type: ignore + else: + return datetime.utcnow() # type: ignore + + +class BraintrustLogger(CustomLogger): + def __init__( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> None: + super().__init__() + self.validate_environment(api_key=api_key) + self.api_base = api_base or API_BASE + self.default_project_id = None + self.api_key: str = api_key or os.getenv("BRAINTRUST_API_KEY") # type: ignore + self.headers = { + "Authorization": "Bearer " + self.api_key, + "Content-Type": "application/json", + } + + def validate_environment(self, api_key: Optional[str]): + """ + Expects + BRAINTRUST_API_KEY + + in the environment + """ + missing_keys = [] + if api_key is None and os.getenv("BRAINTRUST_API_KEY", None) is None: + missing_keys.append("BRAINTRUST_API_KEY") + + if len(missing_keys) > 0: + raise Exception("Missing keys={} in environment.".format(missing_keys)) + + @staticmethod + def add_metadata_from_header(litellm_params: dict, metadata: dict) -> dict: + """ + Adds metadata from proxy request headers to Langfuse logging if keys start with "langfuse_" + and overwrites litellm_params.metadata if already included. + + For example if you want to append your trace to an existing `trace_id` via header, send + `headers: { ..., langfuse_existing_trace_id: your-existing-trace-id }` via proxy request. + """ + if litellm_params is None: + return metadata + + if litellm_params.get("proxy_server_request") is None: + return metadata + + if metadata is None: + metadata = {} + + proxy_headers = ( + litellm_params.get("proxy_server_request", {}).get("headers", {}) or {} + ) + + for metadata_param_key in proxy_headers: + if metadata_param_key.startswith("braintrust"): + trace_param_key = metadata_param_key.replace("braintrust", "", 1) + if trace_param_key in metadata: + verbose_logger.warning( + f"Overwriting Braintrust `{trace_param_key}` from request header" + ) + else: + verbose_logger.debug( + f"Found Braintrust `{trace_param_key}` in request header" + ) + metadata[trace_param_key] = proxy_headers.get(metadata_param_key) + + return metadata + + async def create_default_project_and_experiment(self): + project = await global_braintrust_http_handler.post( + f"{self.api_base}/project", headers=self.headers, json={"name": "litellm"} + ) + + project_dict = project.json() + + self.default_project_id = project_dict["id"] + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + verbose_logger.debug("REACHES BRAINTRUST SUCCESS") + try: + litellm_call_id = kwargs.get("litellm_call_id") + trace_id = kwargs.get("trace_id", litellm_call_id) + project_id = kwargs.get("project_id", None) + if project_id is None: + if self.default_project_id is None: + await self.create_default_project_and_experiment() + project_id = self.default_project_id + + prompt = {"messages": kwargs.get("messages")} + + if response_obj is not None and ( + kwargs.get("call_type", None) == "embedding" + or isinstance(response_obj, litellm.EmbeddingResponse) + ): + input = prompt + output = None + elif response_obj is not None and isinstance( + response_obj, litellm.ModelResponse + ): + input = prompt + output = response_obj["choices"][0]["message"].json() + elif response_obj is not None and isinstance( + response_obj, litellm.TextCompletionResponse + ): + input = prompt + output = response_obj.choices[0].text + elif response_obj is not None and isinstance( + response_obj, litellm.ImageResponse + ): + input = prompt + output = response_obj["data"] + + litellm_params = kwargs.get("litellm_params", {}) + metadata = ( + litellm_params.get("metadata", {}) or {} + ) # if litellm_params['metadata'] == None + metadata = self.add_metadata_from_header(litellm_params, metadata) + clean_metadata = {} + try: + metadata = copy.deepcopy( + metadata + ) # Avoid modifying the original metadata + except: + new_metadata = {} + for key, value in metadata.items(): + if ( + isinstance(value, list) + or isinstance(value, dict) + or isinstance(value, str) + or isinstance(value, int) + or isinstance(value, float) + ): + new_metadata[key] = copy.deepcopy(value) + metadata = new_metadata + + tags = [] + if isinstance(metadata, dict): + for key, value in metadata.items(): + + # generate langfuse tags - Default Tags sent to Langfuse from LiteLLM Proxy + if ( + litellm._langfuse_default_tags is not None + and isinstance(litellm._langfuse_default_tags, list) + and key in litellm._langfuse_default_tags + ): + tags.append(f"{key}:{value}") + + # clean litellm metadata before logging + if key in [ + "headers", + "endpoint", + "caching_groups", + "previous_models", + ]: + continue + else: + clean_metadata[key] = value + + session_id = clean_metadata.pop("session_id", None) + trace_name = clean_metadata.pop("trace_name", None) + trace_id = clean_metadata.pop("trace_id", litellm_call_id) + existing_trace_id = clean_metadata.pop("existing_trace_id", None) + update_trace_keys = clean_metadata.pop("update_trace_keys", []) + debug = clean_metadata.pop("debug_langfuse", None) + mask_input = clean_metadata.pop("mask_input", False) + mask_output = clean_metadata.pop("mask_output", False) + cost = kwargs.get("response_cost", None) + if cost is not None: + clean_metadata["litellm_response_cost"] = cost + + metrics: Optional[dict] = None + if ( + response_obj is not None + and hasattr(response_obj, "usage") + and isinstance(response_obj.usage, litellm.Usage) + ): + generation_id = litellm.utils.get_logging_id(start_time, response_obj) + metrics = { + "prompt_tokens": response_obj.usage.prompt_tokens, + "completion_tokens": response_obj.usage.completion_tokens, + "total_tokens": response_obj.usage.total_tokens, + "total_cost": cost, + } + + request_data = { + "id": litellm_call_id, + "input": prompt, + "output": output, + "metadata": clean_metadata, + "tags": tags, + } + + if metrics is not None: + request_data["metrics"] = metrics + + try: + await global_braintrust_http_handler.post( + url=f"{self.api_base}/project_logs/{project_id}/insert", + json={"events": [request_data]}, + headers=self.headers, + ) + except httpx.HTTPStatusError as e: + raise Exception(e.response.text) + except Exception as e: + verbose_logger.error( + "Error logging to braintrust - Exception received - {}\n{}".format( + str(e), traceback.format_exc() + ) + ) + raise e + + def log_failure_event(self, kwargs, response_obj, start_time, end_time): + return super().log_failure_event(kwargs, response_obj, start_time, end_time) diff --git a/litellm/litellm_core_utils/litellm_logging.py b/litellm/litellm_core_utils/litellm_logging.py index 32633960f..17837c41e 100644 --- a/litellm/litellm_core_utils/litellm_logging.py +++ b/litellm/litellm_core_utils/litellm_logging.py @@ -53,6 +53,7 @@ from litellm.utils import ( from ..integrations.aispend import AISpendLogger from ..integrations.athina import AthinaLogger from ..integrations.berrispend import BerriSpendLogger +from ..integrations.braintrust_logging import BraintrustLogger from ..integrations.clickhouse import ClickhouseLogger from ..integrations.custom_logger import CustomLogger from ..integrations.datadog import DataDogLogger @@ -1945,7 +1946,14 @@ def _init_custom_logger_compatible_class( _openmeter_logger = OpenMeterLogger() _in_memory_loggers.append(_openmeter_logger) return _openmeter_logger # type: ignore + elif logging_integration == "braintrust": + for callback in _in_memory_loggers: + if isinstance(callback, BraintrustLogger): + return callback # type: ignore + braintrust_logger = BraintrustLogger() + _in_memory_loggers.append(braintrust_logger) + return braintrust_logger # type: ignore elif logging_integration == "langsmith": for callback in _in_memory_loggers: if isinstance(callback, LangsmithLogger): @@ -2019,6 +2027,10 @@ def get_custom_logger_compatible_class( for callback in _in_memory_loggers: if isinstance(callback, OpenMeterLogger): return callback + elif logging_integration == "braintrust": + for callback in _in_memory_loggers: + if isinstance(callback, BraintrustLogger): + return callback elif logging_integration == "galileo": for callback in _in_memory_loggers: if isinstance(callback, GalileoObserve): diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index 81244f0fa..7a35650e5 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -3,3 +3,7 @@ model_list: litellm_params: model: groq/llama3-groq-70b-8192-tool-use-preview api_key: os.environ/GROQ_API_KEY + + +litellm_settings: + callbacks: ["braintrust"] diff --git a/litellm/proxy/common_utils/init_callbacks.py b/litellm/proxy/common_utils/init_callbacks.py index 489f9b3a6..2fcceaa29 100644 --- a/litellm/proxy/common_utils/init_callbacks.py +++ b/litellm/proxy/common_utils/init_callbacks.py @@ -27,6 +27,7 @@ def initialize_callbacks_on_proxy( get_args(litellm._custom_logger_compatible_callbacks_literal) ) for callback in value: # ["presidio", ] + if isinstance(callback, str) and callback in known_compatible_callbacks: imported_list.append(callback) elif isinstance(callback, str) and callback == "otel":