(feat) Add basic logging support for /batches endpoints (#7381)

* add basic logging for create`batch`

* add create_batch as a call type

* add basic dd logging for batches

* basic batch creation logging on DD
This commit is contained in:
Ishaan Jaff 2024-12-23 17:45:03 -08:00 committed by GitHub
parent 6f6c651ee0
commit 87f19d6f13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 81 additions and 43 deletions

View file

@ -25,7 +25,7 @@ from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
from litellm.secret_managers.main import get_secret_str from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.utils import supports_httpx_timeout from litellm.utils import client, supports_httpx_timeout
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI() openai_batches_instance = OpenAIBatchesAPI()
@ -34,6 +34,7 @@ vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
################################################# #################################################
@client
async def acreate_batch( async def acreate_batch(
completion_window: Literal["24h"], completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
@ -80,6 +81,7 @@ async def acreate_batch(
raise e raise e
@client
def create_batch( def create_batch(
completion_window: Literal["24h"], completion_window: Literal["24h"],
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"], endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],

View file

@ -34,7 +34,11 @@ from litellm.litellm_core_utils.redact_messages import (
redact_message_input_output_from_custom_logger, redact_message_input_output_from_custom_logger,
redact_message_input_output_from_logging, redact_message_input_output_from_logging,
) )
from litellm.types.llms.openai import AllMessageValues, HttpxBinaryResponseContent from litellm.types.llms.openai import (
AllMessageValues,
Batch,
HttpxBinaryResponseContent,
)
from litellm.types.rerank import RerankResponse from litellm.types.rerank import RerankResponse
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
from litellm.types.utils import ( from litellm.types.utils import (
@ -749,6 +753,7 @@ class Logging(LiteLLMLoggingBaseClass):
TextCompletionResponse, TextCompletionResponse,
HttpxBinaryResponseContent, HttpxBinaryResponseContent,
RerankResponse, RerankResponse,
Batch,
], ],
cache_hit: Optional[bool] = None, cache_hit: Optional[bool] = None,
) -> Optional[float]: ) -> Optional[float]:
@ -865,6 +870,7 @@ class Logging(LiteLLMLoggingBaseClass):
or isinstance(result, TextCompletionResponse) or isinstance(result, TextCompletionResponse)
or isinstance(result, HttpxBinaryResponseContent) # tts or isinstance(result, HttpxBinaryResponseContent) # tts
or isinstance(result, RerankResponse) or isinstance(result, RerankResponse)
or isinstance(result, Batch)
): ):
## RESPONSE COST ## ## RESPONSE COST ##
self.model_call_details["response_cost"] = ( self.model_call_details["response_cost"] = (

View file

@ -169,6 +169,8 @@ class CallTypes(Enum):
rerank = "rerank" rerank = "rerank"
arerank = "arerank" arerank = "arerank"
arealtime = "_arealtime" arealtime = "_arealtime"
create_batch = "create_batch"
acreate_batch = "acreate_batch"
pass_through = "pass_through_endpoint" pass_through = "pass_through_endpoint"
@ -190,6 +192,9 @@ CallTypesLiteral = Literal[
"rerank", "rerank",
"arerank", "arerank",
"_arealtime", "_arealtime",
"create_batch",
"acreate_batch",
"pass_through_endpoint",
] ]

View file

@ -712,16 +712,8 @@ def client(original_function): # noqa: PLR0915
def wrapper(*args, **kwargs): # noqa: PLR0915 def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first # DO NOT MOVE THIS. It always needs to run first
# Check if this is an async function. If so only execute the async function # Check if this is an async function. If so only execute the async function
if ( call_type = original_function.__name__
kwargs.get("acompletion", False) is True if _is_async_request(kwargs):
or kwargs.get("aembedding", False) is True
or kwargs.get("aimg_generation", False) is True
or kwargs.get("amoderation", False) is True
or kwargs.get("atext_completion", False) is True
or kwargs.get("atranscription", False) is True
or kwargs.get("arerank", False) is True
or kwargs.get("_arealtime", False) is True
):
# [OPTIONAL] CHECK MAX RETRIES / REQUEST # [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None: if litellm.num_retries_per_request is not None:
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] # check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
@ -759,20 +751,10 @@ def client(original_function): # noqa: PLR0915
) )
# only set litellm_call_id if its not in kwargs # only set litellm_call_id if its not in kwargs
call_type = original_function.__name__
if "litellm_call_id" not in kwargs: if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4()) kwargs["litellm_call_id"] = str(uuid.uuid4())
model: Optional[str] = None model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except Exception:
model = None
if (
call_type != CallTypes.image_generation.value
and call_type != CallTypes.text_completion.value
):
raise ValueError("model param not passed in.")
try: try:
if logging_obj is None: if logging_obj is None:
@ -1022,16 +1004,7 @@ def client(original_function): # noqa: PLR0915
if "litellm_call_id" not in kwargs: if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4()) kwargs["litellm_call_id"] = str(uuid.uuid4())
model = "" model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except Exception:
if (
call_type != CallTypes.aimage_generation.value # model optional
and call_type != CallTypes.atext_completion.value # can also be engine
and call_type != CallTypes.amoderation.value
):
raise ValueError("model param not passed in.")
try: try:
if logging_obj is None: if logging_obj is None:
@ -1054,7 +1027,7 @@ def client(original_function): # noqa: PLR0915
) )
_caching_handler_response: CachingHandlerResponse = ( _caching_handler_response: CachingHandlerResponse = (
await _llm_caching_handler._async_get_cache( await _llm_caching_handler._async_get_cache(
model=model, model=model or "",
original_function=original_function, original_function=original_function,
logging_obj=logging_obj, logging_obj=logging_obj,
start_time=start_time, start_time=start_time,
@ -1100,7 +1073,7 @@ def client(original_function): # noqa: PLR0915
"id", None "id", None
) )
result._hidden_params["api_base"] = get_api_base( result._hidden_params["api_base"] = get_api_base(
model=model, model=model or "",
optional_params=kwargs, optional_params=kwargs,
) )
result._hidden_params["response_cost"] = ( result._hidden_params["response_cost"] = (
@ -1230,6 +1203,29 @@ def client(original_function): # noqa: PLR0915
return wrapper return wrapper
def _is_async_request(kwargs: Optional[dict]) -> bool:
"""
Returns True if the call type is an internal async request.
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
"""
if kwargs is None:
return False
if (
kwargs.get("acompletion", False) is True
or kwargs.get("aembedding", False) is True
or kwargs.get("aimg_generation", False) is True
or kwargs.get("amoderation", False) is True
or kwargs.get("atext_completion", False) is True
or kwargs.get("atranscription", False) is True
or kwargs.get("arerank", False) is True
or kwargs.get("_arealtime", False) is True
or kwargs.get("acreate_batch", False) is True
):
return True
return False
@lru_cache(maxsize=128) @lru_cache(maxsize=128)
def _select_tokenizer( def _select_tokenizer(
model: str, model: str,

View file

@ -16,7 +16,7 @@ import logging
import time import time
import pytest import pytest
from typing import Optional
import litellm import litellm
from litellm import create_batch, create_file from litellm import create_batch, create_file
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
@ -24,14 +24,36 @@ from test_gcs_bucket import load_vertex_ai_credentials
verbose_logger.setLevel(logging.DEBUG) verbose_logger.setLevel(logging.DEBUG)
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.standard_logging_object: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(
"Success event logged with kwargs=",
kwargs,
"and response_obj=",
response_obj,
)
self.standard_logging_object = kwargs["standard_logging_object"]
@pytest.mark.parametrize("provider", ["openai"]) # , "azure" @pytest.mark.parametrize("provider", ["openai"]) # , "azure"
def test_create_batch(provider): @pytest.mark.asyncio
async def test_create_batch(provider):
""" """
1. Create File for Batch completion 1. Create File for Batch completion
2. Create Batch Request 2. Create Batch Request
3. Retrieve the specific batch 3. Retrieve the specific batch
""" """
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger, "datadog"]
if provider == "azure": if provider == "azure":
# Don't have anymore Azure Quota # Don't have anymore Azure Quota
return return
@ -39,7 +61,7 @@ def test_create_batch(provider):
_current_dir = os.path.dirname(os.path.abspath(__file__)) _current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name) file_path = os.path.join(_current_dir, file_name)
file_obj = litellm.create_file( file_obj = await litellm.acreate_file(
file=open(file_path, "rb"), file=open(file_path, "rb"),
purpose="batch", purpose="batch",
custom_llm_provider=provider, custom_llm_provider=provider,
@ -51,8 +73,8 @@ def test_create_batch(provider):
batch_input_file_id is not None batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}" ), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
time.sleep(5) await asyncio.sleep(1)
create_batch_response = litellm.create_batch( create_batch_response = await litellm.acreate_batch(
completion_window="24h", completion_window="24h",
endpoint="/v1/chat/completions", endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id, input_file_id=batch_input_file_id,
@ -61,7 +83,14 @@ def test_create_batch(provider):
) )
print("response from litellm.create_batch=", create_batch_response) print("response from litellm.create_batch=", create_batch_response)
await asyncio.sleep(6)
# Assert that the create batch event is logged on CustomLogger
assert custom_logger.standard_logging_object is not None
print(
"standard_logging_object=",
json.dumps(custom_logger.standard_logging_object, indent=4),
)
assert ( assert (
create_batch_response.id is not None create_batch_response.id is not None
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}" ), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
@ -73,7 +102,7 @@ def test_create_batch(provider):
create_batch_response.input_file_id == batch_input_file_id create_batch_response.input_file_id == batch_input_file_id
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
retrieved_batch = litellm.retrieve_batch( retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id, custom_llm_provider=provider batch_id=create_batch_response.id, custom_llm_provider=provider
) )
print("retrieved batch=", retrieved_batch) print("retrieved batch=", retrieved_batch)
@ -82,10 +111,10 @@ def test_create_batch(provider):
assert retrieved_batch.id == create_batch_response.id assert retrieved_batch.id == create_batch_response.id
# list all batches # list all batches
list_batches = litellm.list_batches(custom_llm_provider=provider, limit=2) list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2)
print("list_batches=", list_batches) print("list_batches=", list_batches)
file_content = litellm.file_content( file_content = await litellm.afile_content(
file_id=batch_input_file_id, custom_llm_provider=provider file_id=batch_input_file_id, custom_llm_provider=provider
) )