mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 02:34:29 +00:00
(feat) Add cost tracking for /batches requests OpenAI (#7384)
* add basic logging for create`batch` * add create_batch as a call type * add basic dd logging for batches * basic batch creation logging on DD * batch endpoints add cost calc * fix batches_async_logging * separate folder for batches testing * new job for batches tests * test batches logging * fix validation logic * add vertex_batch_completions.jsonl * test test_async_create_batch * test_async_create_batch * update tests * test_completion_with_no_model * remove dead code * update load_vertex_ai_credentials * test_avertex_batch_prediction * update get async httpx client * fix get_async_httpx_client * update test_avertex_batch_prediction * fix batches testing config.yaml * add google deps * fix vertex files handler
This commit is contained in:
parent
87f19d6f13
commit
05b0d2026f
13 changed files with 649 additions and 78 deletions
|
@ -626,6 +626,50 @@ jobs:
|
|||
paths:
|
||||
- llm_translation_coverage.xml
|
||||
- llm_translation_coverage
|
||||
batches_testing:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install -r requirements.txt
|
||||
pip install "respx==0.21.1"
|
||||
pip install "pytest==7.3.1"
|
||||
pip install "pytest-retry==1.6.3"
|
||||
pip install "pytest-asyncio==0.21.1"
|
||||
pip install "pytest-cov==5.0.0"
|
||||
pip install "google-generativeai==0.3.2"
|
||||
pip install "google-cloud-aiplatform==1.43.0"
|
||||
# Run pytest and generate JUnit XML report
|
||||
- run:
|
||||
name: Run tests
|
||||
command: |
|
||||
pwd
|
||||
ls
|
||||
python -m pytest -vv tests/batches_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
|
||||
no_output_timeout: 120m
|
||||
- run:
|
||||
name: Rename the coverage files
|
||||
command: |
|
||||
mv coverage.xml batches_coverage.xml
|
||||
mv .coverage batches_coverage
|
||||
|
||||
# Store test results
|
||||
- store_test_results:
|
||||
path: test-results
|
||||
- persist_to_workspace:
|
||||
root: .
|
||||
paths:
|
||||
- batches_coverage.xml
|
||||
- batches_coverage
|
||||
pass_through_unit_testing:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
|
@ -1417,7 +1461,7 @@ jobs:
|
|||
python -m venv venv
|
||||
. venv/bin/activate
|
||||
pip install coverage
|
||||
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage
|
||||
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage
|
||||
coverage xml
|
||||
- codecov/upload:
|
||||
file: ./coverage.xml
|
||||
|
@ -1714,6 +1758,12 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- batches_testing:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- pass_through_unit_testing:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -1735,6 +1785,7 @@ workflows:
|
|||
- upload-coverage:
|
||||
requires:
|
||||
- llm_translation_testing
|
||||
- batches_testing
|
||||
- pass_through_unit_testing
|
||||
- image_gen_testing
|
||||
- logging_testing
|
||||
|
@ -1783,6 +1834,7 @@ workflows:
|
|||
- load_testing
|
||||
- test_bad_database_url
|
||||
- llm_translation_testing
|
||||
- batches_testing
|
||||
- pass_through_unit_testing
|
||||
- image_gen_testing
|
||||
- logging_testing
|
||||
|
|
302
litellm/batches/batch_utils.py
Normal file
302
litellm/batches/batch_utils.py
Normal file
|
@ -0,0 +1,302 @@
|
|||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import threading
|
||||
from typing import Any, List, Literal, Optional
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.constants import (
|
||||
BATCH_STATUS_POLL_INTERVAL_SECONDS,
|
||||
BATCH_STATUS_POLL_MAX_ATTEMPTS,
|
||||
)
|
||||
from litellm.files.main import afile_content
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.types.llms.openai import Batch
|
||||
from litellm.types.utils import StandardLoggingPayload, Usage
|
||||
|
||||
|
||||
async def batches_async_logging(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens
|
||||
|
||||
|
||||
Polls retrieve_batch until it returns a batch with status "completed" or "failed"
|
||||
"""
|
||||
from .main import aretrieve_batch
|
||||
|
||||
verbose_logger.debug(
|
||||
".....in _batches_async_logging... polling retrieve to get batch status"
|
||||
)
|
||||
if logging_obj is None:
|
||||
raise ValueError(
|
||||
"logging_obj is None cannot calculate cost / log batch creation event"
|
||||
)
|
||||
for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS):
|
||||
try:
|
||||
start_time = datetime.datetime.now()
|
||||
batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider)
|
||||
verbose_logger.debug(
|
||||
"in _batches_async_logging... batch status= %s", batch.status
|
||||
)
|
||||
|
||||
if batch.status == "completed":
|
||||
end_time = datetime.datetime.now()
|
||||
await _handle_completed_batch(
|
||||
batch=batch,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
logging_obj=logging_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
**kwargs,
|
||||
)
|
||||
break
|
||||
elif batch.status == "failed":
|
||||
pass
|
||||
except Exception as e:
|
||||
verbose_logger.error("error in batches_async_logging", e)
|
||||
await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
async def _handle_completed_batch(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Helper function to process a completed batch and handle logging"""
|
||||
# Get batch results
|
||||
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||
batch, custom_llm_provider
|
||||
)
|
||||
|
||||
# Calculate costs and usage
|
||||
batch_cost = await _batch_cost_calculator(
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
)
|
||||
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
|
||||
# Handle logging
|
||||
await _log_completed_batch(
|
||||
logging_obj=logging_obj,
|
||||
batch_usage=batch_usage,
|
||||
batch_cost=batch_cost,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
async def _log_completed_batch(
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
batch_usage: Usage,
|
||||
batch_cost: float,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
**kwargs,
|
||||
) -> None:
|
||||
"""Helper function to handle all logging operations for a completed batch"""
|
||||
logging_obj.call_type = "batch_success"
|
||||
|
||||
standard_logging_object = _create_standard_logging_object_for_completed_batch(
|
||||
kwargs=kwargs,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
batch_usage_object=batch_usage,
|
||||
response_cost=batch_cost,
|
||||
)
|
||||
|
||||
logging_obj.model_call_details["standard_logging_object"] = standard_logging_object
|
||||
|
||||
# Launch async and sync logging handlers
|
||||
asyncio.create_task(
|
||||
logging_obj.async_success_handler(
|
||||
result=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
cache_hit=None,
|
||||
)
|
||||
)
|
||||
threading.Thread(
|
||||
target=logging_obj.success_handler,
|
||||
args=(None, start_time, end_time),
|
||||
).start()
|
||||
|
||||
|
||||
async def _batch_cost_calculator(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> float:
|
||||
"""
|
||||
Calculate the cost of a batch based on the output file id
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
total_cost = _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary=file_content_dictionary,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
|
||||
|
||||
async def _get_batch_output_file_content_as_dictionary(
|
||||
batch: Batch,
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get the batch output file content as a list of dictionaries
|
||||
"""
|
||||
if custom_llm_provider == "vertex_ai":
|
||||
raise ValueError("Vertex AI does not support file content retrieval")
|
||||
|
||||
if batch.output_file_id is None:
|
||||
raise ValueError("Output file id is None cannot retrieve file content")
|
||||
|
||||
_file_content = await afile_content(
|
||||
file_id=batch.output_file_id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return _get_file_content_as_dictionary(_file_content.content)
|
||||
|
||||
|
||||
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
||||
"""
|
||||
Get the file content as a list of dictionaries from JSON Lines format
|
||||
"""
|
||||
try:
|
||||
_file_content_str = file_content.decode("utf-8")
|
||||
# Split by newlines and parse each line as a separate JSON object
|
||||
json_objects = []
|
||||
for line in _file_content_str.strip().split("\n"):
|
||||
if line: # Skip empty lines
|
||||
json_objects.append(json.loads(line))
|
||||
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
||||
return json_objects
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_cost_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> float:
|
||||
"""
|
||||
Get the cost of a batch job from the file content
|
||||
"""
|
||||
try:
|
||||
total_cost: float = 0.0
|
||||
# parse the file content as json
|
||||
verbose_logger.debug(
|
||||
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
||||
)
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
total_cost += litellm.completion_cost(
|
||||
completion_response=_response_body,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
verbose_logger.debug("total_cost=%s", total_cost)
|
||||
return total_cost
|
||||
except Exception as e:
|
||||
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
||||
raise e
|
||||
|
||||
|
||||
def _get_batch_job_total_usage_from_file_content(
|
||||
file_content_dictionary: List[dict],
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the file content
|
||||
"""
|
||||
total_tokens: int = 0
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
for _item in file_content_dictionary:
|
||||
if _batch_response_was_successful(_item):
|
||||
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||
total_tokens += usage.total_tokens
|
||||
prompt_tokens += usage.prompt_tokens
|
||||
completion_tokens += usage.completion_tokens
|
||||
return Usage(
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
)
|
||||
|
||||
|
||||
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
||||
"""
|
||||
Get the tokens of a batch job from the response body
|
||||
"""
|
||||
_usage_dict = response_body.get("usage", None) or {}
|
||||
usage: Usage = Usage(**_usage_dict)
|
||||
return usage
|
||||
|
||||
|
||||
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
||||
"""
|
||||
Get the response from the batch job output file
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
_response_body = _response.get("body", None) or {}
|
||||
return _response_body
|
||||
|
||||
|
||||
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
||||
"""
|
||||
Check if the batch job response status == 200
|
||||
"""
|
||||
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||
return _response.get("status_code", None) == 200
|
||||
|
||||
|
||||
def _create_standard_logging_object_for_completed_batch(
|
||||
kwargs: dict,
|
||||
start_time: datetime.datetime,
|
||||
end_time: datetime.datetime,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
batch_usage_object: Usage,
|
||||
response_cost: float,
|
||||
) -> StandardLoggingPayload:
|
||||
"""
|
||||
Create a standard logging object for a completed batch
|
||||
"""
|
||||
from litellm.litellm_core_utils.litellm_logging import (
|
||||
get_standard_logging_object_payload,
|
||||
)
|
||||
|
||||
standard_logging_object = get_standard_logging_object_payload(
|
||||
kwargs=kwargs,
|
||||
init_response_obj=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
logging_obj=logging_obj,
|
||||
status="success",
|
||||
)
|
||||
|
||||
if standard_logging_object is None:
|
||||
raise ValueError("unable to create standard logging object for completed batch")
|
||||
|
||||
# Add Completed Batch Job Usage and Response Cost
|
||||
standard_logging_object["call_type"] = "batch_success"
|
||||
standard_logging_object["response_cost"] = response_cost
|
||||
standard_logging_object["total_tokens"] = batch_usage_object.total_tokens
|
||||
standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens
|
||||
standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens
|
||||
return standard_logging_object
|
|
@ -27,6 +27,8 @@ from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRe
|
|||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import client, supports_httpx_timeout
|
||||
|
||||
from .batch_utils import batches_async_logging
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
openai_batches_instance = OpenAIBatchesAPI()
|
||||
azure_batches_instance = AzureBatchesAPI()
|
||||
|
@ -71,10 +73,22 @@ async def acreate_batch(
|
|||
ctx = contextvars.copy_context()
|
||||
func_with_context = partial(ctx.run, func)
|
||||
init_response = await loop.run_in_executor(None, func_with_context)
|
||||
|
||||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
response = init_response
|
||||
|
||||
# Start async logging job
|
||||
if response is not None:
|
||||
asyncio.create_task(
|
||||
batches_async_logging(
|
||||
logging_obj=kwargs.get("litellm_logging_obj", None),
|
||||
batch_id=response.id,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
**kwargs,
|
||||
)
|
||||
)
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -238,7 +252,7 @@ def create_batch(
|
|||
|
||||
async def aretrieve_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
|
@ -279,7 +293,7 @@ async def aretrieve_batch(
|
|||
|
||||
def retrieve_batch(
|
||||
batch_id: str,
|
||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
metadata: Optional[Dict[str, str]] = None,
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
|
@ -552,7 +566,6 @@ def list_batches(
|
|||
return response
|
||||
except Exception as e:
|
||||
raise e
|
||||
pass
|
||||
|
||||
|
||||
def cancel_batch():
|
||||
|
|
|
@ -92,3 +92,6 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
|||
"generateQuery/",
|
||||
"optimize-prompt/",
|
||||
]
|
||||
|
||||
BATCH_STATUS_POLL_INTERVAL_SECONDS = 10
|
||||
BATCH_STATUS_POLL_MAX_ATTEMPTS = 10
|
||||
|
|
|
@ -2,10 +2,12 @@ from typing import Any, Coroutine, Optional, Union
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm import LlmProviders
|
||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
||||
GCSBucketBase,
|
||||
GCSLoggingConfig,
|
||||
)
|
||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||
from litellm.types.llms.openai import CreateFileRequest, FileObject
|
||||
|
||||
from .transformation import VertexAIFilesTransformation
|
||||
|
@ -20,6 +22,12 @@ class VertexAIFilesHandler(GCSBucketBase):
|
|||
This implementation uploads files on GCS Buckets
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.async_httpx_client = get_async_httpx_client(
|
||||
llm_provider=LlmProviders.VERTEX_AI,
|
||||
)
|
||||
|
||||
pass
|
||||
|
||||
async def async_create_file(
|
||||
|
|
|
@ -826,6 +826,11 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
- It supports various optional parameters for customizing the completion behavior.
|
||||
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
|
||||
"""
|
||||
### VALIDATE Request ###
|
||||
if model is None:
|
||||
raise ValueError("model param not passed in.")
|
||||
# validate messages
|
||||
messages = validate_chat_completion_messages(messages=messages)
|
||||
######### unpacking kwargs #####################
|
||||
args = locals()
|
||||
api_base = kwargs.get("api_base", None)
|
||||
|
@ -997,9 +1002,6 @@ def completion( # type: ignore # noqa: PLR0915
|
|||
"aws_region_name", None
|
||||
) # support region-based pricing for bedrock
|
||||
|
||||
### VALIDATE USER MESSAGES ###
|
||||
messages = validate_chat_completion_messages(messages=messages)
|
||||
|
||||
### TIMEOUT LOGIC ###
|
||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||
# set timeout for 10 minutes by default
|
||||
|
|
13
tests/batches_tests/adroit-crow-413218-bc47f303efc9.json
Normal file
13
tests/batches_tests/adroit-crow-413218-bc47f303efc9.json
Normal file
|
@ -0,0 +1,13 @@
|
|||
{
|
||||
"type": "service_account",
|
||||
"project_id": "adroit-crow-413218",
|
||||
"private_key_id": "",
|
||||
"private_key": "",
|
||||
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
|
||||
"client_id": "104886546564708740969",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
|
||||
"universe_domain": "googleapis.com"
|
||||
}
|
3
tests/batches_tests/batch_job_results_furniture.jsonl
Normal file
3
tests/batches_tests/batch_job_results_furniture.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
|||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
|
3
tests/batches_tests/openai_batch_completions.jsonl
Normal file
3
tests/batches_tests/openai_batch_completions.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
|||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
|
171
tests/batches_tests/test_batches_logging_unit_tests.py
Normal file
171
tests/batches_tests/test_batches_logging_unit_tests.py
Normal file
|
@ -0,0 +1,171 @@
|
|||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system-path
|
||||
import logging
|
||||
import time
|
||||
|
||||
import pytest
|
||||
from typing import Optional
|
||||
import litellm
|
||||
from litellm import create_batch, create_file
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.batches.batch_utils import (
|
||||
_batch_cost_calculator,
|
||||
_get_file_content_as_dictionary,
|
||||
_get_batch_job_cost_from_file_content,
|
||||
_get_batch_job_total_usage_from_file_content,
|
||||
_get_batch_job_usage_from_response_body,
|
||||
_get_response_from_batch_job_output_file,
|
||||
_batch_response_was_successful,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_content():
|
||||
return b"""
|
||||
{"id": "batch_req_6769ca596b38819093d7ae9f522de924", "custom_id": "request-1", "response": {"status_code": 200, "request_id": "07bc45ab4e7e26ac23a0c949973327e7", "body": {"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7", "object": "chat.completion", "created": 1734986202, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! How can I assist you today?", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null}
|
||||
{"id": "batch_req_6769ca597e588190920666612634e2b4", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "82e04f4c001fe2c127cbad199f5fd31b", "body": {"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP", "object": "chat.completion", "created": 1734986203, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! What can I do for you today?", "refusal": null}, "logprobs": null, "finish_reason": "length"}], "usage": {"prompt_tokens": 22, "completion_tokens": 10, "total_tokens": 32, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null}
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_file_content_dict():
|
||||
return [
|
||||
{
|
||||
"id": "batch_req_6769ca596b38819093d7ae9f522de924",
|
||||
"custom_id": "request-1",
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": "07bc45ab4e7e26ac23a0c949973327e7",
|
||||
"body": {
|
||||
"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7",
|
||||
"object": "chat.completion",
|
||||
"created": 1734986202,
|
||||
"model": "gpt-4o-mini-2024-07-18",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! How can I assist you today?",
|
||||
"refusal": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 20,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 30,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
},
|
||||
"system_fingerprint": "fp_0aa8d3e20b",
|
||||
},
|
||||
},
|
||||
"error": None,
|
||||
},
|
||||
{
|
||||
"id": "batch_req_6769ca597e588190920666612634e2b4",
|
||||
"custom_id": "request-2",
|
||||
"response": {
|
||||
"status_code": 200,
|
||||
"request_id": "82e04f4c001fe2c127cbad199f5fd31b",
|
||||
"body": {
|
||||
"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP",
|
||||
"object": "chat.completion",
|
||||
"created": 1734986203,
|
||||
"model": "gpt-4o-mini-2024-07-18",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Hello! What can I do for you today?",
|
||||
"refusal": None,
|
||||
},
|
||||
"logprobs": None,
|
||||
"finish_reason": "length",
|
||||
}
|
||||
],
|
||||
"usage": {
|
||||
"prompt_tokens": 22,
|
||||
"completion_tokens": 10,
|
||||
"total_tokens": 32,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 0,
|
||||
"audio_tokens": 0,
|
||||
"accepted_prediction_tokens": 0,
|
||||
"rejected_prediction_tokens": 0,
|
||||
},
|
||||
},
|
||||
"system_fingerprint": "fp_0aa8d3e20b",
|
||||
},
|
||||
},
|
||||
"error": None,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_get_file_content_as_dictionary(sample_file_content):
|
||||
result = _get_file_content_as_dictionary(sample_file_content)
|
||||
assert len(result) == 2
|
||||
assert result[0]["id"] == "batch_req_6769ca596b38819093d7ae9f522de924"
|
||||
assert result[0]["custom_id"] == "request-1"
|
||||
assert result[0]["response"]["status_code"] == 200
|
||||
assert result[0]["response"]["body"]["usage"]["total_tokens"] == 30
|
||||
|
||||
|
||||
def test_get_batch_job_total_usage_from_file_content(sample_file_content_dict):
|
||||
usage = _get_batch_job_total_usage_from_file_content(
|
||||
sample_file_content_dict, custom_llm_provider="openai"
|
||||
)
|
||||
assert usage.total_tokens == 62 # 30 + 32
|
||||
assert usage.prompt_tokens == 42 # 20 + 22
|
||||
assert usage.completion_tokens == 20 # 10 + 10
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_batch_cost_calculator(sample_file_content_dict):
|
||||
"""
|
||||
mock litellm.completion_cost to return 0.5
|
||||
|
||||
we know sample_file_content_dict has 2 successful responses
|
||||
|
||||
so we expect the cost to be 0.5 * 2 = 1.0
|
||||
"""
|
||||
with patch("litellm.completion_cost", return_value=0.5):
|
||||
cost = await _batch_cost_calculator(
|
||||
file_content_dictionary=sample_file_content_dict,
|
||||
custom_llm_provider="openai",
|
||||
)
|
||||
assert cost == 1.0 # 0.5 * 2 successful responses
|
||||
|
||||
|
||||
def test_get_response_from_batch_job_output_file(sample_file_content_dict):
|
||||
result = _get_response_from_batch_job_output_file(sample_file_content_dict[0])
|
||||
assert result["id"] == "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7"
|
||||
assert result["object"] == "chat.completion"
|
||||
assert result["usage"]["total_tokens"] == 30
|
|
@ -5,7 +5,7 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import traceback
|
||||
|
||||
import tempfile
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
|
@ -20,7 +20,6 @@ from typing import Optional
|
|||
import litellm
|
||||
from litellm import create_batch, create_file
|
||||
from litellm._logging import verbose_logger
|
||||
from test_gcs_bucket import load_vertex_ai_credentials
|
||||
|
||||
verbose_logger.setLevel(logging.DEBUG)
|
||||
|
||||
|
@ -28,19 +27,47 @@ from litellm.integrations.custom_logger import CustomLogger
|
|||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class TestCustomLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
||||
def load_vertex_ai_credentials():
|
||||
# Define the path to the vertex_key.json file
|
||||
print("loading vertex ai credentials")
|
||||
os.environ["GCS_FLUSH_INTERVAL"] = "1"
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(
|
||||
"Success event logged with kwargs=",
|
||||
kwargs,
|
||||
"and response_obj=",
|
||||
response_obj,
|
||||
)
|
||||
self.standard_logging_object = kwargs["standard_logging_object"]
|
||||
# Read the existing content of the file or create an empty dictionary
|
||||
try:
|
||||
with open(vertex_key_path, "r") as file:
|
||||
# Read the file content
|
||||
print("Read vertexai file path")
|
||||
content = file.read()
|
||||
|
||||
# If the file is empty or not valid JSON, create an empty dictionary
|
||||
if not content or not content.strip():
|
||||
service_account_key_data = {}
|
||||
else:
|
||||
# Attempt to load the existing JSON content
|
||||
file.seek(0)
|
||||
service_account_key_data = json.load(file)
|
||||
except FileNotFoundError:
|
||||
# If the file doesn't exist, create an empty dictionary
|
||||
service_account_key_data = {}
|
||||
|
||||
# Update the service_account_key_data with environment variables
|
||||
private_key_id = os.environ.get("GCS_PRIVATE_KEY_ID", "")
|
||||
private_key = os.environ.get("GCS_PRIVATE_KEY", "")
|
||||
private_key = private_key.replace("\\n", "\n")
|
||||
service_account_key_data["private_key_id"] = private_key_id
|
||||
service_account_key_data["private_key"] = private_key
|
||||
|
||||
# Create a temporary file
|
||||
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||
# Write the updated content to the temporary files
|
||||
json.dump(service_account_key_data, temp_file, indent=2)
|
||||
|
||||
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||
os.environ["GCS_PATH_SERVICE_ACCOUNT"] = os.path.abspath(temp_file.name)
|
||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
print("created gcs path service account=", os.environ["GCS_PATH_SERVICE_ACCOUNT"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||
|
@ -128,6 +155,21 @@ async def test_create_batch(provider):
|
|||
pass
|
||||
|
||||
|
||||
class TestCustomLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
print(
|
||||
"Success event logged with kwargs=",
|
||||
kwargs,
|
||||
"and response_obj=",
|
||||
response_obj,
|
||||
)
|
||||
self.standard_logging_object = kwargs["standard_logging_object"]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
|
||||
@pytest.mark.asyncio()
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
|
@ -142,6 +184,9 @@ async def test_async_create_batch(provider):
|
|||
# Don't have anymore Azure Quota
|
||||
return
|
||||
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger, "datadog"]
|
||||
|
||||
file_name = "openai_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
||||
|
@ -179,7 +224,9 @@ async def test_async_create_batch(provider):
|
|||
create_batch_response.input_file_id == batch_input_file_id
|
||||
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(6)
|
||||
# Assert that the create batch event is logged on CustomLogger
|
||||
assert custom_logger.standard_logging_object is not None
|
||||
|
||||
retrieved_batch = await litellm.aretrieve_batch(
|
||||
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||
|
@ -223,9 +270,10 @@ async def test_async_create_batch(provider):
|
|||
|
||||
print("all_files_list = ", all_files_list)
|
||||
|
||||
# # write this file content to a file
|
||||
# with open("file_content.json", "w") as f:
|
||||
# json.dump(file_content, f)
|
||||
result_file_name = "batch_job_results_furniture.jsonl"
|
||||
|
||||
with open(result_file_name, "wb") as file:
|
||||
file.write(file_content.content)
|
||||
|
||||
|
||||
def test_retrieve_batch():
|
||||
|
@ -241,8 +289,9 @@ def test_list_batch():
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_batch_prediction():
|
||||
async def test_avertex_batch_prediction():
|
||||
load_vertex_ai_credentials()
|
||||
litellm.set_verbose = True
|
||||
file_name = "vertex_batch_completions.jsonl"
|
||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
file_path = os.path.join(_current_dir, file_name)
|
2
tests/batches_tests/vertex_batch_completions.jsonl
Normal file
2
tests/batches_tests/vertex_batch_completions.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
|||
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
|
@ -23,7 +23,7 @@ model_val = None
|
|||
|
||||
def test_completion_with_no_model():
|
||||
# test on empty
|
||||
with pytest.raises(ValueError):
|
||||
with pytest.raises(TypeError):
|
||||
response = completion(messages=messages)
|
||||
|
||||
|
||||
|
@ -36,39 +36,6 @@ def test_completion_with_empty_model():
|
|||
pass
|
||||
|
||||
|
||||
# def test_completion_catch_nlp_exception():
|
||||
# TEMP commented out NLP cloud API is unstable
|
||||
# try:
|
||||
# response = completion(model="dolphin", messages=messages, functions=[
|
||||
# {
|
||||
# "name": "get_current_weather",
|
||||
# "description": "Get the current weather in a given location",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "location": {
|
||||
# "type": "string",
|
||||
# "description": "The city and state, e.g. San Francisco, CA"
|
||||
# },
|
||||
# "unit": {
|
||||
# "type": "string",
|
||||
# "enum": ["celsius", "fahrenheit"]
|
||||
# }
|
||||
# },
|
||||
# "required": ["location"]
|
||||
# }
|
||||
# }
|
||||
# ])
|
||||
|
||||
# except Exception as e:
|
||||
# if "Function calling is not supported by nlp_cloud" in str(e):
|
||||
# pass
|
||||
# else:
|
||||
# pytest.fail(f'An error occurred {e}')
|
||||
|
||||
# test_completion_catch_nlp_exception()
|
||||
|
||||
|
||||
def test_completion_invalid_param_cohere():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
@ -94,9 +61,6 @@ def test_completion_function_call_cohere():
|
|||
pass
|
||||
|
||||
|
||||
# test_completion_function_call_cohere()
|
||||
|
||||
|
||||
def test_completion_function_call_openai():
|
||||
try:
|
||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||
|
@ -140,17 +104,3 @@ def test_completion_with_no_provider():
|
|||
except Exception as e:
|
||||
print(f"error occurred: {e}")
|
||||
pass
|
||||
|
||||
|
||||
# test_completion_with_no_provider()
|
||||
# # bad key
|
||||
# temp_key = os.environ.get("OPENAI_API_KEY")
|
||||
# os.environ["OPENAI_API_KEY"] = "bad-key"
|
||||
# # test on openai completion call
|
||||
# try:
|
||||
# response = completion(model="gpt-3.5-turbo", messages=messages)
|
||||
# print(f"response: {response}")
|
||||
# except Exception:
|
||||
# print(f"error occurred: {traceback.format_exc()}")
|
||||
# pass
|
||||
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue