(feat) /batches - track user_api_key_alias, user_api_key_team_alias etc for /batch requests (#7401)

* run azure testing on ci/cd

* update docs on azure batches endpoints

* add input azure.jsonl

* refactor - use separate file for batches endpoints

* fixes for passing custom llm provider to /batch endpoints

* pass custom llm provider to files endpoints

* update azure batches doc

* add info for azure batches api

* update batches endpoints

* use simple helper for raising proxy exception

* update config.yml

* fix imports

* add type hints to get_litellm_params

* update get_litellm_params

* update get_litellm_params

* update get slp

* QOL - stop double logging a create batch operations on custom loggers

* re use slp from og event

* _create_standard_logging_object_for_completed_batch

* fix linting errors

* reduce num changes in PR

* update BATCH_STATUS_POLL_MAX_ATTEMPTS
This commit is contained in:
Ishaan Jaff 2024-12-24 17:44:28 -08:00 committed by GitHub
parent 47e12802df
commit 08a4c72692
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 72 additions and 29 deletions

View file

@ -1203,11 +1203,18 @@ def client(original_function): # noqa: PLR0915
return wrapper
def _is_async_request(kwargs: Optional[dict]) -> bool:
def _is_async_request(
kwargs: Optional[dict],
is_pass_through: bool = False,
) -> bool:
"""
Returns True if the call type is an internal async request.
eg. litellm.acompletion, litellm.aimage_generation, litellm.acreate_batch, litellm._arealtime
Args:
kwargs (dict): The kwargs passed to the litellm function
is_pass_through (bool): Whether the call is a pass-through call. By default all pass through calls are async.
"""
if kwargs is None:
return False
@ -1221,6 +1228,7 @@ def _is_async_request(kwargs: Optional[dict]) -> bool:
or kwargs.get("arerank", False) is True
or kwargs.get("_arealtime", False) is True
or kwargs.get("acreate_batch", False) is True
or is_pass_through is True
):
return True
return False
@ -1927,7 +1935,7 @@ def register_model(model_cost: Union[str, dict]): # noqa: PLR0915
def get_litellm_params(
api_key=None,
api_key: Optional[str] = None,
force_timeout=600,
azure=False,
logger_fn=None,
@ -1935,12 +1943,12 @@ def get_litellm_params(
hugging_face=False,
replicate=False,
together_ai=False,
custom_llm_provider=None,
api_base=None,
custom_llm_provider: Optional[str] = None,
api_base: Optional[str] = None,
litellm_call_id=None,
model_alias_map=None,
completion_call_id=None,
metadata=None,
metadata: Optional[dict] = None,
model_info=None,
proxy_server_request=None,
acompletion=None,
@ -1954,10 +1962,11 @@ def get_litellm_params(
text_completion=None,
azure_ad_token_provider=None,
user_continue_message=None,
base_model=None,
litellm_trace_id=None,
base_model: Optional[str] = None,
litellm_trace_id: Optional[str] = None,
hf_model_name: Optional[str] = None,
custom_prompt_dict: Optional[dict] = None,
litellm_metadata: Optional[dict] = None,
):
litellm_params = {
"acompletion": acompletion,
@ -1989,8 +1998,8 @@ def get_litellm_params(
"litellm_trace_id": litellm_trace_id,
"hf_model_name": hf_model_name,
"custom_prompt_dict": custom_prompt_dict,
"litellm_metadata": litellm_metadata,
}
return litellm_params