Merge pull request #5638 from BerriAI/litellm_langsmith_perf

[Langsmith Perf Improvement] Use /batch for Langsmith Logging
This commit is contained in:
Ishaan Jaff 2024-09-11 17:43:26 -07:00 committed by GitHub
commit f55318de47
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 244 additions and 54 deletions

View file

@ -165,11 +165,11 @@ jobs:
pip install pytest
pip install tiktoken
pip install aiohttp
pip install openai
pip install click
pip install "boto3==1.34.34"
pip install jinja2
pip install tokenizers
pip install openai
pip install jsonschema
- run:
name: Run tests

View file

@ -55,6 +55,7 @@ _known_custom_logger_compatible_callbacks: List = list(
)
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
langfuse_default_tags: Optional[List[str]] = None
langsmith_batch_size: Optional[int] = None
_async_input_callback: List[Callable] = (
[]
) # internal variable - async custom callbacks are routed here.

View file

@ -0,0 +1,53 @@
"""
Custom Logger that handles batching logic
Use this if you want your logs to be stored in memory and flushed periodically
"""
import asyncio
import time
from typing import List, Literal, Optional
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
DEFAULT_BATCH_SIZE = 512
DEFAULT_FLUSH_INTERVAL_SECONDS = 5
class CustomBatchLogger(CustomLogger):
def __init__(self, flush_lock: Optional[asyncio.Lock] = None, **kwargs) -> None:
"""
Args:
flush_lock (Optional[asyncio.Lock], optional): Lock to use when flushing the queue. Defaults to None. Only used for custom loggers that do batching
"""
self.log_queue: List = []
self.flush_interval = DEFAULT_FLUSH_INTERVAL_SECONDS # 10 seconds
self.batch_size = DEFAULT_BATCH_SIZE
self.last_flush_time = time.time()
self.flush_lock = flush_lock
super().__init__(**kwargs)
pass
async def periodic_flush(self):
while True:
await asyncio.sleep(self.flush_interval)
verbose_logger.debug(
f"CustomLogger periodic flush after {self.flush_interval} seconds"
)
await self.flush_queue()
async def flush_queue(self):
async with self.flush_lock:
if self.log_queue:
verbose_logger.debug(
"CustomLogger: Flushing batch of %s events", self.batch_size
)
await self.async_send_batch()
self.log_queue.clear()
self.last_flush_time = time.time()
async def async_send_batch(self):
pass

View file

