(feat) /batches - track user_api_key_alias, user_api_key_team_alias etc for /batch requests (#7401)

* run azure testing on ci/cd

* update docs on azure batches endpoints

* add input azure.jsonl

* refactor - use separate file for batches endpoints

* fixes for passing custom llm provider to /batch endpoints

* pass custom llm provider to files endpoints

* update azure batches doc

* add info for azure batches api

* update batches endpoints

* use simple helper for raising proxy exception

* update config.yml

* fix imports

* add type hints to get_litellm_params

* update get_litellm_params

* update get_litellm_params

* update get slp

* QOL - stop double logging a create batch operations on custom loggers

* re use slp from og event

* _create_standard_logging_object_for_completed_batch

* fix linting errors

* reduce num changes in PR

* update BATCH_STATUS_POLL_MAX_ATTEMPTS
This commit is contained in:
Ishaan Jaff 2024-12-24 17:44:28 -08:00 committed by GitHub
parent 0627450808
commit e98f1d16fd
9 changed files with 72 additions and 29 deletions

View file

@ -277,17 +277,8 @@ def _create_standard_logging_object_for_completed_batch(
"""
Create a standard logging object for a completed batch
"""
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=None,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
standard_logging_object = logging_obj.model_call_details.get(
"standard_logging_object", None
)
if standard_logging_object is None:

View file

@ -19,13 +19,14 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
import httpx
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.llms.azure.azure import AzureBatchesAPI
from litellm.llms.openai.openai import OpenAIBatchesAPI
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import client, supports_httpx_timeout
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
from .batch_utils import batches_async_logging
@ -114,9 +115,22 @@ def create_batch(
try:
optional_params = GenericLiteLLMParams(**kwargs)
_is_async = kwargs.pop("acreate_batch", False) is True
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if (
timeout is not None

View file

@ -93,5 +93,5 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
"optimize-prompt/",
]
BATCH_STATUS_POLL_INTERVAL_SECONDS = 10
BATCH_STATUS_POLL_MAX_ATTEMPTS = 10
BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours

View file

@ -388,9 +388,15 @@ class Logging(LiteLLMLoggingBaseClass):
return standard_callback_dynamic_params
def update_environment_variables(
self, model, user, optional_params, litellm_params, **additional_params
self,
litellm_params: Dict,
optional_params: Dict,
model: Optional[str] = None,
user: Optional[str] = None,
**additional_params,
):
self.optional_params = optional_params
if model is not None:
self.model = model
self.user = user
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
@ -2885,9 +2891,11 @@ def get_standard_logging_object_payload(
litellm_params = kwargs.get("litellm_params", {})
proxy_server_request = litellm_params.get("proxy_server_request") or {}
end_user_id = proxy_server_request.get("body", {}).get("user", None)
metadata = (
litellm_params.get("metadata", {}) or {}
) # if litellm_params['metadata'] == None
metadata: dict = (
litellm_params.get("litellm_metadata")
or litellm_params.get("metadata", None)
or {}
)
completion_start_time = kwargs.get("completion_start_time", end_time)
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)

View file

@ -1140,6 +1140,7 @@ def completion( # type: ignore # noqa: PLR0915
litellm_trace_id=kwargs.get("litellm_trace_id"),
hf_model_name=hf_model_name,
custom_prompt_dict=custom_prompt_dict,
litellm_metadata=kwargs.get("litellm_metadata"),
)
logging.update_environment_variables(
model=model,

View file

@ -138,7 +138,7 @@ def rerank( # noqa: PLR0915
litellm_logging_obj.update_environment_variables(
model=model,
user=user,
optional_params=optional_rerank_params,
optional_params=dict(optional_rerank_params),
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,

View file

@ -1592,6 +1592,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
all_litellm_params = [
"metadata",
"litellm_metadata",
"litellm_trace_id",
"tags",
"acompletion",

View file

@ -1203,11 +1203,18 @@ def client(original_function): # noqa: PLR0915
return wrapper
def _is_async_request(kwargs: Optional[dict]) -> bool:
def _is_async_request(
kwargs: Optional[dict],
is_pass_through: bool = False,
) -> bool:
"""
Returns True if the call type is an internal async request.
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
Args:
kwargs (dict): The kwargs passed to the litellm function
is_pass_through (bool): Whether the call is a pass-through call. By default all pass through calls are async.
"""
if kwargs is None:
return False
@ -1221,6 +1228,7 @@ def _is_async_request(kwargs: Optional[dict]) -> bool:
or kwargs.get("arerank", False) is True
or kwargs.get("_arealtime", False) is True
or kwargs.get("acreate_batch", False) is True
or is_pass_through is True
):
return True
return False
@ -1927,7 +1935,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
def get_litellm_params(
api_key=None,
api_key: Optional[str] = None,
force_timeout=600,
azure=False,
logger_fn=None,
@ -1935,12 +1943,12 @@ def get_litellm_params(
hugging_face=False,
replicate=False,
together_ai=False,
custom_llm_provider=None,
api_base=None,
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
litellm_call_id=None,
model_alias_map=None,
completion_call_id=None,
metadata=None,
metadata: Optional[dict] = None,
model_info=None,
proxy_server_request=None,
acompletion=None,
@ -1954,10 +1962,11 @@ def get_litellm_params(
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
base_model=None,
litellm_trace_id=None,
base_model: Optional[str] = None,
litellm_trace_id: Optional[str] = None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None,
):
litellm_params = {
"acompletion": acompletion,
@ -1989,8 +1998,8 @@ def get_litellm_params(
"litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata,
}
return litellm_params

View file

@ -12,6 +12,7 @@ load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import logging
import time
@ -191,12 +192,18 @@ async def test_async_create_batch(provider):
batch_input_file_id is not None
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
extra_metadata_field = {
"user_api_key_alias": "special_api_key_alias",
"user_api_key_team_alias": "special_team_alias",
}
create_batch_response = await litellm.acreate_batch(
completion_window="24h",
endpoint="/v1/chat/completions",
input_file_id=batch_input_file_id,
custom_llm_provider=provider,
metadata={"key1": "value1", "key2": "value2"},
# litellm specific param - used for logging metadata on logging callback
litellm_metadata=extra_metadata_field,
)
print("response from litellm.create_batch=", create_batch_response)
@ -215,6 +222,18 @@ async def test_async_create_batch(provider):
await asyncio.sleep(6)
# Assert that the create batch event is logged on CustomLogger
assert custom_logger.standard_logging_object is not None
print(
"standard_logging_object=",
json.dumps(custom_logger.standard_logging_object, indent=4, default=str),
)
assert (
custom_logger.standard_logging_object["metadata"]["user_api_key_alias"]
== extra_metadata_field["user_api_key_alias"]
)
assert (
custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"]
== extra_metadata_field["user_api_key_team_alias"]
)
retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id, custom_llm_provider=provider