(feat) Add cost tracking for /batches requests OpenAI (#7384)

* add basic logging for create`batch`

* add create_batch as a call type

* add basic dd logging for batches

* basic batch creation logging on DD

* batch endpoints add cost calc

* fix batches_async_logging

* separate folder for batches testing

* new job for batches tests

* test batches logging

* fix validation logic

* add vertex_batch_completions.jsonl

* test test_async_create_batch

* test_async_create_batch

* update tests

* test_completion_with_no_model

* remove dead code

* update load_vertex_ai_credentials

* test_avertex_batch_prediction

* update get async httpx client

* fix get_async_httpx_client

* update test_avertex_batch_prediction

* fix batches testing config.yaml

* add google deps

* fix vertex files handler
This commit is contained in:
Ishaan Jaff 2024-12-23 17:47:26 -08:00 committed by GitHub
parent 9d66976162
commit 00544b97c8
13 changed files with 649 additions and 78 deletions

View file

@ -27,6 +27,8 @@ from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRe
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import client, supports_httpx_timeout
from .batch_utils import batches_async_logging
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
@ -71,10 +73,22 @@ async def acreate_batch(
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
response = init_response
# Start async logging job
if response is not None:
asyncio.create_task(
batches_async_logging(
logging_obj=kwargs.get("litellm_logging_obj", None),
batch_id=response.id,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
)
return response
except Exception as e:
@ -238,7 +252,7 @@ def create_batch(
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -279,7 +293,7 @@ async def aretrieve_batch(
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -552,7 +566,6 @@ def list_batches(
return response
except Exception as e:
raise e
pass
def cancel_batch():