Litellm dev 03 04 2025 p3 (#8997)

* fix(core_helpers.py): handle litellm_metadata instead of 'metadata'

* feat(batches/): ensure batches logs are written to db

makes batches response dict compatible

* fix(cost_calculator.py): handle batch response being a dictionary

* fix(batches/main.py): modify retrieve endpoints to use @client decorator

enables logging to work on retrieve call

* fix(batches/main.py): fix retrieve batch response type to be 'dict' compatible

* fix(spend_tracking_utils.py): send unique uuid for retrieve batch call type

create batch and retrieve batch share the same id

* fix(spend_tracking_utils.py): prevent duplicate retrieve batch calls from being double counted

* refactor(batches/): refactor cost tracking for batches - do it on retrieve, and within the established litellm_logging pipeline

ensures cost is always logged to db

* fix: fix linting errors

* fix: fix linting error
This commit is contained in:
Krish Dholakia 2025-03-04 21:58:03 -08:00 committed by GitHub
parent f2a9d67e05
commit b43b8dc21c
17 changed files with 314 additions and 219 deletions

View file

@ -31,10 +31,9 @@ from litellm.types.llms.openai import (
RetrieveBatchRequest,
)
from litellm.types.router import GenericLiteLLMParams
from litellm.types.utils import LiteLLMBatch
from litellm.utils import client, get_litellm_params, supports_httpx_timeout
from .batch_utils import batches_async_logging
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
@ -85,17 +84,6 @@ async def acreate_batch(
else:
response = init_response
# Start async logging job
if response is not None:
asyncio.create_task(
batches_async_logging(
logging_obj=kwargs.get("litellm_logging_obj", None),
batch_id=response.id,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
)
return response
except Exception as e:
raise e
@ -111,7 +99,7 @@ def create_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Creates and executes a batch from an uploaded file of request
@ -119,21 +107,26 @@ def create_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_call_id = kwargs.get("litellm_call_id", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
_is_async = kwargs.pop("acreate_batch", False) is True
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
litellm_params={
"litellm_call_id": litellm_call_id,
"proxy_server_request": proxy_server_request,
"model_info": model_info,
"metadata": metadata,
"preset_cache_key": None,
"stream_response": {},
**optional_params.model_dump(exclude_unset=True),
},
custom_llm_provider=custom_llm_provider,
)
@ -261,7 +254,7 @@ def create_batch(
response=httpx.Response(
status_code=400,
content="Unsupported provider",
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
request=httpx.Request(method="create_batch", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
@ -269,6 +262,7 @@ def create_batch(
raise e
@client
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@ -276,7 +270,7 @@ async def aretrieve_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
) -> LiteLLMBatch:
"""
Async: Retrieves a batch.
@ -310,6 +304,7 @@ async def aretrieve_batch(
raise e
@client
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
@ -317,7 +312,7 @@ def retrieve_batch(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Union[Batch, Coroutine[Any, Any, Batch]]:
) -> Union[LiteLLMBatch, Coroutine[Any, Any, LiteLLMBatch]]:
"""
Retrieves a batch.
@ -325,9 +320,23 @@ def retrieve_batch(
"""
try:
optional_params = GenericLiteLLMParams(**kwargs)
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj", None)
### TIMEOUT LOGIC ###
timeout = optional_params.timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default
litellm_params = get_litellm_params(
custom_llm_provider=custom_llm_provider,
litellm_call_id=kwargs.get("litellm_call_id", None),
litellm_trace_id=kwargs.get("litellm_trace_id"),
litellm_metadata=kwargs.get("litellm_metadata"),
)
litellm_logging_obj.update_environment_variables(
model=None,
user=None,
optional_params=optional_params.model_dump(),
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if (
timeout is not None