From c3bc9e6b12b29414e0bb23e10f5e41952c7df914 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Mon, 11 Nov 2024 13:58:06 -0800 Subject: [PATCH] (Feat) Add langsmith key based logging (#6682) * add langsmith_api_key to StandardCallbackDynamicParams * create a file for langsmith types * langsmith add key / team based logging * add key based logging for langsmith * fix langsmith key based logging * fix linting langsmith * remove NOQA violation * add unit test coverage for all helpers in test langsmith * test_langsmith_key_based_logging * docs langsmith key based logging * run langsmith tests in logging callback tests * fix logging testing * test_langsmith_key_based_logging * test_add_callback_via_key_litellm_pre_call_utils_langsmith * add debug statement langsmith key based logging * test_langsmith_key_based_logging --- .circleci/config.yml | 1 + docs/my-website/docs/proxy/team_logging.md | 45 ++ litellm/integrations/langsmith.py | 293 +++++++------ litellm/proxy/proxy_config.yaml | 2 + litellm/types/integrations/langsmith.py | 61 +++ litellm/types/utils.py | 5 + tests/local_testing/test_langsmith.py | 55 --- .../test_langsmith_unit_test.py | 394 ++++++++++++++++++ tests/proxy_unit_tests/test_proxy_server.py | 133 ++++++ 9 files changed, 810 insertions(+), 179 deletions(-) create mode 100644 litellm/types/integrations/langsmith.py create mode 100644 tests/logging_callback_tests/test_langsmith_unit_test.py diff --git a/.circleci/config.yml b/.circleci/config.yml index 88e83fa7f..7961cfddb 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -686,6 +686,7 @@ jobs: pip install "pytest-retry==1.6.3" pip install "pytest-cov==5.0.0" pip install "pytest-asyncio==0.21.1" + pip install pytest-mock pip install "respx==0.21.1" pip install "google-generativeai==0.3.2" pip install "google-cloud-aiplatform==1.43.0" diff --git a/docs/my-website/docs/proxy/team_logging.md b/docs/my-website/docs/proxy/team_logging.md index e2fcfa4b5..8286ac449 100644 --- a/docs/my-website/docs/proxy/team_logging.md +++ b/docs/my-website/docs/proxy/team_logging.md @@ -281,6 +281,51 @@ curl -X POST 'http://0.0.0.0:4000/key/generate' \ }' ``` + + + + +1. Create Virtual Key to log to a specific Langsmith Project + + ```bash + curl -X POST 'http://0.0.0.0:4000/key/generate' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{ + "metadata": { + "logging": [{ + "callback_name": "langsmith", # "otel", "gcs_bucket" + "callback_type": "success", # "success", "failure", "success_and_failure" + "callback_vars": { + "langsmith_api_key": "os.environ/LANGSMITH_API_KEY", # API Key for Langsmith logging + "langsmith_project": "pr-brief-resemblance-72", # project name on langsmith + "langsmith_base_url": "https://api.smith.langchain.com" + } + }] + } + }' + + ``` + +2. Test it - `/chat/completions` request + + Use the virtual key from step 3 to make a `/chat/completions` request + + You should see your logs on your Langsmith project on a successful request + + ```shell + curl -i http://localhost:4000/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer sk-Fxq5XSyWKeXDKfPdqXZhPg" \ + -d '{ + "model": "fake-openai-endpoint", + "messages": [ + {"role": "user", "content": "Hello, Claude"} + ], + "user": "hello", + }' + ``` + diff --git a/litellm/integrations/langsmith.py b/litellm/integrations/langsmith.py index 951393445..4abd2a2c3 100644 --- a/litellm/integrations/langsmith.py +++ b/litellm/integrations/langsmith.py @@ -23,34 +23,8 @@ from litellm.llms.custom_httpx.http_handler import ( get_async_httpx_client, httpxSpecialProvider, ) -from litellm.types.utils import StandardLoggingPayload - - -class LangsmithInputs(BaseModel): - model: Optional[str] = None - messages: Optional[List[Any]] = None - stream: Optional[bool] = None - call_type: Optional[str] = None - litellm_call_id: Optional[str] = None - completion_start_time: Optional[datetime] = None - temperature: Optional[float] = None - max_tokens: Optional[int] = None - custom_llm_provider: Optional[str] = None - input: Optional[List[Any]] = None - log_event_type: Optional[str] = None - original_response: Optional[Any] = None - response_cost: Optional[float] = None - - # LiteLLM Virtual Key specific fields - user_api_key: Optional[str] = None - user_api_key_user_id: Optional[str] = None - user_api_key_team_alias: Optional[str] = None - - -class LangsmithCredentialsObject(TypedDict): - LANGSMITH_API_KEY: str - LANGSMITH_PROJECT: str - LANGSMITH_BASE_URL: str +from litellm.types.integrations.langsmith import * +from litellm.types.utils import StandardCallbackDynamicParams, StandardLoggingPayload def is_serializable(value): @@ -93,15 +67,16 @@ class LangsmithLogger(CustomBatchLogger): ) if _batch_size: self.batch_size = int(_batch_size) + self.log_queue: List[LangsmithQueueObject] = [] asyncio.create_task(self.periodic_flush()) self.flush_lock = asyncio.Lock() super().__init__(**kwargs, flush_lock=self.flush_lock) def get_credentials_from_env( self, - langsmith_api_key: Optional[str], - langsmith_project: Optional[str], - langsmith_base_url: Optional[str], + langsmith_api_key: Optional[str] = None, + langsmith_project: Optional[str] = None, + langsmith_base_url: Optional[str] = None, ) -> LangsmithCredentialsObject: _credentials_api_key = langsmith_api_key or os.getenv("LANGSMITH_API_KEY") @@ -132,42 +107,19 @@ class LangsmithLogger(CustomBatchLogger): LANGSMITH_PROJECT=_credentials_project, ) - def _prepare_log_data( # noqa: PLR0915 - self, kwargs, response_obj, start_time, end_time + def _prepare_log_data( + self, + kwargs, + response_obj, + start_time, + end_time, + credentials: LangsmithCredentialsObject, ): - import json - from datetime import datetime as dt - try: _litellm_params = kwargs.get("litellm_params", {}) or {} metadata = _litellm_params.get("metadata", {}) or {} - new_metadata = {} - for key, value in metadata.items(): - if ( - isinstance(value, list) - or isinstance(value, str) - or isinstance(value, int) - or isinstance(value, float) - ): - new_metadata[key] = value - elif isinstance(value, BaseModel): - new_metadata[key] = value.model_dump_json() - elif isinstance(value, dict): - for k, v in value.items(): - if isinstance(v, dt): - value[k] = v.isoformat() - new_metadata[key] = value - - metadata = new_metadata - - kwargs["user_api_key"] = metadata.get("user_api_key", None) - kwargs["user_api_key_user_id"] = metadata.get("user_api_key_user_id", None) - kwargs["user_api_key_team_alias"] = metadata.get( - "user_api_key_team_alias", None - ) - project_name = metadata.get( - "project_name", self.default_credentials["LANGSMITH_PROJECT"] + "project_name", credentials["LANGSMITH_PROJECT"] ) run_name = metadata.get("run_name", self.langsmith_default_run_name) run_id = metadata.get("id", None) @@ -175,16 +127,10 @@ class LangsmithLogger(CustomBatchLogger): trace_id = metadata.get("trace_id", None) session_id = metadata.get("session_id", None) dotted_order = metadata.get("dotted_order", None) - tags = metadata.get("tags", []) or [] verbose_logger.debug( f"Langsmith Logging - project_name: {project_name}, run_name {run_name}" ) - # filter out kwargs to not include any dicts, langsmith throws an erros when trying to log kwargs - # logged_kwargs = LangsmithInputs(**kwargs) - # kwargs = logged_kwargs.model_dump() - - # new_kwargs = {} # Ensure everything in the payload is converted to str payload: Optional[StandardLoggingPayload] = kwargs.get( "standard_logging_object", None @@ -193,7 +139,6 @@ class LangsmithLogger(CustomBatchLogger): if payload is None: raise Exception("Error logging request payload. Payload=none.") - new_kwargs = payload metadata = payload[ "metadata" ] # ensure logged metadata is json serializable @@ -201,12 +146,12 @@ class LangsmithLogger(CustomBatchLogger): data = { "name": run_name, "run_type": "llm", # this should always be llm, since litellm always logs llm calls. Langsmith allow us to log "chain" - "inputs": new_kwargs, - "outputs": new_kwargs["response"], + "inputs": payload, + "outputs": payload["response"], "session_name": project_name, - "start_time": new_kwargs["startTime"], - "end_time": new_kwargs["endTime"], - "tags": tags, + "start_time": payload["startTime"], + "end_time": payload["endTime"], + "tags": payload["request_tags"], "extra": metadata, } @@ -243,37 +188,6 @@ class LangsmithLogger(CustomBatchLogger): except Exception: raise - def _send_batch(self): - if not self.log_queue: - return - - langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] - langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] - - url = f"{langsmith_api_base}/runs/batch" - - headers = {"x-api-key": langsmith_api_key} - - try: - response = requests.post( - url=url, - json=self.log_queue, - headers=headers, - ) - - if response.status_code >= 300: - verbose_logger.error( - f"Langsmith Error: {response.status_code} - {response.text}" - ) - else: - verbose_logger.debug( - f"Batch of {len(self.log_queue)} runs successfully created" - ) - - self.log_queue.clear() - except Exception: - verbose_logger.exception("Langsmith Layer Error - Error sending batch.") - def log_success_event(self, kwargs, response_obj, start_time, end_time): try: sampling_rate = ( @@ -295,8 +209,20 @@ class LangsmithLogger(CustomBatchLogger): kwargs, response_obj, ) - data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) - self.log_queue.append(data) + credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) + data = self._prepare_log_data( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + credentials=credentials, + ) + self.log_queue.append( + LangsmithQueueObject( + data=data, + credentials=credentials, + ) + ) verbose_logger.debug( f"Langsmith, event added to queue. Will flush in {self.flush_interval} seconds..." ) @@ -323,8 +249,20 @@ class LangsmithLogger(CustomBatchLogger): kwargs, response_obj, ) - data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) - self.log_queue.append(data) + credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) + data = self._prepare_log_data( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + credentials=credentials, + ) + self.log_queue.append( + LangsmithQueueObject( + data=data, + credentials=credentials, + ) + ) verbose_logger.debug( "Langsmith logging: queue length %s, batch size %s", len(self.log_queue), @@ -349,8 +287,20 @@ class LangsmithLogger(CustomBatchLogger): return # Skip logging verbose_logger.info("Langsmith Failure Event Logging!") try: - data = self._prepare_log_data(kwargs, response_obj, start_time, end_time) - self.log_queue.append(data) + credentials = self._get_credentials_to_use_for_request(kwargs=kwargs) + data = self._prepare_log_data( + kwargs=kwargs, + response_obj=response_obj, + start_time=start_time, + end_time=end_time, + credentials=credentials, + ) + self.log_queue.append( + LangsmithQueueObject( + data=data, + credentials=credentials, + ) + ) verbose_logger.debug( "Langsmith logging: queue length %s, batch size %s", len(self.log_queue), @@ -365,31 +315,58 @@ class LangsmithLogger(CustomBatchLogger): async def async_send_batch(self): """ - sends runs to /batch endpoint + Handles sending batches of runs to Langsmith - Sends runs from self.log_queue + self.log_queue contains LangsmithQueueObjects + Each LangsmithQueueObject has the following: + - "credentials" - credentials to use for the request (langsmith_api_key, langsmith_project, langsmith_base_url) + - "data" - data to log on to langsmith for the request + + + This function + - groups the queue objects by credentials + - loops through each unique credentials and sends batches to Langsmith + + + This was added to support key/team based logging on langsmith + """ + if not self.log_queue: + return + + batch_groups = self._group_batches_by_credentials() + for batch_group in batch_groups.values(): + await self._log_batch_on_langsmith( + credentials=batch_group.credentials, + queue_objects=batch_group.queue_objects, + ) + + async def _log_batch_on_langsmith( + self, + credentials: LangsmithCredentialsObject, + queue_objects: List[LangsmithQueueObject], + ): + """ + Logs a batch of runs to Langsmith + sends runs to /batch endpoint for the given credentials + + Args: + credentials: LangsmithCredentialsObject + queue_objects: List[LangsmithQueueObject] Returns: None Raises: Does not raise an exception, will only verbose_logger.exception() """ - if not self.log_queue: - return - - langsmith_api_base = self.default_credentials["LANGSMITH_BASE_URL"] - + langsmith_api_base = credentials["LANGSMITH_BASE_URL"] + langsmith_api_key = credentials["LANGSMITH_API_KEY"] url = f"{langsmith_api_base}/runs/batch" - - langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] - headers = {"x-api-key": langsmith_api_key} + elements_to_log = [queue_object["data"] for queue_object in queue_objects] try: response = await self.async_httpx_client.post( url=url, - json={ - "post": self.log_queue, - }, + json={"post": elements_to_log}, headers=headers, ) response.raise_for_status() @@ -411,6 +388,74 @@ class LangsmithLogger(CustomBatchLogger): f"Langsmith Layer Error - {traceback.format_exc()}" ) + def _group_batches_by_credentials(self) -> Dict[CredentialsKey, BatchGroup]: + """Groups queue objects by credentials using a proper key structure""" + log_queue_by_credentials: Dict[CredentialsKey, BatchGroup] = {} + + for queue_object in self.log_queue: + credentials = queue_object["credentials"] + key = CredentialsKey( + api_key=credentials["LANGSMITH_API_KEY"], + project=credentials["LANGSMITH_PROJECT"], + base_url=credentials["LANGSMITH_BASE_URL"], + ) + + if key not in log_queue_by_credentials: + log_queue_by_credentials[key] = BatchGroup( + credentials=credentials, queue_objects=[] + ) + + log_queue_by_credentials[key].queue_objects.append(queue_object) + + return log_queue_by_credentials + + def _get_credentials_to_use_for_request( + self, kwargs: Dict[str, Any] + ) -> LangsmithCredentialsObject: + """ + Handles key/team based logging + + If standard_callback_dynamic_params are provided, use those credentials. + + Otherwise, use the default credentials. + """ + standard_callback_dynamic_params: Optional[StandardCallbackDynamicParams] = ( + kwargs.get("standard_callback_dynamic_params", None) + ) + if standard_callback_dynamic_params is not None: + credentials = self.get_credentials_from_env( + langsmith_api_key=standard_callback_dynamic_params.get( + "langsmith_api_key", None + ), + langsmith_project=standard_callback_dynamic_params.get( + "langsmith_project", None + ), + langsmith_base_url=standard_callback_dynamic_params.get( + "langsmith_base_url", None + ), + ) + else: + credentials = self.default_credentials + return credentials + + def _send_batch(self): + """Calls async_send_batch in an event loop""" + if not self.log_queue: + return + + try: + # Try to get the existing event loop + loop = asyncio.get_event_loop() + if loop.is_running(): + # If we're already in an event loop, create a task + asyncio.create_task(self.async_send_batch()) + else: + # If no event loop is running, run the coroutine directly + loop.run_until_complete(self.async_send_batch()) + except RuntimeError: + # If we can't get an event loop, create a new one + asyncio.run(self.async_send_batch()) + def get_run_by_id(self, run_id): langsmith_api_key = self.default_credentials["LANGSMITH_API_KEY"] diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index b4a18baa4..29d14c910 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -6,5 +6,7 @@ model_list: api_base: https://exampleopenaiendpoint-production.up.railway.app/ + litellm_settings: callbacks: ["gcs_bucket"] + diff --git a/litellm/types/integrations/langsmith.py b/litellm/types/integrations/langsmith.py new file mode 100644 index 000000000..48c8e2e0a --- /dev/null +++ b/litellm/types/integrations/langsmith.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from datetime import datetime +from typing import Any, Dict, List, NamedTuple, Optional, TypedDict + +from pydantic import BaseModel + + +class LangsmithInputs(BaseModel): + model: Optional[str] = None + messages: Optional[List[Any]] = None + stream: Optional[bool] = None + call_type: Optional[str] = None + litellm_call_id: Optional[str] = None + completion_start_time: Optional[datetime] = None + temperature: Optional[float] = None + max_tokens: Optional[int] = None + custom_llm_provider: Optional[str] = None + input: Optional[List[Any]] = None + log_event_type: Optional[str] = None + original_response: Optional[Any] = None + response_cost: Optional[float] = None + + # LiteLLM Virtual Key specific fields + user_api_key: Optional[str] = None + user_api_key_user_id: Optional[str] = None + user_api_key_team_alias: Optional[str] = None + + +class LangsmithCredentialsObject(TypedDict): + LANGSMITH_API_KEY: str + LANGSMITH_PROJECT: str + LANGSMITH_BASE_URL: str + + +class LangsmithQueueObject(TypedDict): + """ + Langsmith Queue Object - this is what gets stored in the internal system queue before flushing to Langsmith + + We need to store: + - data[Dict] - data that should get logged on langsmith + - credentials[LangsmithCredentialsObject] - credentials to use for logging to langsmith + """ + + data: Dict + credentials: LangsmithCredentialsObject + + +class CredentialsKey(NamedTuple): + """Immutable key for grouping credentials""" + + api_key: str + project: str + base_url: str + + +@dataclass +class BatchGroup: + """Groups credentials with their associated queue objects""" + + credentials: LangsmithCredentialsObject + queue_objects: List[LangsmithQueueObject] diff --git a/litellm/types/utils.py b/litellm/types/utils.py index a2b62f9cc..e3df357be 100644 --- a/litellm/types/utils.py +++ b/litellm/types/utils.py @@ -1595,3 +1595,8 @@ class StandardCallbackDynamicParams(TypedDict, total=False): # GCS dynamic params gcs_bucket_name: Optional[str] gcs_path_service_account: Optional[str] + + # Langsmith dynamic params + langsmith_api_key: Optional[str] + langsmith_project: Optional[str] + langsmith_base_url: Optional[str] diff --git a/tests/local_testing/test_langsmith.py b/tests/local_testing/test_langsmith.py index 6a98f244d..ab387e444 100644 --- a/tests/local_testing/test_langsmith.py +++ b/tests/local_testing/test_langsmith.py @@ -22,61 +22,6 @@ litellm.set_verbose = True import time -@pytest.mark.asyncio -async def test_langsmith_queue_logging(): - try: - # Initialize LangsmithLogger - test_langsmith_logger = LangsmithLogger() - - litellm.callbacks = [test_langsmith_logger] - test_langsmith_logger.batch_size = 6 - litellm.set_verbose = True - - # Make multiple calls to ensure we don't hit the batch size - for _ in range(5): - response = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Test message"}], - max_tokens=10, - temperature=0.2, - mock_response="This is a mock response", - ) - - await asyncio.sleep(3) - - # Check that logs are in the queue - assert len(test_langsmith_logger.log_queue) == 5 - - # Now make calls to exceed the batch size - for _ in range(3): - response = await litellm.acompletion( - model="gpt-3.5-turbo", - messages=[{"role": "user", "content": "Test message"}], - max_tokens=10, - temperature=0.2, - mock_response="This is a mock response", - ) - - # Wait a short time for any asynchronous operations to complete - await asyncio.sleep(1) - - print( - "Length of langsmith log queue: {}".format( - len(test_langsmith_logger.log_queue) - ) - ) - # Check that the queue was flushed after exceeding batch size - assert len(test_langsmith_logger.log_queue) < 5 - - # Clean up - for cb in litellm.callbacks: - if isinstance(cb, LangsmithLogger): - await cb.async_httpx_client.client.aclose() - - except Exception as e: - pytest.fail(f"Error occurred: {e}") - - # test_langsmith_logging() diff --git a/tests/logging_callback_tests/test_langsmith_unit_test.py b/tests/logging_callback_tests/test_langsmith_unit_test.py new file mode 100644 index 000000000..3e106666f --- /dev/null +++ b/tests/logging_callback_tests/test_langsmith_unit_test.py @@ -0,0 +1,394 @@ +import io +import os +import sys + + +sys.path.insert(0, os.path.abspath("../..")) + +import asyncio +import gzip +import json +import logging +import time +from unittest.mock import AsyncMock, patch, MagicMock +import pytest +from datetime import datetime, timezone +from litellm.integrations.langsmith import ( + LangsmithLogger, + LangsmithQueueObject, + CredentialsKey, + BatchGroup, +) + +import litellm + + +# Test get_credentials_from_env +@pytest.mark.asyncio +async def test_get_credentials_from_env(): + # Test with direct parameters + logger = LangsmithLogger( + langsmith_api_key="test-key", + langsmith_project="test-project", + langsmith_base_url="http://test-url", + ) + + credentials = logger.get_credentials_from_env( + langsmith_api_key="custom-key", + langsmith_project="custom-project", + langsmith_base_url="http://custom-url", + ) + + assert credentials["LANGSMITH_API_KEY"] == "custom-key" + assert credentials["LANGSMITH_PROJECT"] == "custom-project" + assert credentials["LANGSMITH_BASE_URL"] == "http://custom-url" + + # assert that the default api base is used if not provided + credentials = logger.get_credentials_from_env() + assert credentials["LANGSMITH_BASE_URL"] == "https://api.smith.langchain.com" + + +@pytest.mark.asyncio +async def test_group_batches_by_credentials(): + + logger = LangsmithLogger(langsmith_api_key="test-key") + + # Create test queue objects + queue_obj1 = LangsmithQueueObject( + data={"test": "data1"}, + credentials={ + "LANGSMITH_API_KEY": "key1", + "LANGSMITH_PROJECT": "proj1", + "LANGSMITH_BASE_URL": "url1", + }, + ) + + queue_obj2 = LangsmithQueueObject( + data={"test": "data2"}, + credentials={ + "LANGSMITH_API_KEY": "key1", + "LANGSMITH_PROJECT": "proj1", + "LANGSMITH_BASE_URL": "url1", + }, + ) + + logger.log_queue = [queue_obj1, queue_obj2] + + grouped = logger._group_batches_by_credentials() + + # Check grouping + assert len(grouped) == 1 # Should have one group since credentials are same + key = list(grouped.keys())[0] + assert isinstance(key, CredentialsKey) + assert len(grouped[key].queue_objects) == 2 + + +@pytest.mark.asyncio +async def test_group_batches_by_credentials_multiple_credentials(): + + # Test with multiple different credentials + logger = LangsmithLogger(langsmith_api_key="test-key") + + queue_obj1 = LangsmithQueueObject( + data={"test": "data1"}, + credentials={ + "LANGSMITH_API_KEY": "key1", + "LANGSMITH_PROJECT": "proj1", + "LANGSMITH_BASE_URL": "url1", + }, + ) + + queue_obj2 = LangsmithQueueObject( + data={"test": "data2"}, + credentials={ + "LANGSMITH_API_KEY": "key2", # Different API key + "LANGSMITH_PROJECT": "proj1", + "LANGSMITH_BASE_URL": "url1", + }, + ) + + queue_obj3 = LangsmithQueueObject( + data={"test": "data3"}, + credentials={ + "LANGSMITH_API_KEY": "key1", + "LANGSMITH_PROJECT": "proj2", # Different project + "LANGSMITH_BASE_URL": "url1", + }, + ) + + logger.log_queue = [queue_obj1, queue_obj2, queue_obj3] + + grouped = logger._group_batches_by_credentials() + + # Check grouping + assert len(grouped) == 3 # Should have three groups since credentials differ + for key, batch_group in grouped.items(): + assert isinstance(key, CredentialsKey) + assert len(batch_group.queue_objects) == 1 # Each group should have one object + + +# Test make_dot_order +@pytest.mark.asyncio +async def test_make_dot_order(): + logger = LangsmithLogger(langsmith_api_key="test-key") + run_id = "729cff0e-f30c-4336-8b79-45d6b61c64b4" + dot_order = logger.make_dot_order(run_id) + + print("dot_order=", dot_order) + + # Check format: YYYYMMDDTHHMMSSfffZ + run_id + # Check the timestamp portion (first 23 characters) + timestamp_part = dot_order[:-36] # 36 is length of run_id + assert len(timestamp_part) == 22 + assert timestamp_part[8] == "T" # Check T separator + assert timestamp_part[-1] == "Z" # Check Z suffix + + # Verify timestamp format + try: + # Parse the timestamp portion (removing the Z) + datetime.strptime(timestamp_part[:-1], "%Y%m%dT%H%M%S%f") + except ValueError: + pytest.fail("Timestamp portion is not in correct format") + + # Verify run_id portion + assert dot_order[-36:] == run_id + + +# Test is_serializable +@pytest.mark.asyncio +async def test_is_serializable(): + from litellm.integrations.langsmith import is_serializable + from pydantic import BaseModel + + # Test basic types + assert is_serializable("string") is True + assert is_serializable(123) is True + assert is_serializable({"key": "value"}) is True + + # Test non-serializable types + async def async_func(): + pass + + assert is_serializable(async_func) is False + + class TestModel(BaseModel): + field: str + + assert is_serializable(TestModel(field="test")) is False + + +@pytest.mark.asyncio +async def test_async_send_batch(): + logger = LangsmithLogger(langsmith_api_key="test-key") + + # Mock the httpx client + mock_response = AsyncMock() + mock_response.status_code = 200 + logger.async_httpx_client = AsyncMock() + logger.async_httpx_client.post.return_value = mock_response + + # Add test data to queue + logger.log_queue = [ + LangsmithQueueObject( + data={"test": "data"}, credentials=logger.default_credentials + ) + ] + + await logger.async_send_batch() + + # Verify the API call + logger.async_httpx_client.post.assert_called_once() + call_args = logger.async_httpx_client.post.call_args + assert "runs/batch" in call_args[1]["url"] + assert "x-api-key" in call_args[1]["headers"] + + +@pytest.mark.asyncio +async def test_langsmith_key_based_logging(mocker): + """ + In key based logging langsmith_api_key and langsmith_project are passed directly to litellm.acompletion + """ + try: + # Mock the httpx post request + mock_post = mocker.patch( + "litellm.llms.custom_httpx.http_handler.AsyncHTTPHandler.post" + ) + mock_post.return_value.status_code = 200 + mock_post.return_value.raise_for_status = lambda: None + litellm.set_verbose = True + + litellm.callbacks = [LangsmithLogger()] + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Test message"}], + max_tokens=10, + temperature=0.2, + mock_response="This is a mock response", + langsmith_api_key="fake_key_project2", + langsmith_project="fake_project2", + ) + print("Waiting for logs to be flushed to Langsmith.....") + await asyncio.sleep(15) + + print("done sleeping 15 seconds...") + + # Verify the post request was made with correct parameters + mock_post.assert_called_once() + call_args = mock_post.call_args + + print("call_args", call_args) + + # Check URL contains /runs/batch + assert "/runs/batch" in call_args[1]["url"] + + # Check headers contain the correct API key + assert call_args[1]["headers"]["x-api-key"] == "fake_key_project2" + + # Verify the request body contains the expected data + request_body = call_args[1]["json"] + assert "post" in request_body + assert len(request_body["post"]) == 1 # Should contain one run + + # EXPECTED BODY + expected_body = { + "post": [ + { + "name": "LLMRun", + "run_type": "llm", + "inputs": { + "id": "chatcmpl-82699ee4-7932-4fc0-9585-76abc8caeafa", + "call_type": "acompletion", + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Test message"}], + "model_parameters": { + "temperature": 0.2, + "max_tokens": 10, + "extra_body": {}, + }, + }, + "outputs": { + "id": "chatcmpl-82699ee4-7932-4fc0-9585-76abc8caeafa", + "model": "gpt-3.5-turbo", + "choices": [ + { + "finish_reason": "stop", + "index": 0, + "message": { + "content": "This is a mock response", + "role": "assistant", + "tool_calls": None, + "function_call": None, + }, + } + ], + "usage": { + "completion_tokens": 20, + "prompt_tokens": 10, + "total_tokens": 30, + }, + }, + "session_name": "fake_project2", + } + ] + } + + # Print both bodies for debugging + actual_body = call_args[1]["json"] + print("\nExpected body:") + print(json.dumps(expected_body, indent=2)) + print("\nActual body:") + print(json.dumps(actual_body, indent=2)) + + assert len(actual_body["post"]) == 1 + + # Assert only the critical parts we care about + assert actual_body["post"][0]["name"] == expected_body["post"][0]["name"] + assert ( + actual_body["post"][0]["run_type"] == expected_body["post"][0]["run_type"] + ) + assert ( + actual_body["post"][0]["inputs"]["messages"] + == expected_body["post"][0]["inputs"]["messages"] + ) + assert ( + actual_body["post"][0]["inputs"]["model_parameters"] + == expected_body["post"][0]["inputs"]["model_parameters"] + ) + assert ( + actual_body["post"][0]["outputs"]["choices"] + == expected_body["post"][0]["outputs"]["choices"] + ) + assert ( + actual_body["post"][0]["outputs"]["usage"]["completion_tokens"] + == expected_body["post"][0]["outputs"]["usage"]["completion_tokens"] + ) + assert ( + actual_body["post"][0]["outputs"]["usage"]["prompt_tokens"] + == expected_body["post"][0]["outputs"]["usage"]["prompt_tokens"] + ) + assert ( + actual_body["post"][0]["outputs"]["usage"]["total_tokens"] + == expected_body["post"][0]["outputs"]["usage"]["total_tokens"] + ) + assert ( + actual_body["post"][0]["session_name"] + == expected_body["post"][0]["session_name"] + ) + + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +@pytest.mark.asyncio +async def test_langsmith_queue_logging(): + try: + # Initialize LangsmithLogger + test_langsmith_logger = LangsmithLogger() + + litellm.callbacks = [test_langsmith_logger] + test_langsmith_logger.batch_size = 6 + litellm.set_verbose = True + + # Make multiple calls to ensure we don't hit the batch size + for _ in range(5): + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Test message"}], + max_tokens=10, + temperature=0.2, + mock_response="This is a mock response", + ) + + await asyncio.sleep(3) + + # Check that logs are in the queue + assert len(test_langsmith_logger.log_queue) == 5 + + # Now make calls to exceed the batch size + for _ in range(3): + response = await litellm.acompletion( + model="gpt-3.5-turbo", + messages=[{"role": "user", "content": "Test message"}], + max_tokens=10, + temperature=0.2, + mock_response="This is a mock response", + ) + + # Wait a short time for any asynchronous operations to complete + await asyncio.sleep(1) + + print( + "Length of langsmith log queue: {}".format( + len(test_langsmith_logger.log_queue) + ) + ) + # Check that the queue was flushed after exceeding batch size + assert len(test_langsmith_logger.log_queue) < 5 + + # Clean up + for cb in litellm.callbacks: + if isinstance(cb, LangsmithLogger): + await cb.async_httpx_client.client.aclose() + + except Exception as e: + pytest.fail(f"Error occurred: {e}") diff --git a/tests/proxy_unit_tests/test_proxy_server.py b/tests/proxy_unit_tests/test_proxy_server.py index 76cdf1a54..5588d0414 100644 --- a/tests/proxy_unit_tests/test_proxy_server.py +++ b/tests/proxy_unit_tests/test_proxy_server.py @@ -1632,6 +1632,139 @@ async def test_add_callback_via_key_litellm_pre_call_utils_gcs_bucket( assert new_data["failure_callback"] == expected_failure_callbacks +@pytest.mark.asyncio +@pytest.mark.parametrize( + "callback_type, expected_success_callbacks, expected_failure_callbacks", + [ + ("success", ["langsmith"], []), + ("failure", [], ["langsmith"]), + ("success_and_failure", ["langsmith"], ["langsmith"]), + ], +) +async def test_add_callback_via_key_litellm_pre_call_utils_langsmith( + prisma_client, callback_type, expected_success_callbacks, expected_failure_callbacks +): + import json + + from fastapi import HTTPException, Request, Response + from starlette.datastructures import URL + + from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request + + setattr(litellm.proxy.proxy_server, "prisma_client", prisma_client) + setattr(litellm.proxy.proxy_server, "master_key", "sk-1234") + await litellm.proxy.proxy_server.prisma_client.connect() + + proxy_config = getattr(litellm.proxy.proxy_server, "proxy_config") + + request = Request(scope={"type": "http", "method": "POST", "headers": {}}) + request._url = URL(url="/chat/completions") + + test_data = { + "model": "azure/chatgpt-v-2", + "messages": [ + {"role": "user", "content": "write 1 sentence poem"}, + ], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + } + + json_bytes = json.dumps(test_data).encode("utf-8") + + request._body = json_bytes + + data = { + "data": { + "model": "azure/chatgpt-v-2", + "messages": [{"role": "user", "content": "write 1 sentence poem"}], + "max_tokens": 10, + "mock_response": "Hello world", + "api_key": "my-fake-key", + }, + "request": request, + "user_api_key_dict": UserAPIKeyAuth( + token=None, + key_name=None, + key_alias=None, + spend=0.0, + max_budget=None, + expires=None, + models=[], + aliases={}, + config={}, + user_id=None, + team_id=None, + max_parallel_requests=None, + metadata={ + "logging": [ + { + "callback_name": "langsmith", + "callback_type": callback_type, + "callback_vars": { + "langsmith_api_key": "ls-1234", + "langsmith_project": "pr-brief-resemblance-72", + "langsmith_base_url": "https://api.smith.langchain.com", + }, + } + ] + }, + tpm_limit=None, + rpm_limit=None, + budget_duration=None, + budget_reset_at=None, + allowed_cache_controls=[], + permissions={}, + model_spend={}, + model_max_budget={}, + soft_budget_cooldown=False, + litellm_budget_table=None, + org_id=None, + team_spend=None, + team_alias=None, + team_tpm_limit=None, + team_rpm_limit=None, + team_max_budget=None, + team_models=[], + team_blocked=False, + soft_budget=None, + team_model_aliases=None, + team_member_spend=None, + team_metadata=None, + end_user_id=None, + end_user_tpm_limit=None, + end_user_rpm_limit=None, + end_user_max_budget=None, + last_refreshed_at=None, + api_key=None, + user_role=None, + allowed_model_region=None, + parent_otel_span=None, + ), + "proxy_config": proxy_config, + "general_settings": {}, + "version": "0.0.0", + } + + new_data = await add_litellm_data_to_request(**data) + print("NEW DATA: {}".format(new_data)) + + assert "langsmith_api_key" in new_data + assert new_data["langsmith_api_key"] == "ls-1234" + assert "langsmith_project" in new_data + assert new_data["langsmith_project"] == "pr-brief-resemblance-72" + assert "langsmith_base_url" in new_data + assert new_data["langsmith_base_url"] == "https://api.smith.langchain.com" + + if expected_success_callbacks: + assert "success_callback" in new_data + assert new_data["success_callback"] == expected_success_callbacks + + if expected_failure_callbacks: + assert "failure_callback" in new_data + assert new_data["failure_callback"] == expected_failure_callbacks + + @pytest.mark.asyncio async def test_gemini_pass_through_endpoint(): from starlette.datastructures import URL