mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
(feat) Add basic logging support for /batches
endpoints (#7381)
* add basic logging for create`batch` * add create_batch as a call type * add basic dd logging for batches * basic batch creation logging on DD
This commit is contained in:
parent
6f6c651ee0
commit
87f19d6f13
5 changed files with 81 additions and 43 deletions
|
@ -25,7 +25,7 @@ from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
|
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import supports_httpx_timeout
|
from litellm.utils import client, supports_httpx_timeout
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_batches_instance = OpenAIBatchesAPI()
|
openai_batches_instance = OpenAIBatchesAPI()
|
||||||
|
@ -34,6 +34,7 @@ vertex_ai_batches_instance = VertexAIBatchPrediction(gcs_bucket_name="")
|
||||||
#################################################
|
#################################################
|
||||||
|
|
||||||
|
|
||||||
|
@client
|
||||||
async def acreate_batch(
|
async def acreate_batch(
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
|
@ -80,6 +81,7 @@ async def acreate_batch(
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
@client
|
||||||
def create_batch(
|
def create_batch(
|
||||||
completion_window: Literal["24h"],
|
completion_window: Literal["24h"],
|
||||||
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
endpoint: Literal["/v1/chat/completions", "/v1/embeddings", "/v1/completions"],
|
||||||
|
|
|
@ -34,7 +34,11 @@ from litellm.litellm_core_utils.redact_messages import (
|
||||||
redact_message_input_output_from_custom_logger,
|
redact_message_input_output_from_custom_logger,
|
||||||
redact_message_input_output_from_logging,
|
redact_message_input_output_from_logging,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.openai import AllMessageValues, HttpxBinaryResponseContent
|
from litellm.types.llms.openai import (
|
||||||
|
AllMessageValues,
|
||||||
|
Batch,
|
||||||
|
HttpxBinaryResponseContent,
|
||||||
|
)
|
||||||
from litellm.types.rerank import RerankResponse
|
from litellm.types.rerank import RerankResponse
|
||||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||||
from litellm.types.utils import (
|
from litellm.types.utils import (
|
||||||
|
@ -749,6 +753,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
HttpxBinaryResponseContent,
|
HttpxBinaryResponseContent,
|
||||||
RerankResponse,
|
RerankResponse,
|
||||||
|
Batch,
|
||||||
],
|
],
|
||||||
cache_hit: Optional[bool] = None,
|
cache_hit: Optional[bool] = None,
|
||||||
) -> Optional[float]:
|
) -> Optional[float]:
|
||||||
|
@ -865,6 +870,7 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
or isinstance(result, TextCompletionResponse)
|
or isinstance(result, TextCompletionResponse)
|
||||||
or isinstance(result, HttpxBinaryResponseContent) # tts
|
or isinstance(result, HttpxBinaryResponseContent) # tts
|
||||||
or isinstance(result, RerankResponse)
|
or isinstance(result, RerankResponse)
|
||||||
|
or isinstance(result, Batch)
|
||||||
):
|
):
|
||||||
## RESPONSE COST ##
|
## RESPONSE COST ##
|
||||||
self.model_call_details["response_cost"] = (
|
self.model_call_details["response_cost"] = (
|
||||||
|
|
|
@ -169,6 +169,8 @@ class CallTypes(Enum):
|
||||||
rerank = "rerank"
|
rerank = "rerank"
|
||||||
arerank = "arerank"
|
arerank = "arerank"
|
||||||
arealtime = "_arealtime"
|
arealtime = "_arealtime"
|
||||||
|
create_batch = "create_batch"
|
||||||
|
acreate_batch = "acreate_batch"
|
||||||
pass_through = "pass_through_endpoint"
|
pass_through = "pass_through_endpoint"
|
||||||
|
|
||||||
|
|
||||||
|
@ -190,6 +192,9 @@ CallTypesLiteral = Literal[
|
||||||
"rerank",
|
"rerank",
|
||||||
"arerank",
|
"arerank",
|
||||||
"_arealtime",
|
"_arealtime",
|
||||||
|
"create_batch",
|
||||||
|
"acreate_batch",
|
||||||
|
"pass_through_endpoint",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -712,16 +712,8 @@ def client(original_function): # noqa: PLR0915
|
||||||
def wrapper(*args, **kwargs): # noqa: PLR0915
|
def wrapper(*args, **kwargs): # noqa: PLR0915
|
||||||
# DO NOT MOVE THIS. It always needs to run first
|
# DO NOT MOVE THIS. It always needs to run first
|
||||||
# Check if this is an async function. If so only execute the async function
|
# Check if this is an async function. If so only execute the async function
|
||||||
if (
|
call_type = original_function.__name__
|
||||||
kwargs.get("acompletion", False) is True
|
if _is_async_request(kwargs):
|
||||||
or kwargs.get("aembedding", False) is True
|
|
||||||
or kwargs.get("aimg_generation", False) is True
|
|
||||||
or kwargs.get("amoderation", False) is True
|
|
||||||
or kwargs.get("atext_completion", False) is True
|
|
||||||
or kwargs.get("atranscription", False) is True
|
|
||||||
or kwargs.get("arerank", False) is True
|
|
||||||
or kwargs.get("_arealtime", False) is True
|
|
||||||
):
|
|
||||||
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
|
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
|
||||||
if litellm.num_retries_per_request is not None:
|
if litellm.num_retries_per_request is not None:
|
||||||
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
|
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
|
||||||
|
@ -759,20 +751,10 @@ def client(original_function): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
|
|
||||||
# only set litellm_call_id if its not in kwargs
|
# only set litellm_call_id if its not in kwargs
|
||||||
call_type = original_function.__name__
|
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
|
|
||||||
model: Optional[str] = None
|
model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||||
try:
|
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
|
||||||
except Exception:
|
|
||||||
model = None
|
|
||||||
if (
|
|
||||||
call_type != CallTypes.image_generation.value
|
|
||||||
and call_type != CallTypes.text_completion.value
|
|
||||||
):
|
|
||||||
raise ValueError("model param not passed in.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if logging_obj is None:
|
if logging_obj is None:
|
||||||
|
@ -1022,16 +1004,7 @@ def client(original_function): # noqa: PLR0915
|
||||||
if "litellm_call_id" not in kwargs:
|
if "litellm_call_id" not in kwargs:
|
||||||
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
kwargs["litellm_call_id"] = str(uuid.uuid4())
|
||||||
|
|
||||||
model = ""
|
model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
|
||||||
try:
|
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
|
||||||
except Exception:
|
|
||||||
if (
|
|
||||||
call_type != CallTypes.aimage_generation.value # model optional
|
|
||||||
and call_type != CallTypes.atext_completion.value # can also be engine
|
|
||||||
and call_type != CallTypes.amoderation.value
|
|
||||||
):
|
|
||||||
raise ValueError("model param not passed in.")
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if logging_obj is None:
|
if logging_obj is None:
|
||||||
|
@ -1054,7 +1027,7 @@ def client(original_function): # noqa: PLR0915
|
||||||
)
|
)
|
||||||
_caching_handler_response: CachingHandlerResponse = (
|
_caching_handler_response: CachingHandlerResponse = (
|
||||||
await _llm_caching_handler._async_get_cache(
|
await _llm_caching_handler._async_get_cache(
|
||||||
model=model,
|
model=model or "",
|
||||||
original_function=original_function,
|
original_function=original_function,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
start_time=start_time,
|
start_time=start_time,
|
||||||
|
@ -1100,7 +1073,7 @@ def client(original_function): # noqa: PLR0915
|
||||||
"id", None
|
"id", None
|
||||||
)
|
)
|
||||||
result._hidden_params["api_base"] = get_api_base(
|
result._hidden_params["api_base"] = get_api_base(
|
||||||
model=model,
|
model=model or "",
|
||||||
optional_params=kwargs,
|
optional_params=kwargs,
|
||||||
)
|
)
|
||||||
result._hidden_params["response_cost"] = (
|
result._hidden_params["response_cost"] = (
|
||||||
|
@ -1230,6 +1203,29 @@ def client(original_function): # noqa: PLR0915
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
|
def _is_async_request(kwargs: Optional[dict]) -> bool:
|
||||||
|
"""
|
||||||
|
Returns True if the call type is an internal async request.
|
||||||
|
|
||||||
|
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
|
||||||
|
"""
|
||||||
|
if kwargs is None:
|
||||||
|
return False
|
||||||
|
if (
|
||||||
|
kwargs.get("acompletion", False) is True
|
||||||
|
or kwargs.get("aembedding", False) is True
|
||||||
|
or kwargs.get("aimg_generation", False) is True
|
||||||
|
or kwargs.get("amoderation", False) is True
|
||||||
|
or kwargs.get("atext_completion", False) is True
|
||||||
|
or kwargs.get("atranscription", False) is True
|
||||||
|
or kwargs.get("arerank", False) is True
|
||||||
|
or kwargs.get("_arealtime", False) is True
|
||||||
|
or kwargs.get("acreate_batch", False) is True
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=128)
|
@lru_cache(maxsize=128)
|
||||||
def _select_tokenizer(
|
def _select_tokenizer(
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -16,7 +16,7 @@ import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from typing import Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import create_batch, create_file
|
from litellm import create_batch, create_file
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
@ -24,14 +24,36 @@ from test_gcs_bucket import load_vertex_ai_credentials
|
||||||
|
|
||||||
verbose_logger.setLevel(logging.DEBUG)
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomLogger(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(
|
||||||
|
"Success event logged with kwargs=",
|
||||||
|
kwargs,
|
||||||
|
"and response_obj=",
|
||||||
|
response_obj,
|
||||||
|
)
|
||||||
|
self.standard_logging_object = kwargs["standard_logging_object"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||||
def test_create_batch(provider):
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_batch(provider):
|
||||||
"""
|
"""
|
||||||
1. Create File for Batch completion
|
1. Create File for Batch completion
|
||||||
2. Create Batch Request
|
2. Create Batch Request
|
||||||
3. Retrieve the specific batch
|
3. Retrieve the specific batch
|
||||||
"""
|
"""
|
||||||
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.callbacks = [custom_logger, "datadog"]
|
||||||
|
|
||||||
if provider == "azure":
|
if provider == "azure":
|
||||||
# Don't have anymore Azure Quota
|
# Don't have anymore Azure Quota
|
||||||
return
|
return
|
||||||
|
@ -39,7 +61,7 @@ def test_create_batch(provider):
|
||||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
file_path = os.path.join(_current_dir, file_name)
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
|
||||||
file_obj = litellm.create_file(
|
file_obj = await litellm.acreate_file(
|
||||||
file=open(file_path, "rb"),
|
file=open(file_path, "rb"),
|
||||||
purpose="batch",
|
purpose="batch",
|
||||||
custom_llm_provider=provider,
|
custom_llm_provider=provider,
|
||||||
|
@ -51,8 +73,8 @@ def test_create_batch(provider):
|
||||||
batch_input_file_id is not None
|
batch_input_file_id is not None
|
||||||
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
time.sleep(5)
|
await asyncio.sleep(1)
|
||||||
create_batch_response = litellm.create_batch(
|
create_batch_response = await litellm.acreate_batch(
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
|
@ -61,7 +83,14 @@ def test_create_batch(provider):
|
||||||
)
|
)
|
||||||
|
|
||||||
print("response from litellm.create_batch=", create_batch_response)
|
print("response from litellm.create_batch=", create_batch_response)
|
||||||
|
await asyncio.sleep(6)
|
||||||
|
|
||||||
|
# Assert that the create batch event is logged on CustomLogger
|
||||||
|
assert custom_logger.standard_logging_object is not None
|
||||||
|
print(
|
||||||
|
"standard_logging_object=",
|
||||||
|
json.dumps(custom_logger.standard_logging_object, indent=4),
|
||||||
|
)
|
||||||
assert (
|
assert (
|
||||||
create_batch_response.id is not None
|
create_batch_response.id is not None
|
||||||
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
||||||
|
@ -73,7 +102,7 @@ def test_create_batch(provider):
|
||||||
create_batch_response.input_file_id == batch_input_file_id
|
create_batch_response.input_file_id == batch_input_file_id
|
||||||
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||||
|
|
||||||
retrieved_batch = litellm.retrieve_batch(
|
retrieved_batch = await litellm.aretrieve_batch(
|
||||||
batch_id=create_batch_response.id, custom_llm_provider=provider
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
print("retrieved batch=", retrieved_batch)
|
print("retrieved batch=", retrieved_batch)
|
||||||
|
@ -82,10 +111,10 @@ def test_create_batch(provider):
|
||||||
assert retrieved_batch.id == create_batch_response.id
|
assert retrieved_batch.id == create_batch_response.id
|
||||||
|
|
||||||
# list all batches
|
# list all batches
|
||||||
list_batches = litellm.list_batches(custom_llm_provider=provider, limit=2)
|
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2)
|
||||||
print("list_batches=", list_batches)
|
print("list_batches=", list_batches)
|
||||||
|
|
||||||
file_content = litellm.file_content(
|
file_content = await litellm.afile_content(
|
||||||
file_id=batch_input_file_id, custom_llm_provider=provider
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue