build(pyproject.toml): add new dev dependencies - for type checking (#9631)

* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
This commit is contained in:
Krish Dholakia 2025-03-29 11:02:13 -07:00 committed by GitHub
parent 72198737f8
commit d7b294dd0a
214 changed files with 1553 additions and 1433 deletions

View file

@ -4,7 +4,7 @@
import copy
import os
from datetime import datetime
from typing import Optional, Dict
from typing import Dict, Optional
import httpx
from pydantic import BaseModel
@ -19,7 +19,9 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.utils import print_verbose
global_braintrust_http_handler = get_async_httpx_client(llm_provider=httpxSpecialProvider.LoggingCallback)
global_braintrust_http_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
global_braintrust_sync_http_handler = HTTPHandler()
API_BASE = "https://api.braintrustdata.com/v1"
@ -35,7 +37,9 @@ def get_utc_datetime():
class BraintrustLogger(CustomLogger):
def __init__(self, api_key: Optional[str] = None, api_base: Optional[str] = None) -> None:
def __init__(
self, api_key: Optional[str] = None, api_base: Optional[str] = None
) -> None:
super().__init__()
self.validate_environment(api_key=api_key)
self.api_base = api_base or API_BASE
@ -45,7 +49,9 @@ class BraintrustLogger(CustomLogger):
"Authorization": "Bearer " + self.api_key,
"Content-Type": "application/json",
}
self._project_id_cache: Dict[str, str] = {} # Cache mapping project names to IDs
self._project_id_cache: Dict[
str, str
] = {} # Cache mapping project names to IDs
def validate_environment(self, api_key: Optional[str]):
"""
@ -71,7 +77,9 @@ class BraintrustLogger(CustomLogger):
try:
response = global_braintrust_sync_http_handler.post(
f"{self.api_base}/project", headers=self.headers, json={"name": project_name}
f"{self.api_base}/project",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
@ -89,7 +97,9 @@ class BraintrustLogger(CustomLogger):
try:
response = await global_braintrust_http_handler.post(
f"{self.api_base}/project/register", headers=self.headers, json={"name": project_name}
f"{self.api_base}/project/register",
headers=self.headers,
json={"name": project_name},
)
project_dict = response.json()
project_id = project_dict["id"]
@ -116,15 +126,21 @@ class BraintrustLogger(CustomLogger):
if metadata is None:
metadata = {}
proxy_headers = litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
proxy_headers = (
litellm_params.get("proxy_server_request", {}).get("headers", {}) or {}
)
for metadata_param_key in proxy_headers:
if metadata_param_key.startswith("braintrust"):
trace_param_key = metadata_param_key.replace("braintrust", "", 1)
if trace_param_key in metadata:
verbose_logger.warning(f"Overwriting Braintrust `{trace_param_key}` from request header")
verbose_logger.warning(
f"Overwriting Braintrust `{trace_param_key}` from request header"
)
else:
verbose_logger.debug(f"Found Braintrust `{trace_param_key}` in request header")
verbose_logger.debug(
f"Found Braintrust `{trace_param_key}` in request header"
)
metadata[trace_param_key] = proxy_headers.get(metadata_param_key)
return metadata
@ -157,24 +173,35 @@ class BraintrustLogger(CustomLogger):
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
try:
metadata = copy.deepcopy(metadata) # Avoid modifying the original metadata
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except Exception:
new_metadata = {}
for key, value in metadata.items():
@ -192,7 +219,9 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id")
if project_id is None:
project_name = metadata.get("project_name")
project_id = self.get_project_id_sync(project_name) if project_name else None
project_id = (
self.get_project_id_sync(project_name) if project_name else None
)
if project_id is None:
if self.default_project_id is None:
@ -234,7 +263,8 @@ class BraintrustLogger(CustomLogger):
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
"time_to_first_token": end_time.timestamp() - start_time.timestamp(),
"time_to_first_token": end_time.timestamp()
- start_time.timestamp(),
"start": start_time.timestamp(),
"end": end_time.timestamp(),
}
@ -255,7 +285,9 @@ class BraintrustLogger(CustomLogger):
request_data["metrics"] = metrics
try:
print_verbose(f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}")
print_verbose(
f"global_braintrust_sync_http_handler.post: {global_braintrust_sync_http_handler.post}"
)
global_braintrust_sync_http_handler.post(
url=f"{self.api_base}/project_logs/{project_id}/insert",
json={"events": [request_data]},
@ -276,20 +308,29 @@ class BraintrustLogger(CustomLogger):
output = None
choices = []
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" or isinstance(response_obj, litellm.EmbeddingResponse)
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
output = None
elif response_obj is not None and isinstance(response_obj, litellm.ModelResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
output = response_obj["choices"][0]["message"].json()
choices = response_obj["choices"]
elif response_obj is not None and isinstance(response_obj, litellm.TextCompletionResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
output = response_obj.choices[0].text
choices = response_obj.choices
elif response_obj is not None and isinstance(response_obj, litellm.ImageResponse):
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
metadata = litellm_params.get("metadata", {}) or {} # if litellm_params['metadata'] == None
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata = self.add_metadata_from_header(litellm_params, metadata)
clean_metadata = {}
new_metadata = {}
@ -313,7 +354,11 @@ class BraintrustLogger(CustomLogger):
project_id = metadata.get("project_id")
if project_id is None:
project_name = metadata.get("project_name")
project_id = await self.get_project_id_async(project_name) if project_name else None
project_id = (
await self.get_project_id_async(project_name)
if project_name
else None
)
if project_id is None:
if self.default_project_id is None:
@ -362,8 +407,14 @@ class BraintrustLogger(CustomLogger):
api_call_start_time = kwargs.get("api_call_start_time")
completion_start_time = kwargs.get("completion_start_time")
if api_call_start_time is not None and completion_start_time is not None:
metrics["time_to_first_token"] = completion_start_time.timestamp() - api_call_start_time.timestamp()
if (
api_call_start_time is not None
and completion_start_time is not None
):
metrics["time_to_first_token"] = (
completion_start_time.timestamp()
- api_call_start_time.timestamp()
)
request_data = {
"id": litellm_call_id,