mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(feat) /batches
- track user_api_key_alias
, user_api_key_team_alias
etc for /batch requests (#7401)
* run azure testing on ci/cd * update docs on azure batches endpoints * add input azure.jsonl * refactor - use separate file for batches endpoints * fixes for passing custom llm provider to /batch endpoints * pass custom llm provider to files endpoints * update azure batches doc * add info for azure batches api * update batches endpoints * use simple helper for raising proxy exception * update config.yml * fix imports * add type hints to get_litellm_params * update get_litellm_params * update get_litellm_params * update get slp * QOL - stop double logging a create batch operations on custom loggers * re use slp from og event * _create_standard_logging_object_for_completed_batch * fix linting errors * reduce num changes in PR * update BATCH_STATUS_POLL_MAX_ATTEMPTS
This commit is contained in:
parent
47e12802df
commit
08a4c72692
9 changed files with 72 additions and 29 deletions
|
@ -277,17 +277,8 @@ def _create_standard_logging_object_for_completed_batch(
|
|||
"""
|
||||
Create a standard logging object for a completed batch
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
standard_logging_object = logging_obj.model_call_details.get(
|
||||
"standard_logging_object", None
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
|
|
|
@ -19,13 +19,14 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.azure.azure import AzureBatchesAPI
|
||||
from litellm.llms.openai.openai import OpenAIBatchesAPI
|
||||
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import client, supports_httpx_timeout
|
||||
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
|
||||
|
||||
from .batch_utils import batches_async_logging
|
||||
|
||||
|
@ -114,9 +115,22 @@ def create_batch(
|
|||
try:
|
||||
optional_params = GenericLiteLLMParams(**kwargs)
|
||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
litellm_params = get_litellm_params(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
litellm_call_id=kwargs.get("litellm_call_id", None),
|
||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
)
|
||||
litellm_logging_obj.update_environment_variables(
|
||||
model=None,
|
||||
user=None,
|
||||
optional_params=optional_params.model_dump(),
|
||||
litellm_params=litellm_params,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
if (
|
||||
timeout is not None
|
||||
|
|
|
@ -93,5 +93,5 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
|||
"optimize-prompt/",
|
||||
]
|
||||
|
||||
BATCH_STATUS_POLL_INTERVAL_SECONDS = 10
|
||||
BATCH_STATUS_POLL_MAX_ATTEMPTS = 10
|
||||
BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
|
||||
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
|
||||
|
|
|
@ -388,10 +388,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
|||
return standard_callback_dynamic_params
|
||||
|
||||
def update_environment_variables(
|
||||
self, model, user, optional_params, litellm_params, **additional_params
|
||||
self,
|
||||
litellm_params: Dict,
|
||||
optional_params: Dict,
|
||||
model: Optional[str] = None,
|
||||
user: Optional[str] = None,
|
||||
**additional_params,
|
||||
):
|
||||
self.optional_params = optional_params
|
||||
self.model = model
|
||||
if model is not None:
|
||||
self.model = model
|
||||
self.user = user
|
||||
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
|
||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||
|
@ -2885,9 +2891,11 @@ def get_standard_logging_object_payload(
|
|||
litellm_params = kwargs.get("litellm_params", {})
|
||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
) # if litellm_params['metadata'] == None
|
||||
metadata: dict = (
|
||||
litellm_params.get("litellm_metadata")
|
||||
or litellm_params.get("metadata", None)
|
||||
or {}
|
||||
)
|
||||
completion_start_time = kwargs.get("completion_start_time", end_time)
|
||||
call_type = kwargs.get("call_type")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
|
|
|
@ -1140,6 +1140,7 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||
hf_model_name=hf_model_name,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||
)
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
|
|
@ -138,7 +138,7 @@ def rerank( # noqa: PLR0915
|
|||
litellm_logging_obj.update_environment_variables(
|
||||
model=model,
|
||||
user=user,
|
||||
optional_params=optional_rerank_params,
|
||||
optional_params=dict(optional_rerank_params),
|
||||
litellm_params={
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"proxy_server_request": proxy_server_request,
|
||||
|
|
|
@ -1592,6 +1592,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
|
|||
|
||||
all_litellm_params = [
|
||||
"metadata",
|
||||
"litellm_metadata",
|
||||
"litellm_trace_id",
|
||||
"tags",
|
||||
"acompletion",
|
||||
|
|
|
@ -1203,11 +1203,18 @@ def client(original_function): # noqa: PLR0915
|
|||
return wrapper
|
||||
|
||||
|
||||
def _is_async_request(kwargs: Optional[dict]) -> bool:
|
||||
def _is_async_request(
|
||||
kwargs: Optional[dict],
|
||||
is_pass_through: bool = False,
|
||||
) -> bool:
|
||||
"""
|
||||
Returns True if the call type is an internal async request.
|
||||
|
||||
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
|
||||
|
||||
Args:
|
||||
kwargs (dict): The kwargs passed to the litellm function
|
||||
is_pass_through (bool): Whether the call is a pass-through call. By default all pass through calls are async.
|
||||
"""
|
||||
if kwargs is None:
|
||||
return False
|
||||
|
@ -1221,6 +1228,7 @@ def _is_async_request(kwargs: Optional[dict]) -> bool:
|
|||
or kwargs.get("arerank", False) is True
|
||||
or kwargs.get("_arealtime", False) is True
|
||||
or kwargs.get("acreate_batch", False) is True
|
||||
or is_pass_through is True
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
@ -1927,7 +1935,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
|
|||
|
||||
|
||||
def get_litellm_params(
|
||||
api_key=None,
|
||||
api_key: Optional[str] = None,
|
||||
force_timeout=600,
|
||||
azure=False,
|
||||
logger_fn=None,
|
||||
|
@ -1935,12 +1943,12 @@ def get_litellm_params(
|
|||
hugging_face=False,
|
||||
replicate=False,
|
||||
together_ai=False,
|
||||
custom_llm_provider=None,
|
||||
api_base=None,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
litellm_call_id=None,
|
||||
model_alias_map=None,
|
||||
completion_call_id=None,
|
||||
metadata=None,
|
||||
metadata: Optional[dict] = None,
|
||||
model_info=None,
|
||||
proxy_server_request=None,
|
||||
acompletion=None,
|
||||
|
@ -1954,10 +1962,11 @@ def get_litellm_params(
|
|||
text_completion=None,
|
||||
azure_ad_token_provider=None,
|
||||
user_continue_message=None,
|
||||
base_model=None,
|
||||
litellm_trace_id=None,
|
||||
base_model: Optional[str] = None,
|
||||
litellm_trace_id: Optional[str] = None,
|
||||
hf_model_name: Optional[str] = None,
|
||||
custom_prompt_dict: Optional[dict] = None,
|
||||
litellm_metadata: Optional[dict] = None,
|
||||
):
|
||||
litellm_params = {
|
||||
"acompletion": acompletion,
|
||||
|
@ -1989,8 +1998,8 @@ def get_litellm_params(
|
|||
"litellm_trace_id": litellm_trace_id,
|
||||
"hf_model_name": hf_model_name,
|
||||
"custom_prompt_dict": custom_prompt_dict,
|
||||
"litellm_metadata": litellm_metadata,
|
||||
}
|
||||
|
||||
return litellm_params
|
||||
|
||||
|
||||
|
|
|
@ -12,6 +12,7 @@ load_dotenv()
|
|||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
|
||||
import logging
|
||||
import time
|
||||
|
||||
|
@ -191,12 +192,18 @@ async def test_async_create_batch(provider):
|
|||
batch_input_file_id is not None
|
||||
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||
|
||||
extra_metadata_field = {
|
||||
"user_api_key_alias": "special_api_key_alias",
|
||||
"user_api_key_team_alias": "special_team_alias",
|
||||
}
|
||||
create_batch_response = await litellm.acreate_batch(
|
||||
completion_window="24h",
|
||||
endpoint="/v1/chat/completions",
|
||||
input_file_id=batch_input_file_id,
|
||||
custom_llm_provider=provider,
|
||||
metadata={"key1": "value1", "key2": "value2"},
|
||||
# litellm specific param - used for logging metadata on logging callback
|
||||
litellm_metadata=extra_metadata_field,
|
||||
)
|
||||
|
||||
print("response from litellm.create_batch=", create_batch_response)
|
||||
|
@ -215,6 +222,18 @@ async def test_async_create_batch(provider):
|
|||
await asyncio.sleep(6)
|
||||
# Assert that the create batch event is logged on CustomLogger
|
||||
assert custom_logger.standard_logging_object is not None
|
||||
print(
|
||||
"standard_logging_object=",
|
||||
json.dumps(custom_logger.standard_logging_object, indent=4, default=str),
|
||||
)
|
||||
assert (
|
||||
custom_logger.standard_logging_object["metadata"]["user_api_key_alias"]
|
||||
== extra_metadata_field["user_api_key_alias"]
|
||||
)
|
||||
assert (
|
||||
custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"]
|
||||
== extra_metadata_field["user_api_key_team_alias"]
|
||||
)
|
||||
|
||||
retrieved_batch = await litellm.aretrieve_batch(
|
||||
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue