mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) /batches
- track user_api_key_alias
, user_api_key_team_alias
etc for /batch requests (#7401)
* run azure testing on ci/cd * update docs on azure batches endpoints * add input azure.jsonl * refactor - use separate file for batches endpoints * fixes for passing custom llm provider to /batch endpoints * pass custom llm provider to files endpoints * update azure batches doc * add info for azure batches api * update batches endpoints * use simple helper for raising proxy exception * update config.yml * fix imports * add type hints to get_litellm_params * update get_litellm_params * update get_litellm_params * update get slp * QOL - stop double logging a create batch operations on custom loggers * re use slp from og event * _create_standard_logging_object_for_completed_batch * fix linting errors * reduce num changes in PR * update BATCH_STATUS_POLL_MAX_ATTEMPTS
This commit is contained in:
parent
0627450808
commit
e98f1d16fd
9 changed files with 72 additions and 29 deletions
|
@ -277,17 +277,8 @@ def _create_standard_logging_object_for_completed_batch(
|
||||||
"""
|
"""
|
||||||
Create a standard logging object for a completed batch
|
Create a standard logging object for a completed batch
|
||||||
"""
|
"""
|
||||||
from litellm.litellm_core_utils.litellm_logging import (
|
standard_logging_object = logging_obj.model_call_details.get(
|
||||||
get_standard_logging_object_payload,
|
"standard_logging_object", None
|
||||||
)
|
|
||||||
|
|
||||||
standard_logging_object = get_standard_logging_object_payload(
|
|
||||||
kwargs=kwargs,
|
|
||||||
init_response_obj=None,
|
|
||||||
start_time=start_time,
|
|
||||||
end_time=end_time,
|
|
||||||
logging_obj=logging_obj,
|
|
||||||
status="success",
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if standard_logging_object is None:
|
if standard_logging_object is None:
|
||||||
|
|
|
@ -19,13 +19,14 @@ from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
from litellm.llms.azure.azure import AzureBatchesAPI
|
from litellm.llms.azure.azure import AzureBatchesAPI
|
||||||
from litellm.llms.openai.openai import OpenAIBatchesAPI
|
from litellm.llms.openai.openai import OpenAIBatchesAPI
|
||||||
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
|
from litellm.llms.vertex_ai.batches.handler import VertexAIBatchPrediction
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
|
from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRequest
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import client, supports_httpx_timeout
|
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
|
||||||
|
|
||||||
from .batch_utils import batches_async_logging
|
from .batch_utils import batches_async_logging
|
||||||
|
|
||||||
|
@ -114,9 +115,22 @@ def create_batch(
|
||||||
try:
|
try:
|
||||||
optional_params = GenericLiteLLMParams(**kwargs)
|
optional_params = GenericLiteLLMParams(**kwargs)
|
||||||
_is_async = kwargs.pop("acreate_batch", False) is True
|
_is_async = kwargs.pop("acreate_batch", False) is True
|
||||||
|
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
# set timeout for 10 minutes by default
|
litellm_params = get_litellm_params(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
litellm_call_id=kwargs.get("litellm_call_id", None),
|
||||||
|
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||||
|
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||||
|
)
|
||||||
|
litellm_logging_obj.update_environment_variables(
|
||||||
|
model=None,
|
||||||
|
user=None,
|
||||||
|
optional_params=optional_params.model_dump(),
|
||||||
|
litellm_params=litellm_params,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
if (
|
if (
|
||||||
timeout is not None
|
timeout is not None
|
||||||
|
|
|
@ -93,5 +93,5 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
||||||
"optimize-prompt/",
|
"optimize-prompt/",
|
||||||
]
|
]
|
||||||
|
|
||||||
BATCH_STATUS_POLL_INTERVAL_SECONDS = 10
|
BATCH_STATUS_POLL_INTERVAL_SECONDS = 3600 # 1 hour
|
||||||
BATCH_STATUS_POLL_MAX_ATTEMPTS = 10
|
BATCH_STATUS_POLL_MAX_ATTEMPTS = 24 # for 24 hours
|
||||||
|
|
|
@ -388,10 +388,16 @@ class Logging(LiteLLMLoggingBaseClass):
|
||||||
return standard_callback_dynamic_params
|
return standard_callback_dynamic_params
|
||||||
|
|
||||||
def update_environment_variables(
|
def update_environment_variables(
|
||||||
self, model, user, optional_params, litellm_params, **additional_params
|
self,
|
||||||
|
litellm_params: Dict,
|
||||||
|
optional_params: Dict,
|
||||||
|
model: Optional[str] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
**additional_params,
|
||||||
):
|
):
|
||||||
self.optional_params = optional_params
|
self.optional_params = optional_params
|
||||||
self.model = model
|
if model is not None:
|
||||||
|
self.model = model
|
||||||
self.user = user
|
self.user = user
|
||||||
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
|
self.litellm_params = scrub_sensitive_keys_in_metadata(litellm_params)
|
||||||
self.logger_fn = litellm_params.get("logger_fn", None)
|
self.logger_fn = litellm_params.get("logger_fn", None)
|
||||||
|
@ -2885,9 +2891,11 @@ def get_standard_logging_object_payload(
|
||||||
litellm_params = kwargs.get("litellm_params", {})
|
litellm_params = kwargs.get("litellm_params", {})
|
||||||
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
proxy_server_request = litellm_params.get("proxy_server_request") or {}
|
||||||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||||
metadata = (
|
metadata: dict = (
|
||||||
litellm_params.get("metadata", {}) or {}
|
litellm_params.get("litellm_metadata")
|
||||||
) # if litellm_params['metadata'] == None
|
or litellm_params.get("metadata", None)
|
||||||
|
or {}
|
||||||
|
)
|
||||||
completion_start_time = kwargs.get("completion_start_time", end_time)
|
completion_start_time = kwargs.get("completion_start_time", end_time)
|
||||||
call_type = kwargs.get("call_type")
|
call_type = kwargs.get("call_type")
|
||||||
cache_hit = kwargs.get("cache_hit", False)
|
cache_hit = kwargs.get("cache_hit", False)
|
||||||
|
|
|
@ -1140,6 +1140,7 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
litellm_trace_id=kwargs.get("litellm_trace_id"),
|
||||||
hf_model_name=hf_model_name,
|
hf_model_name=hf_model_name,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
|
litellm_metadata=kwargs.get("litellm_metadata"),
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(
|
logging.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
|
|
|
@ -138,7 +138,7 @@ def rerank( # noqa: PLR0915
|
||||||
litellm_logging_obj.update_environment_variables(
|
litellm_logging_obj.update_environment_variables(
|
||||||
model=model,
|
model=model,
|
||||||
user=user,
|
user=user,
|
||||||
optional_params=optional_rerank_params,
|
optional_params=dict(optional_rerank_params),
|
||||||
litellm_params={
|
litellm_params={
|
||||||
"litellm_call_id": litellm_call_id,
|
"litellm_call_id": litellm_call_id,
|
||||||
"proxy_server_request": proxy_server_request,
|
"proxy_server_request": proxy_server_request,
|
||||||
|
|
|
@ -1592,6 +1592,7 @@ class StandardCallbackDynamicParams(TypedDict, total=False):
|
||||||
|
|
||||||
all_litellm_params = [
|
all_litellm_params = [
|
||||||
"metadata",
|
"metadata",
|
||||||
|
"litellm_metadata",
|
||||||
"litellm_trace_id",
|
"litellm_trace_id",
|
||||||
"tags",
|
"tags",
|
||||||
"acompletion",
|
"acompletion",
|
||||||
|
|
|
@ -1203,11 +1203,18 @@ def client(original_function): # noqa: PLR0915
|
||||||
return wrapper
|
return wrapper
|
||||||
|
|
||||||
|
|
||||||
def _is_async_request(kwargs: Optional[dict]) -> bool:
|
def _is_async_request(
|
||||||
|
kwargs: Optional[dict],
|
||||||
|
is_pass_through: bool = False,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Returns True if the call type is an internal async request.
|
Returns True if the call type is an internal async request.
|
||||||
|
|
||||||
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
|
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
|
||||||
|
|
||||||
|
Args:
|
||||||
|
kwargs (dict): The kwargs passed to the litellm function
|
||||||
|
is_pass_through (bool): Whether the call is a pass-through call. By default all pass through calls are async.
|
||||||
"""
|
"""
|
||||||
if kwargs is None:
|
if kwargs is None:
|
||||||
return False
|
return False
|
||||||
|
@ -1221,6 +1228,7 @@ def _is_async_request(kwargs: Optional[dict]) -> bool:
|
||||||
or kwargs.get("arerank", False) is True
|
or kwargs.get("arerank", False) is True
|
||||||
or kwargs.get("_arealtime", False) is True
|
or kwargs.get("_arealtime", False) is True
|
||||||
or kwargs.get("acreate_batch", False) is True
|
or kwargs.get("acreate_batch", False) is True
|
||||||
|
or is_pass_through is True
|
||||||
):
|
):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
@ -1927,7 +1935,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
|
||||||
|
|
||||||
|
|
||||||
def get_litellm_params(
|
def get_litellm_params(
|
||||||
api_key=None,
|
api_key: Optional[str] = None,
|
||||||
force_timeout=600,
|
force_timeout=600,
|
||||||
azure=False,
|
azure=False,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -1935,12 +1943,12 @@ def get_litellm_params(
|
||||||
hugging_face=False,
|
hugging_face=False,
|
||||||
replicate=False,
|
replicate=False,
|
||||||
together_ai=False,
|
together_ai=False,
|
||||||
custom_llm_provider=None,
|
custom_llm_provider: Optional[str] = None,
|
||||||
api_base=None,
|
api_base: Optional[str] = None,
|
||||||
litellm_call_id=None,
|
litellm_call_id=None,
|
||||||
model_alias_map=None,
|
model_alias_map=None,
|
||||||
completion_call_id=None,
|
completion_call_id=None,
|
||||||
metadata=None,
|
metadata: Optional[dict] = None,
|
||||||
model_info=None,
|
model_info=None,
|
||||||
proxy_server_request=None,
|
proxy_server_request=None,
|
||||||
acompletion=None,
|
acompletion=None,
|
||||||
|
@ -1954,10 +1962,11 @@ def get_litellm_params(
|
||||||
text_completion=None,
|
text_completion=None,
|
||||||
azure_ad_token_provider=None,
|
azure_ad_token_provider=None,
|
||||||
user_continue_message=None,
|
user_continue_message=None,
|
||||||
base_model=None,
|
base_model: Optional[str] = None,
|
||||||
litellm_trace_id=None,
|
litellm_trace_id: Optional[str] = None,
|
||||||
hf_model_name: Optional[str] = None,
|
hf_model_name: Optional[str] = None,
|
||||||
custom_prompt_dict: Optional[dict] = None,
|
custom_prompt_dict: Optional[dict] = None,
|
||||||
|
litellm_metadata: Optional[dict] = None,
|
||||||
):
|
):
|
||||||
litellm_params = {
|
litellm_params = {
|
||||||
"acompletion": acompletion,
|
"acompletion": acompletion,
|
||||||
|
@ -1989,8 +1998,8 @@ def get_litellm_params(
|
||||||
"litellm_trace_id": litellm_trace_id,
|
"litellm_trace_id": litellm_trace_id,
|
||||||
"hf_model_name": hf_model_name,
|
"hf_model_name": hf_model_name,
|
||||||
"custom_prompt_dict": custom_prompt_dict,
|
"custom_prompt_dict": custom_prompt_dict,
|
||||||
|
"litellm_metadata": litellm_metadata,
|
||||||
}
|
}
|
||||||
|
|
||||||
return litellm_params
|
return litellm_params
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -12,6 +12,7 @@ load_dotenv()
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system-path
|
) # Adds the parent directory to the system-path
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
@ -191,12 +192,18 @@ async def test_async_create_batch(provider):
|
||||||
batch_input_file_id is not None
|
batch_input_file_id is not None
|
||||||
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
||||||
|
|
||||||
|
extra_metadata_field = {
|
||||||
|
"user_api_key_alias": "special_api_key_alias",
|
||||||
|
"user_api_key_team_alias": "special_team_alias",
|
||||||
|
}
|
||||||
create_batch_response = await litellm.acreate_batch(
|
create_batch_response = await litellm.acreate_batch(
|
||||||
completion_window="24h",
|
completion_window="24h",
|
||||||
endpoint="/v1/chat/completions",
|
endpoint="/v1/chat/completions",
|
||||||
input_file_id=batch_input_file_id,
|
input_file_id=batch_input_file_id,
|
||||||
custom_llm_provider=provider,
|
custom_llm_provider=provider,
|
||||||
metadata={"key1": "value1", "key2": "value2"},
|
metadata={"key1": "value1", "key2": "value2"},
|
||||||
|
# litellm specific param - used for logging metadata on logging callback
|
||||||
|
litellm_metadata=extra_metadata_field,
|
||||||
)
|
)
|
||||||
|
|
||||||
print("response from litellm.create_batch=", create_batch_response)
|
print("response from litellm.create_batch=", create_batch_response)
|
||||||
|
@ -215,6 +222,18 @@ async def test_async_create_batch(provider):
|
||||||
await asyncio.sleep(6)
|
await asyncio.sleep(6)
|
||||||
# Assert that the create batch event is logged on CustomLogger
|
# Assert that the create batch event is logged on CustomLogger
|
||||||
assert custom_logger.standard_logging_object is not None
|
assert custom_logger.standard_logging_object is not None
|
||||||
|
print(
|
||||||
|
"standard_logging_object=",
|
||||||
|
json.dumps(custom_logger.standard_logging_object, indent=4, default=str),
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
custom_logger.standard_logging_object["metadata"]["user_api_key_alias"]
|
||||||
|
== extra_metadata_field["user_api_key_alias"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"]
|
||||||
|
== extra_metadata_field["user_api_key_team_alias"]
|
||||||
|
)
|
||||||
|
|
||||||
retrieved_batch = await litellm.aretrieve_batch(
|
retrieved_batch = await litellm.aretrieve_batch(
|
||||||
batch_id=create_batch_response.id, custom_llm_provider=provider
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue