(feat) Add basic logging support for /batches endpoints (#7381)

* add basic logging for create`batch`

* add create_batch as a call type

* add basic dd logging for batches

* basic batch creation logging on DD
This commit is contained in:
Ishaan Jaff 2024-12-23 17:45:03 -08:00 committed by GitHub
parent 6f6c651ee0
commit 87f19d6f13
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 81 additions and 43 deletions

View file

@ -712,16 +712,8 @@ def client(original_function): # noqa: PLR0915
def wrapper(*args, **kwargs): # noqa: PLR0915
# DO NOT MOVE THIS. It always needs to run first
# Check if this is an async function. If so only execute the async function
if (
kwargs.get("acompletion", False) is True
or kwargs.get("aembedding", False) is True
or kwargs.get("aimg_generation", False) is True
or kwargs.get("amoderation", False) is True
or kwargs.get("atext_completion", False) is True
or kwargs.get("atranscription", False) is True
or kwargs.get("arerank", False) is True
or kwargs.get("_arealtime", False) is True
):
call_type = original_function.__name__
if _is_async_request(kwargs):
# [OPTIONAL] CHECK MAX RETRIES / REQUEST
if litellm.num_retries_per_request is not None:
# check if previous_models passed in as ['litellm_params']['metadata]['previous_models']
@ -759,20 +751,10 @@ def client(original_function): # noqa: PLR0915
)
# only set litellm_call_id if its not in kwargs
call_type = original_function.__name__
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
model: Optional[str] = None
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except Exception:
model = None
if (
call_type != CallTypes.image_generation.value
and call_type != CallTypes.text_completion.value
):
raise ValueError("model param not passed in.")
model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
try:
if logging_obj is None:
@ -1022,16 +1004,7 @@ def client(original_function): # noqa: PLR0915
if "litellm_call_id" not in kwargs:
kwargs["litellm_call_id"] = str(uuid.uuid4())
model = ""
try:
model = args[0] if len(args) > 0 else kwargs["model"]
except Exception:
if (
call_type != CallTypes.aimage_generation.value # model optional
and call_type != CallTypes.atext_completion.value # can also be engine
and call_type != CallTypes.amoderation.value
):
raise ValueError("model param not passed in.")
model: Optional[str] = args[0] if len(args) > 0 else kwargs.get("model", None)
try:
if logging_obj is None:
@ -1054,7 +1027,7 @@ def client(original_function): # noqa: PLR0915
)
_caching_handler_response: CachingHandlerResponse = (
await _llm_caching_handler._async_get_cache(
model=model,
model=model or "",
original_function=original_function,
logging_obj=logging_obj,
start_time=start_time,
@ -1100,7 +1073,7 @@ def client(original_function): # noqa: PLR0915
"id", None
)
result._hidden_params["api_base"] = get_api_base(
model=model,
model=model or "",
optional_params=kwargs,
)
result._hidden_params["response_cost"] = (
@ -1230,6 +1203,29 @@ def client(original_function): # noqa: PLR0915
return wrapper
def _is_async_request(kwargs: Optional[dict]) -> bool:
"""
Returns True if the call type is an internal async request.
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
"""
if kwargs is None:
return False
if (
kwargs.get("acompletion", False) is True
or kwargs.get("aembedding", False) is True
or kwargs.get("aimg_generation", False) is True
or kwargs.get("amoderation", False) is True
or kwargs.get("atext_completion", False) is True
or kwargs.get("atranscription", False) is True
or kwargs.get("arerank", False) is True
or kwargs.get("_arealtime", False) is True
or kwargs.get("acreate_batch", False) is True
):
return True
return False
@lru_cache(maxsize=128)
def _select_tokenizer(
model: str,