(Feat) Add langsmith key based logging (#6682)

* add langsmith_api_key to StandardCallbackDynamicParams

* create a file for langsmith types

* langsmith add key / team based logging

* add key based logging for langsmith

* fix langsmith key based logging

* fix linting langsmith

* remove NOQA violation

* add unit test coverage for all helpers in test langsmith

* test_langsmith_key_based_logging

* docs langsmith key based logging

* run langsmith tests in logging callback tests

* fix logging testing

* test_langsmith_key_based_logging

* test_add_callback_via_key_litellm_pre_call_utils_langsmith

* add debug statement langsmith key based logging

* test_langsmith_key_based_logging
This commit is contained in:
Ishaan Jaff 2024-11-11 13:58:06 -08:00 committed by GitHub
parent 1e2ba3e045
commit c3bc9e6b12
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 810 additions and 179 deletions

View file

@ -686,6 +686,7 @@ jobs:
pip install "pytest-retry==1.6.3"
pip install "pytest-cov==5.0.0"
pip install "pytest-asyncio==0.21.1"
pip install pytest-mock
pip install "respx==0.21.1"
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"

View file

@ -281,6 +281,51 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \
}'
```
</TabItem>
<TabItem label="Langsmith" value="langsmith">
1. Create Virtual Key to log to a specific Langsmith Project
```bash
curl -X POST 'http://0.0.0.0:4000/key/generate' \
-H 'Authorization: Bearer sk-1234' \
-H 'Content-Type: application/json' \
-d '{
"metadata": {
"logging": [{
"callback_name": "langsmith", # "otel", "gcs_bucket"
"callback_type": "success", # "success", "failure", "success_and_failure"
"callback_vars": {
"langsmith_api_key": "os.environ/LANGSMITH_API_KEY", # API Key for Langsmith logging
"langsmith_project": "pr-brief-resemblance-72", # project name on langsmith
"langsmith_base_url": "https://api.smith.langchain.com"
}
}]
}
}'
```
2. Test it - `/chat/completions` request
Use the virtual key from step 3 to make a `/chat/completions` request
You should see your logs on your Langsmith project on a successful request
```shell
curl -i http://localhost:4000/v1/chat/completions \
-H "Content-Type: application/json" \
-H "Authorization: Bearer sk-Fxq5XSyWKeXDKfPdqXZhPg" \
-d '{
"model": "fake-openai-endpoint",
"messages": [
{"role": "user", "content": "Hello, Claude"}
],
"user": "hello",
}'
```
</TabItem>
</Tabs>

View file

@ -23,34 +23,8 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
httpxSpecialProvider,
)
from litellm.types.utils import StandardLoggingPayload
class LangsmithInputs(BaseModel):
model: Optional[str] = None
messages: Optional[List[Any]] = None
stream: Optional[bool] = None
call_type: Optional[str] = None
litellm_call_id: Optional[str] = None
completion_start_time: Optional[datetime] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
custom_llm_provider: Optional[str] = None
input: Optional[List[Any]] = None
log_event_type: Optional[str] = None
original_response: Optional[Any] = None
response_cost: Optional[float] = None
# LiteLLM Virtual Key specific fields
user_api_key: Optional[str] = None
user_api_key_user_id: Optional[str] = None
user_api_key_team_alias: Optional[str] = None
class LangsmithCredentialsObject(TypedDict):
LANGSMITH_API_KEY: str
LANGSMITH_PROJECT: str
LANGSMITH_BASE_URL: str
from litellm.types.integrations.langsmith import *
from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload
def is_serializable(value):
@ -93,15 +67,16 @@ class LangsmithLogger(CustomBatchLogger):
)
if _batch_size:
self.batch_size = int(_batch_size)
self.log_queue: List[LangsmithQueueObject] = []
asyncio.create_task(self.periodic_flush())
self.flush_lock = asyncio.Lock()
super().__init__(**kwargs, flush_lock=self.flush_lock)
def get_credentials_from_env(
self,
langsmith_api_key: Optional[str],
langsmith_project: Optional[str],
langsmith_base_url: Optional[str],
langsmith_api_key: Optional[str] = None,
langsmith_project: Optional[str] = None,
langsmith_base_url: Optional[str] = None,
) -> LangsmithCredentialsObject:
_credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY")
@ -132,42 +107,19 @@ class LangsmithLogger(CustomBatchLogger):
LANGSMITH_PROJECT=_credentials_project,
)
def _prepare_log_data( # noqa: PLR0915
self, kwargs, response_obj, start_time, end_time
def _prepare_log_data(
self,
kwargs,
response_obj,
start_time,
end_time,
credentials: LangsmithCredentialsObject,
):
import json
from datetime import datetime as dt
try:
_litellm_params = kwargs.get("litellm_params", {}) or {}
metadata = _litellm_params.get("metadata", {}) or {}
new_metadata = {}
for key, value in metadata.items():
if (
isinstance(value, list)
or isinstance(value, str)
or isinstance(value, int)
or isinstance(value, float)
):
new_metadata[key] = value
elif isinstance(value, BaseModel):
new_metadata[key] = value.model_dump_json()
elif isinstance(value, dict):
for k, v in value.items():
if isinstance(v, dt):
value[k] = v.isoformat()
new_metadata[key] = value
metadata = new_metadata
kwargs["user_api_key"] = metadata.get("user_api_key", None)
kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None)
kwargs["user_api_key_team_alias"] = metadata.get(
"user_api_key_team_alias", None
)
project_name = metadata.get(
"project_name", self.default_credentials["LANGSMITH_PROJECT"]
"project_name", credentials["LANGSMITH_PROJECT"]
)
run_name = metadata.get("run_name", self.langsmith_default_run_name)
run_id = metadata.get("id", None)
@ -175,16 +127,10 @@ class LangsmithLogger(CustomBatchLogger):
trace_id = metadata.get("trace_id", None)
session_id = metadata.get("session_id", None)
dotted_order = metadata.get("dotted_order", None)
tags = metadata.get("tags", []) or []
verbose_logger.debug(
f"Langsmith Logging - project_name: {project_name}, run_name {run_name}"
)
# filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs
# logged_kwargs = LangsmithInputs(**kwargs)
# kwargs = logged_kwargs.model_dump()
# new_kwargs = {}
# Ensure everything in the payload is converted to str
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object", None
@ -193,7 +139,6 @@ class LangsmithLogger(CustomBatchLogger):
if payload is None:
raise Exception("Error logging request payload. Payload=none.")
new_kwargs = payload
metadata = payload[
"metadata"
] # ensure logged metadata is json serializable
@ -201,12 +146,12 @@ class LangsmithLogger(CustomBatchLogger):
data = {
"name": run_name,
"run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain"
"inputs": new_kwargs,
"outputs": new_kwargs["response"],
"inputs": payload,
"outputs": payload["response"],
"session_name": project_name,
"start_time": new_kwargs["startTime"],
"end_time": new_kwargs["endTime"],
"tags": tags,
"start_time": payload["startTime"],
"end_time": payload["endTime"],
"tags": payload["request_tags"],
"extra": metadata,
}
@ -243,37 +188,6 @@ class LangsmithLogger(CustomBatchLogger):
except Exception:
raise
def _send_batch(self):
if not self.log_queue:
return
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
url = f"{langsmith_api_base}/runs/batch"
headers = {"x-api-key": langsmith_api_key}
try:
response = requests.post(
url=url,
json=self.log_queue,
headers=headers,
)
if response.status_code >= 300:
verbose_logger.error(
f"Langsmith Error: {response.status_code} - {response.text}"
)
else:
verbose_logger.debug(
f"Batch of {len(self.log_queue)} runs successfully created"
)
self.log_queue.clear()
except Exception:
verbose_logger.exception("Langsmith Layer Error - Error sending batch.")
def log_success_event(self, kwargs, response_obj, start_time, end_time):
try:
sampling_rate = (
@ -295,8 +209,20 @@ class LangsmithLogger(CustomBatchLogger):
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
data = self._prepare_log_data(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
credentials=credentials,
)
self.log_queue.append(
LangsmithQueueObject(
data=data,
credentials=credentials,
)
)
verbose_logger.debug(
f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..."
)
@ -323,8 +249,20 @@ class LangsmithLogger(CustomBatchLogger):
kwargs,
response_obj,
)
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
data = self._prepare_log_data(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
credentials=credentials,
)
self.log_queue.append(
LangsmithQueueObject(
data=data,
credentials=credentials,
)
)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
@ -349,8 +287,20 @@ class LangsmithLogger(CustomBatchLogger):
return # Skip logging
verbose_logger.info("Langsmith Failure Event Logging!")
try:
data = self._prepare_log_data(kwargs, response_obj, start_time, end_time)
self.log_queue.append(data)
credentials = self._get_credentials_to_use_for_request(kwargs=kwargs)
data = self._prepare_log_data(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
credentials=credentials,
)
self.log_queue.append(
LangsmithQueueObject(
data=data,
credentials=credentials,
)
)
verbose_logger.debug(
"Langsmith logging: queue length %s, batch size %s",
len(self.log_queue),
@ -365,31 +315,58 @@ class LangsmithLogger(CustomBatchLogger):
async def async_send_batch(self):
"""
sends runs to /batch endpoint
Handles sending batches of runs to Langsmith
Sends runs from self.log_queue
self.log_queue contains LangsmithQueueObjects
Each LangsmithQueueObject has the following:
- "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url)
- "data" - data to log on to langsmith for the request
This function
- groups the queue objects by credentials
- loops through each unique credentials and sends batches to Langsmith
This was added to support key/team based logging on langsmith
"""
if not self.log_queue:
return
batch_groups = self._group_batches_by_credentials()
for batch_group in batch_groups.values():
await self._log_batch_on_langsmith(
credentials=batch_group.credentials,
queue_objects=batch_group.queue_objects,
)
async def _log_batch_on_langsmith(
self,
credentials: LangsmithCredentialsObject,
queue_objects: List[LangsmithQueueObject],
):
"""
Logs a batch of runs to Langsmith
sends runs to /batch endpoint for the given credentials
Args:
credentials: LangsmithCredentialsObject
queue_objects: List[LangsmithQueueObject]
Returns: None
Raises: Does not raise an exception, will only verbose_logger.exception()
"""
if not self.log_queue:
return
langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"]
langsmith_api_base = credentials["LANGSMITH_BASE_URL"]
langsmith_api_key = credentials["LANGSMITH_API_KEY"]
url = f"{langsmith_api_base}/runs/batch"
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]
headers = {"x-api-key": langsmith_api_key}
elements_to_log = [queue_object["data"] for queue_object in queue_objects]
try:
response = await self.async_httpx_client.post(
url=url,
json={
"post": self.log_queue,
},
json={"post": elements_to_log},
headers=headers,
)
response.raise_for_status()
@ -411,6 +388,74 @@ class LangsmithLogger(CustomBatchLogger):
f"Langsmith Layer Error - {traceback.format_exc()}"
)
def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]:
"""Groups queue objects by credentials using a proper key structure"""
log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {}
for queue_object in self.log_queue:
credentials = queue_object["credentials"]
key = CredentialsKey(
api_key=credentials["LANGSMITH_API_KEY"],
project=credentials["LANGSMITH_PROJECT"],
base_url=credentials["LANGSMITH_BASE_URL"],
)
if key not in log_queue_by_credentials:
log_queue_by_credentials[key] = BatchGroup(
credentials=credentials, queue_objects=[]
)
log_queue_by_credentials[key].queue_objects.append(queue_object)
return log_queue_by_credentials
def _get_credentials_to_use_for_request(
self, kwargs: Dict[str, Any]
) -> LangsmithCredentialsObject:
"""
Handles key/team based logging
If standard_callback_dynamic_params are provided, use those credentials.
Otherwise, use the default credentials.
"""
standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = (
kwargs.get("standard_callback_dynamic_params", None)
)
if standard_callback_dynamic_params is not None:
credentials = self.get_credentials_from_env(
langsmith_api_key=standard_callback_dynamic_params.get(
"langsmith_api_key", None
),
langsmith_project=standard_callback_dynamic_params.get(
"langsmith_project", None
),
langsmith_base_url=standard_callback_dynamic_params.get(
"langsmith_base_url", None
),
)
else:
credentials = self.default_credentials
return credentials
def _send_batch(self):
"""Calls async_send_batch in an event loop"""
if not self.log_queue:
return
try:
# Try to get the existing event loop
loop = asyncio.get_event_loop()
if loop.is_running():
# If we're already in an event loop, create a task
asyncio.create_task(self.async_send_batch())
else:
# If no event loop is running, run the coroutine directly
loop.run_until_complete(self.async_send_batch())
except RuntimeError:
# If we can't get an event loop, create a new one
asyncio.run(self.async_send_batch())
def get_run_by_id(self, run_id):
langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"]

View file

@ -6,5 +6,7 @@ model_list:
api_base: https://exampleopenaiendpoint-production.up.railway.app/
litellm_settings:
callbacks: ["gcs_bucket"]

View file

@ -0,0 +1,61 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, Dict, List, NamedTuple, Optional, TypedDict
from pydantic import BaseModel
class LangsmithInputs(BaseModel):
model: Optional[str] = None
messages: Optional[List[Any]] = None
stream: Optional[bool] = None
call_type: Optional[str] = None
litellm_call_id: Optional[str] = None
completion_start_time: Optional[datetime] = None
temperature: Optional[float] = None
max_tokens: Optional[int] = None
custom_llm_provider: Optional[str] = None
input: Optional[List[Any]] = None
log_event_type: Optional[str] = None
original_response: Optional[Any] = None
response_cost: Optional[float] = None
# LiteLLM Virtual Key specific fields
user_api_key: Optional[str] = None
user_api_key_user_id: Optional[str] = None
user_api_key_team_alias: Optional[str] = None
class LangsmithCredentialsObject(TypedDict):
LANGSMITH_API_KEY: str
LANGSMITH_PROJECT: str
LANGSMITH_BASE_URL: str
class LangsmithQueueObject(TypedDict):
"""
Langsmith Queue Object - this is what gets stored in the internal system queue before flushing to Langsmith
We need to store:
- data[Dict] - data that should get logged on langsmith
- credentials[LangsmithCredentialsObject] - credentials to use for logging to langsmith
"""
data: Dict
credentials: LangsmithCredentialsObject
class CredentialsKey(NamedTuple):
"""Immutable key for grouping credentials"""
api_key: str
project: str
base_url: str
@dataclass
class BatchGroup:
"""Groups credentials with their associated queue objects"""
credentials: LangsmithCredentialsObject
queue_objects: List[LangsmithQueueObject]

View file

@ -1595,3 +1595,8 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
# GCS dynamic params
gcs_bucket_name: Optional[str]
gcs_path_service_account: Optional[str]
# Langsmith dynamic params
langsmith_api_key: Optional[str]
langsmith_project: Optional[str]
langsmith_base_url: Optional[str]

View file

@ -22,61 +22,6 @@ litellm.set_verbose = True
import time
@pytest.mark.asyncio
async def test_langsmith_queue_logging():
try:
# Initialize LangsmithLogger
test_langsmith_logger = LangsmithLogger()
litellm.callbacks = [test_langsmith_logger]
test_langsmith_logger.batch_size = 6
litellm.set_verbose = True
# Make multiple calls to ensure we don't hit the batch size
for _ in range(5):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
)
await asyncio.sleep(3)
# Check that logs are in the queue
assert len(test_langsmith_logger.log_queue) == 5
# Now make calls to exceed the batch size
for _ in range(3):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
)
# Wait a short time for any asynchronous operations to complete
await asyncio.sleep(1)
print(
"Length of langsmith log queue: {}".format(
len(test_langsmith_logger.log_queue)
)
)
# Check that the queue was flushed after exceeding batch size
assert len(test_langsmith_logger.log_queue) < 5
# Clean up
for cb in litellm.callbacks:
if isinstance(cb, LangsmithLogger):
await cb.async_httpx_client.client.aclose()
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_langsmith_logging()

View file

@ -0,0 +1,394 @@
import io
import os
import sys
sys.path.insert(0, os.path.abspath("../.."))
import asyncio
import gzip
import json
import logging
import time
from unittest.mock import AsyncMock, patch, MagicMock
import pytest
from datetime import datetime, timezone
from litellm.integrations.langsmith import (
LangsmithLogger,
LangsmithQueueObject,
CredentialsKey,
BatchGroup,
)
import litellm
# Test get_credentials_from_env
@pytest.mark.asyncio
async def test_get_credentials_from_env():
# Test with direct parameters
logger = LangsmithLogger(
langsmith_api_key="test-key",
langsmith_project="test-project",
langsmith_base_url="http://test-url",
)
credentials = logger.get_credentials_from_env(
langsmith_api_key="custom-key",
langsmith_project="custom-project",
langsmith_base_url="http://custom-url",
)
assert credentials["LANGSMITH_API_KEY"] == "custom-key"
assert credentials["LANGSMITH_PROJECT"] == "custom-project"
assert credentials["LANGSMITH_BASE_URL"] == "http://custom-url"
# assert that the default api base is used if not provided
credentials = logger.get_credentials_from_env()
assert credentials["LANGSMITH_BASE_URL"] == "https://api.smith.langchain.com"
@pytest.mark.asyncio
async def test_group_batches_by_credentials():
logger = LangsmithLogger(langsmith_api_key="test-key")
# Create test queue objects
queue_obj1 = LangsmithQueueObject(
data={"test": "data1"},
credentials={
"LANGSMITH_API_KEY": "key1",
"LANGSMITH_PROJECT": "proj1",
"LANGSMITH_BASE_URL": "url1",
},
)
queue_obj2 = LangsmithQueueObject(
data={"test": "data2"},
credentials={
"LANGSMITH_API_KEY": "key1",
"LANGSMITH_PROJECT": "proj1",
"LANGSMITH_BASE_URL": "url1",
},
)
logger.log_queue = [queue_obj1, queue_obj2]
grouped = logger._group_batches_by_credentials()
# Check grouping
assert len(grouped) == 1 # Should have one group since credentials are same
key = list(grouped.keys())[0]
assert isinstance(key, CredentialsKey)
assert len(grouped[key].queue_objects) == 2
@pytest.mark.asyncio
async def test_group_batches_by_credentials_multiple_credentials():
# Test with multiple different credentials
logger = LangsmithLogger(langsmith_api_key="test-key")
queue_obj1 = LangsmithQueueObject(
data={"test": "data1"},
credentials={
"LANGSMITH_API_KEY": "key1",
"LANGSMITH_PROJECT": "proj1",
"LANGSMITH_BASE_URL": "url1",
},
)
queue_obj2 = LangsmithQueueObject(
data={"test": "data2"},
credentials={
"LANGSMITH_API_KEY": "key2", # Different API key
"LANGSMITH_PROJECT": "proj1",
"LANGSMITH_BASE_URL": "url1",
},
)
queue_obj3 = LangsmithQueueObject(
data={"test": "data3"},
credentials={
"LANGSMITH_API_KEY": "key1",
"LANGSMITH_PROJECT": "proj2", # Different project
"LANGSMITH_BASE_URL": "url1",
},
)
logger.log_queue = [queue_obj1, queue_obj2, queue_obj3]
grouped = logger._group_batches_by_credentials()
# Check grouping
assert len(grouped) == 3 # Should have three groups since credentials differ
for key, batch_group in grouped.items():
assert isinstance(key, CredentialsKey)
assert len(batch_group.queue_objects) == 1 # Each group should have one object
# Test make_dot_order
@pytest.mark.asyncio
async def test_make_dot_order():
logger = LangsmithLogger(langsmith_api_key="test-key")
run_id = "729cff0e-f30c-4336-8b79-45d6b61c64b4"
dot_order = logger.make_dot_order(run_id)
print("dot_order=", dot_order)
# Check format: YYYYMMDDTHHMMSSfffZ + run_id
# Check the timestamp portion (first 23 characters)
timestamp_part = dot_order[:-36] # 36 is length of run_id
assert len(timestamp_part) == 22
assert timestamp_part[8] == "T" # Check T separator
assert timestamp_part[-1] == "Z" # Check Z suffix
# Verify timestamp format
try:
# Parse the timestamp portion (removing the Z)
datetime.strptime(timestamp_part[:-1], "%Y%m%dT%H%M%S%f")
except ValueError:
pytest.fail("Timestamp portion is not in correct format")
# Verify run_id portion
assert dot_order[-36:] == run_id
# Test is_serializable
@pytest.mark.asyncio
async def test_is_serializable():
from litellm.integrations.langsmith import is_serializable
from pydantic import BaseModel
# Test basic types
assert is_serializable("string") is True
assert is_serializable(123) is True
assert is_serializable({"key": "value"}) is True
# Test non-serializable types
async def async_func():
pass
assert is_serializable(async_func) is False
class TestModel(BaseModel):
field: str
assert is_serializable(TestModel(field="test")) is False
@pytest.mark.asyncio
async def test_async_send_batch():
logger = LangsmithLogger(langsmith_api_key="test-key")
# Mock the httpx client
mock_response = AsyncMock()
mock_response.status_code = 200
logger.async_httpx_client = AsyncMock()
logger.async_httpx_client.post.return_value = mock_response
# Add test data to queue
logger.log_queue = [
LangsmithQueueObject(
data={"test": "data"}, credentials=logger.default_credentials
)
]
await logger.async_send_batch()
# Verify the API call
logger.async_httpx_client.post.assert_called_once()
call_args = logger.async_httpx_client.post.call_args
assert "runs/batch" in call_args[1]["url"]
assert "x-api-key" in call_args[1]["headers"]
@pytest.mark.asyncio
async def test_langsmith_key_based_logging(mocker):
"""
In key based logging langsmith_api_key and langsmith_project are passed directly to litellm.acompletion
"""
try:
# Mock the httpx post request
mock_post = mocker.patch(
"litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post"
)
mock_post.return_value.status_code = 200
mock_post.return_value.raise_for_status = lambda: None
litellm.set_verbose = True
litellm.callbacks = [LangsmithLogger()]
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
langsmith_api_key="fake_key_project2",
langsmith_project="fake_project2",
)
print("Waiting for logs to be flushed to Langsmith.....")
await asyncio.sleep(15)
print("done sleeping 15 seconds...")
# Verify the post request was made with correct parameters
mock_post.assert_called_once()
call_args = mock_post.call_args
print("call_args", call_args)
# Check URL contains /runs/batch
assert "/runs/batch" in call_args[1]["url"]
# Check headers contain the correct API key
assert call_args[1]["headers"]["x-api-key"] == "fake_key_project2"
# Verify the request body contains the expected data
request_body = call_args[1]["json"]
assert "post" in request_body
assert len(request_body["post"]) == 1 # Should contain one run
# EXPECTED BODY
expected_body = {
"post": [
{
"name": "LLMRun",
"run_type": "llm",
"inputs": {
"id": "chatcmpl-82699ee4-7932-4fc0-9585-76abc8caeafa",
"call_type": "acompletion",
"model": "gpt-3.5-turbo",
"messages": [{"role": "user", "content": "Test message"}],
"model_parameters": {
"temperature": 0.2,
"max_tokens": 10,
"extra_body": {},
},
},
"outputs": {
"id": "chatcmpl-82699ee4-7932-4fc0-9585-76abc8caeafa",
"model": "gpt-3.5-turbo",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"message": {
"content": "This is a mock response",
"role": "assistant",
"tool_calls": None,
"function_call": None,
},
}
],
"usage": {
"completion_tokens": 20,
"prompt_tokens": 10,
"total_tokens": 30,
},
},
"session_name": "fake_project2",
}
]
}
# Print both bodies for debugging
actual_body = call_args[1]["json"]
print("\nExpected body:")
print(json.dumps(expected_body, indent=2))
print("\nActual body:")
print(json.dumps(actual_body, indent=2))
assert len(actual_body["post"]) == 1
# Assert only the critical parts we care about
assert actual_body["post"][0]["name"] == expected_body["post"][0]["name"]
assert (
actual_body["post"][0]["run_type"] == expected_body["post"][0]["run_type"]
)
assert (
actual_body["post"][0]["inputs"]["messages"]
== expected_body["post"][0]["inputs"]["messages"]
)
assert (
actual_body["post"][0]["inputs"]["model_parameters"]
== expected_body["post"][0]["inputs"]["model_parameters"]
)
assert (
actual_body["post"][0]["outputs"]["choices"]
== expected_body["post"][0]["outputs"]["choices"]
)
assert (
actual_body["post"][0]["outputs"]["usage"]["completion_tokens"]
== expected_body["post"][0]["outputs"]["usage"]["completion_tokens"]
)
assert (
actual_body["post"][0]["outputs"]["usage"]["prompt_tokens"]
== expected_body["post"][0]["outputs"]["usage"]["prompt_tokens"]
)
assert (
actual_body["post"][0]["outputs"]["usage"]["total_tokens"]
== expected_body["post"][0]["outputs"]["usage"]["total_tokens"]
)
assert (
actual_body["post"][0]["session_name"]
== expected_body["post"][0]["session_name"]
)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_langsmith_queue_logging():
try:
# Initialize LangsmithLogger
test_langsmith_logger = LangsmithLogger()
litellm.callbacks = [test_langsmith_logger]
test_langsmith_logger.batch_size = 6
litellm.set_verbose = True
# Make multiple calls to ensure we don't hit the batch size
for _ in range(5):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
)
await asyncio.sleep(3)
# Check that logs are in the queue
assert len(test_langsmith_logger.log_queue) == 5
# Now make calls to exceed the batch size
for _ in range(3):
response = await litellm.acompletion(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Test message"}],
max_tokens=10,
temperature=0.2,
mock_response="This is a mock response",
)
# Wait a short time for any asynchronous operations to complete
await asyncio.sleep(1)
print(
"Length of langsmith log queue: {}".format(
len(test_langsmith_logger.log_queue)
)
)
# Check that the queue was flushed after exceeding batch size
assert len(test_langsmith_logger.log_queue) < 5
# Clean up
for cb in litellm.callbacks:
if isinstance(cb, LangsmithLogger):
await cb.async_httpx_client.client.aclose()
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -1632,6 +1632,139 @@ async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket(
assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
@pytest.mark.parametrize(
"callback_type, expected_success_callbacks, expected_failure_callbacks",
[
("success", ["langsmith"], []),
("failure", [], ["langsmith"]),
("success_and_failure", ["langsmith"], ["langsmith"]),
],
)
async def test_add_callback_via_key_litellm_pre_call_utils_langsmith(
prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks
):
import json
from fastapi import HTTPException, Request, Response
from starlette.datastructures import URL
from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request
setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client)
setattr(litellm.proxy.proxy_server, "master_key", "sk-1234")
await litellm.proxy.proxy_server.prisma_client.connect()
proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config")
request = Request(scope={"type": "http", "method": "POST", "headers": {}})
request._url = URL(url="/chat/completions")
test_data = {
"model": "azure/chatgpt-v-2",
"messages": [
{"role": "user", "content": "write 1 sentence poem"},
],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
}
json_bytes = json.dumps(test_data).encode("utf-8")
request._body = json_bytes
data = {
"data": {
"model": "azure/chatgpt-v-2",
"messages": [{"role": "user", "content": "write 1 sentence poem"}],
"max_tokens": 10,
"mock_response": "Hello world",
"api_key": "my-fake-key",
},
"request": request,
"user_api_key_dict": UserAPIKeyAuth(
token=None,
key_name=None,
key_alias=None,
spend=0.0,
max_budget=None,
expires=None,
models=[],
aliases={},
config={},
user_id=None,
team_id=None,
max_parallel_requests=None,
metadata={
"logging": [
{
"callback_name": "langsmith",
"callback_type": callback_type,
"callback_vars": {
"langsmith_api_key": "ls-1234",
"langsmith_project": "pr-brief-resemblance-72",
"langsmith_base_url": "https://api.smith.langchain.com",
},
}
]
},
tpm_limit=None,
rpm_limit=None,
budget_duration=None,
budget_reset_at=None,
allowed_cache_controls=[],
permissions={},
model_spend={},
model_max_budget={},
soft_budget_cooldown=False,
litellm_budget_table=None,
org_id=None,
team_spend=None,
team_alias=None,
team_tpm_limit=None,
team_rpm_limit=None,
team_max_budget=None,
team_models=[],
team_blocked=False,
soft_budget=None,
team_model_aliases=None,
team_member_spend=None,
team_metadata=None,
end_user_id=None,
end_user_tpm_limit=None,
end_user_rpm_limit=None,
end_user_max_budget=None,
last_refreshed_at=None,
api_key=None,
user_role=None,
allowed_model_region=None,
parent_otel_span=None,
),
"proxy_config": proxy_config,
"general_settings": {},
"version": "0.0.0",
}
new_data = await add_litellm_data_to_request(**data)
print("NEW DATA: {}".format(new_data))
assert "langsmith_api_key" in new_data
assert new_data["langsmith_api_key"] == "ls-1234"
assert "langsmith_project" in new_data
assert new_data["langsmith_project"] == "pr-brief-resemblance-72"
assert "langsmith_base_url" in new_data
assert new_data["langsmith_base_url"] == "https://api.smith.langchain.com"
if expected_success_callbacks:
assert "success_callback" in new_data
assert new_data["success_callback"] == expected_success_callbacks
if expected_failure_callbacks:
assert "failure_callback" in new_data
assert new_data["failure_callback"] == expected_failure_callbacks
@pytest.mark.asyncio
async def test_gemini_pass_through_endpoint():
from starlette.datastructures import URL