@ -3,9 +3,11 @@
import asyncio
import os
import random
import time
import traceback
import types
from datetime import datetime
import uuid
from datetime import datetime, timezone
from typing import Any, List, Optional, Union
import dotenv # type: ignore
@ -15,7 +17,7 @@ from pydantic import BaseModel # type: ignore
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
get_async_httpx_client,
@ -54,9 +56,8 @@ def is_serializable(value):
return not isinstance(value, non_serializable_types)
class LangsmithLogger(CustomLogger):
# Class variables or attributes
def __init__(self):
class LangsmithLogger(CustomBatchLogger):
def __init__(self, **kwargs):
self.langsmith_api_key = os.getenv("LANGSMITH_API_KEY")
self.langsmith_project = os.getenv("LANGSMITH_PROJECT", "litellm-completion")
self.langsmith_default_run_name = os.getenv(
@ -68,6 +69,14 @@ class LangsmithLogger(CustomLogger):
self.async_httpx_client = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
_batch_size = (
os.getenv("LANGSMITH_BATCH_SIZE", None) or litellm.langsmith_batch_size
)
if _batch_size:
self.batch_size = int(_batch_size)
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
def _prepare_log_data(self, kwargs, response_obj, start_time, end_time):
import datetime
@ -170,52 +179,44 @@ class LangsmithLogger(CustomLogger):
if dotted_order:
data["dotted_order"] = dotted_order
if "id" not in data or data["id"] is None:
"""
for /batch langsmith requires id, trace_id and dotted_order passed as params
"""
run_id = uuid.uuid4()
data["id"] = str(run_id)
data["trace_id"] = str(run_id)
data["dotted_order"] = self.make_dot_order(run_id=run_id)
verbose_logger.debug("Langsmith Logging data on langsmith: %s", data)
return data
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE"))
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit()
else 1.0
)
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.debug(
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
url = f"{self.langsmith_base_url}/runs"
verbose_logger.debug(f"Langsmith Logging - About to send data to {url} ...")
def _send_batch(self):
if not self.log_queue:
return
url = f"{self.langsmith_base_url}/runs/batch"
headers = {"x-api-key": self.langsmith_api_key}
response = await self.async_httpx_client.post(
url=url, json=data, headers=headers
try:
response = requests.post(
url=url,
json=self.log_queue,
headers=headers,
)
if response.status_code >= 300:
verbose_logger.error(
f"Langmsith Error: {response.status_code} - {response.text}"
f"Langsmith Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
"Run successfully created, response=%s", response.text
f"Batch of {len(self.log_queue)} runs successfully created"
)
verbose_logger.debug(
f"Langsmith Layer Logging - final response object: {response_obj}. Response text from langsmith={response.text}"
)
except:
self.log_queue.clear()
except Exception as e:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -240,25 +241,95 @@ class LangsmithLogger(CustomLogger):
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
url = f"{self.langsmith_base_url}/runs"
verbose_logger.debug(f"Langsmith Logging - About to send data to {url} ...")
response = requests.post(
url=url,
json=data,
headers={"x-api-key": self.langsmith_api_key},
)
if response.status_code >= 300:
verbose_logger.error(f"Error: {response.status_code} - {response.text}")
else:
verbose_logger.debug("Run successfully created")
self.log_queue.append(data)
verbose_logger.debug(
f"Langsmith Layer Logging - final response object: {response_obj}. Response text from langsmith={response.text}"
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
)
if len(self.log_queue) >= self.batch_size:
self._send_batch()
except:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}")
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
float(os.getenv("LANGSMITH_SAMPLING_RATE"))
if os.getenv("LANGSMITH_SAMPLING_RATE") is not None
and os.getenv("LANGSMITH_SAMPLING_RATE").strip().isdigit()
else 1.0
)
random_sample = random.random()
if random_sample > sampling_rate:
verbose_logger.info(
"Skipping Langsmith logging. Sampling rate={}, random_sample={}".format(
sampling_rate, random_sample
)
)
return # Skip logging
verbose_logger.debug(
"Langsmith Async Layer Logging - kwargs: %s, response_obj: %s",
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
self.batch_size,
)
if len(self.log_queue) >= self.batch_size:
await self.flush_queue()
except:
verbose_logger.error(f"Langsmith Layer Error - {traceback.format_exc()}")
async def async_send_batch(self):
"""
sends runs to /batch endpoint
Sends runs from self.log_queue
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
import json
if not self.log_queue:
return
url = f"{self.langsmith_base_url}/runs/batch"
headers = {"x-api-key": self.langsmith_api_key}
try:
response = await self.async_httpx_client.post(
url=url,
json={
"post": self.log_queue,
},
headers=headers,
)
response.raise_for_status()
if response.status_code >= 300:
verbose_logger.error(
f"Langsmith Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
f"Batch of {len(self.log_queue)} runs successfully created"
)
except httpx.HTTPStatusError as e:
verbose_logger.exception(
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
)
except Exception as e:
verbose_logger.exception(
f"Langsmith Layer Error - {traceback.format_exc()}"
)
def get_run_by_id(self, run_id):
url = f"{self.langsmith_base_url}/runs/{run_id}"
@ -268,3 +339,8 @@ class LangsmithLogger(CustomLogger):
)
return response.json()
def make_dot_order(self, run_id: str):
st = datetime.now(timezone.utc)
id_ = run_id
return st.strftime("%Y%m%dT%H%M%S%fZ") + str(id_)

View file

@ -14,3 +14,8 @@ model_list:
general_settings:
master_key: sk-1234
litellm_settings:
success_callback: ["langsmith", "prometheus"]
service_callback: ["prometheus_system"]
callbacks: ["otel"]

View file

@ -52,6 +52,7 @@ VERTEX_MODELS_TO_NOT_TEST = [
"gemini-1.5-pro-preview-0215",
"gemini-pro-experimental",
"gemini-flash-experimental",
"gemini-1.5-flash-exp-0827",
"gemini-pro-flash",
]

View file

@ -22,6 +22,61 @@ litellm.set_verbose = True
import time
@pytest.mark.asyncio
async def test_langsmith_queue_logging():
try:
# Initialize LangsmithLogger
test_langsmith_logger = LangsmithLogger()
litellm.callbacks = [test_langsmith_logger]
test_langsmith_logger.batch_size = 6
litellm.set_verbose = True
# Make multiple calls to ensure we don't hit the batch size
for _ in range(5):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
)
await asyncio.sleep(3)
# Check that logs are in the queue
assert len(test_langsmith_logger.log_queue) == 5
# Now make calls to exceed the batch size
for _ in range(3):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
)
# Wait a short time for any asynchronous operations to complete
await asyncio.sleep(1)
print(
"Length of langsmith log queue: {}".format(
len(test_langsmith_logger.log_queue)
)
)
# Check that the queue was flushed after exceeding batch size
assert len(test_langsmith_logger.log_queue) < 5
# Clean up
for cb in litellm.callbacks:
if isinstance(cb, LangsmithLogger):
await cb.async_httpx_client.client.aclose()
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.skip(reason="Flaky test. covered by unit tests on custom logger.")
@pytest.mark.asyncio()
async def test_async_langsmith_logging():

View file

@ -9,7 +9,6 @@ gunicorn==22.0.0 # server dep
boto3==1.34.34 # aws bedrock/sagemaker calls
redis==5.0.0 # caching
numpy==1.24.3 # semantic caching
pandas==2.1.1 # for viewing clickhouse spend analytics
prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions
pynacl==1.5.0 # for encrypting keys