(feat) Add cost tracking for /batches requests OpenAI (#7384)

* add basic logging for create`batch`

* add create_batch as a call type

* add basic dd logging for batches

* basic batch creation logging on DD

* batch endpoints add cost calc

* fix batches_async_logging

* separate folder for batches testing

* new job for batches tests

* test batches logging

* fix validation logic

* add vertex_batch_completions.jsonl

* test test_async_create_batch

* test_async_create_batch

* update tests

* test_completion_with_no_model

* remove dead code

* update load_vertex_ai_credentials

* test_avertex_batch_prediction

* update get async httpx client

* fix get_async_httpx_client

* update test_avertex_batch_prediction

* fix batches testing config.yaml

* add google deps

* fix vertex files handler
This commit is contained in:
Ishaan Jaff 2024-12-23 17:47:26 -08:00 committed by GitHub
parent 87f19d6f13
commit 05b0d2026f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
13 changed files with 649 additions and 78 deletions

View file

@ -626,6 +626,50 @@ jobs:
paths:
- llm_translation_coverage.xml
- llm_translation_coverage
batches_testing:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project
steps:
- checkout
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
python -m pip install -r requirements.txt
pip install "respx==0.21.1"
pip install "pytest==7.3.1"
pip install "pytest-retry==1.6.3"
pip install "pytest-asyncio==0.21.1"
pip install "pytest-cov==5.0.0"
pip install "google-generativeai==0.3.2"
pip install "google-cloud-aiplatform==1.43.0"
# Run pytest and generate JUnit XML report
- run:
name: Run tests
command: |
pwd
ls
python -m pytest -vv tests/batches_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
no_output_timeout: 120m
- run:
name: Rename the coverage files
command: |
mv coverage.xml batches_coverage.xml
mv .coverage batches_coverage
# Store test results
- store_test_results:
path: test-results
- persist_to_workspace:
root: .
paths:
- batches_coverage.xml
- batches_coverage
pass_through_unit_testing:
docker:
- image: cimg/python:3.11
@ -1417,7 +1461,7 @@ jobs:
python -m venv venv
. venv/bin/activate
pip install coverage
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage
coverage xml
- codecov/upload:
file: ./coverage.xml
@ -1714,6 +1758,12 @@ workflows:
only:
- main
- /litellm_.*/
- batches_testing:
filters:
branches:
only:
- main
- /litellm_.*/
- pass_through_unit_testing:
filters:
branches:
@ -1735,6 +1785,7 @@ workflows:
- upload-coverage:
requires:
- llm_translation_testing
- batches_testing
- pass_through_unit_testing
- image_gen_testing
- logging_testing
@ -1783,6 +1834,7 @@ workflows:
- load_testing
- test_bad_database_url
- llm_translation_testing
- batches_testing
- pass_through_unit_testing
- image_gen_testing
- logging_testing

View file

@ -0,0 +1,302 @@
import asyncio
import datetime
import json
import threading
from typing import Any, List, Literal, Optional
import litellm
from litellm._logging import verbose_logger
from litellm.constants import (
BATCH_STATUS_POLL_INTERVAL_SECONDS,
BATCH_STATUS_POLL_MAX_ATTEMPTS,
)
from litellm.files.main import afile_content
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.types.llms.openai import Batch
from litellm.types.utils import StandardLoggingPayload, Usage
async def batches_async_logging(
batch_id: str,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
logging_obj: Optional[LiteLLMLoggingObj] = None,
**kwargs,
):
"""
Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens
Polls retrieve_batch until it returns a batch with status "completed" or "failed"
"""
from .main import aretrieve_batch
verbose_logger.debug(
".....in _batches_async_logging... polling retrieve to get batch status"
)
if logging_obj is None:
raise ValueError(
"logging_obj is None cannot calculate cost / log batch creation event"
)
for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS):
try:
start_time = datetime.datetime.now()
batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider)
verbose_logger.debug(
"in _batches_async_logging... batch status= %s", batch.status
)
if batch.status == "completed":
end_time = datetime.datetime.now()
await _handle_completed_batch(
batch=batch,
custom_llm_provider=custom_llm_provider,
logging_obj=logging_obj,
start_time=start_time,
end_time=end_time,
**kwargs,
)
break
elif batch.status == "failed":
pass
except Exception as e:
verbose_logger.error("error in batches_async_logging", e)
await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS)
async def _handle_completed_batch(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
logging_obj: LiteLLMLoggingObj,
start_time: datetime.datetime,
end_time: datetime.datetime,
**kwargs,
) -> None:
"""Helper function to process a completed batch and handle logging"""
# Get batch results
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
batch, custom_llm_provider
)
# Calculate costs and usage
batch_cost = await _batch_cost_calculator(
custom_llm_provider=custom_llm_provider,
file_content_dictionary=file_content_dictionary,
)
batch_usage = _get_batch_job_total_usage_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
)
# Handle logging
await _log_completed_batch(
logging_obj=logging_obj,
batch_usage=batch_usage,
batch_cost=batch_cost,
start_time=start_time,
end_time=end_time,
**kwargs,
)
async def _log_completed_batch(
logging_obj: LiteLLMLoggingObj,
batch_usage: Usage,
batch_cost: float,
start_time: datetime.datetime,
end_time: datetime.datetime,
**kwargs,
) -> None:
"""Helper function to handle all logging operations for a completed batch"""
logging_obj.call_type = "batch_success"
standard_logging_object = _create_standard_logging_object_for_completed_batch(
kwargs=kwargs,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
batch_usage_object=batch_usage,
response_cost=batch_cost,
)
logging_obj.model_call_details["standard_logging_object"] = standard_logging_object
# Launch async and sync logging handlers
asyncio.create_task(
logging_obj.async_success_handler(
result=None,
start_time=start_time,
end_time=end_time,
cache_hit=None,
)
)
threading.Thread(
target=logging_obj.success_handler,
args=(None, start_time, end_time),
).start()
async def _batch_cost_calculator(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> float:
"""
Calculate the cost of a batch based on the output file id
"""
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
total_cost = _get_batch_job_cost_from_file_content(
file_content_dictionary=file_content_dictionary,
custom_llm_provider=custom_llm_provider,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
async def _get_batch_output_file_content_as_dictionary(
batch: Batch,
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> List[dict]:
"""
Get the batch output file content as a list of dictionaries
"""
if custom_llm_provider == "vertex_ai":
raise ValueError("Vertex AI does not support file content retrieval")
if batch.output_file_id is None:
raise ValueError("Output file id is None cannot retrieve file content")
_file_content = await afile_content(
file_id=batch.output_file_id,
custom_llm_provider=custom_llm_provider,
)
return _get_file_content_as_dictionary(_file_content.content)
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
"""
Get the file content as a list of dictionaries from JSON Lines format
"""
try:
_file_content_str = file_content.decode("utf-8")
# Split by newlines and parse each line as a separate JSON object
json_objects = []
for line in _file_content_str.strip().split("\n"):
if line: # Skip empty lines
json_objects.append(json.loads(line))
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
return json_objects
except Exception as e:
raise e
def _get_batch_job_cost_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> float:
"""
Get the cost of a batch job from the file content
"""
try:
total_cost: float = 0.0
# parse the file content as json
verbose_logger.debug(
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
)
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
total_cost += litellm.completion_cost(
completion_response=_response_body,
custom_llm_provider=custom_llm_provider,
)
verbose_logger.debug("total_cost=%s", total_cost)
return total_cost
except Exception as e:
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
raise e
def _get_batch_job_total_usage_from_file_content(
file_content_dictionary: List[dict],
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
) -> Usage:
"""
Get the tokens of a batch job from the file content
"""
total_tokens: int = 0
prompt_tokens: int = 0
completion_tokens: int = 0
for _item in file_content_dictionary:
if _batch_response_was_successful(_item):
_response_body = _get_response_from_batch_job_output_file(_item)
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
total_tokens += usage.total_tokens
prompt_tokens += usage.prompt_tokens
completion_tokens += usage.completion_tokens
return Usage(
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
)
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
"""
Get the tokens of a batch job from the response body
"""
_usage_dict = response_body.get("usage", None) or {}
usage: Usage = Usage(**_usage_dict)
return usage
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
"""
Get the response from the batch job output file
"""
_response: dict = batch_job_output_file.get("response", None) or {}
_response_body = _response.get("body", None) or {}
return _response_body
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
"""
Check if the batch job response status == 200
"""
_response: dict = batch_job_output_file.get("response", None) or {}
return _response.get("status_code", None) == 200
def _create_standard_logging_object_for_completed_batch(
kwargs: dict,
start_time: datetime.datetime,
end_time: datetime.datetime,
logging_obj: LiteLLMLoggingObj,
batch_usage_object: Usage,
response_cost: float,
) -> StandardLoggingPayload:
"""
Create a standard logging object for a completed batch
"""
from litellm.litellm_core_utils.litellm_logging import (
get_standard_logging_object_payload,
)
standard_logging_object = get_standard_logging_object_payload(
kwargs=kwargs,
init_response_obj=None,
start_time=start_time,
end_time=end_time,
logging_obj=logging_obj,
status="success",
)
if standard_logging_object is None:
raise ValueError("unable to create standard logging object for completed batch")
# Add Completed Batch Job Usage and Response Cost
standard_logging_object["call_type"] = "batch_success"
standard_logging_object["response_cost"] = response_cost
standard_logging_object["total_tokens"] = batch_usage_object.total_tokens
standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens
standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens
return standard_logging_object

View file

@ -27,6 +27,8 @@ from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRe
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import client, supports_httpx_timeout
from .batch_utils import batches_async_logging
####### ENVIRONMENT VARIABLES ###################
openai_batches_instance = OpenAIBatchesAPI()
azure_batches_instance = AzureBatchesAPI()
@ -71,10 +73,22 @@ async def acreate_batch(
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
init_response = await loop.run_in_executor(None, func_with_context)
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
response = init_response
# Start async logging job
if response is not None:
asyncio.create_task(
batches_async_logging(
logging_obj=kwargs.get("litellm_logging_obj", None),
batch_id=response.id,
custom_llm_provider=custom_llm_provider,
**kwargs,
)
)
return response
except Exception as e:
@ -238,7 +252,7 @@ def create_batch(
async def aretrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -279,7 +293,7 @@ async def aretrieve_batch(
def retrieve_batch(
batch_id: str,
custom_llm_provider: Literal["openai", "azure"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
metadata: Optional[Dict[str, str]] = None,
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
@ -552,7 +566,6 @@ def list_batches(
return response
except Exception as e:
raise e
pass
def cancel_batch():

View file

@ -92,3 +92,6 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
"generateQuery/",
"optimize-prompt/",
]
BATCH_STATUS_POLL_INTERVAL_SECONDS = 10
BATCH_STATUS_POLL_MAX_ATTEMPTS = 10

View file

@ -2,10 +2,12 @@ from typing import Any, Coroutine, Optional, Union
import httpx
from litellm import LlmProviders
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
GCSBucketBase,
GCSLoggingConfig,
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, FileObject
from .transformation import VertexAIFilesTransformation
@ -20,6 +22,12 @@ class VertexAIFilesHandler(GCSBucketBase):
This implementation uploads files on GCS Buckets
"""
def __init__(self):
super().__init__()
self.async_httpx_client = get_async_httpx_client(
llm_provider=LlmProviders.VERTEX_AI,
)
pass
async def async_create_file(

View file

@ -826,6 +826,11 @@ def completion( # type: ignore # noqa: PLR0915
- It supports various optional parameters for customizing the completion behavior.
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
"""
### VALIDATE Request ###
if model is None:
raise ValueError("model param not passed in.")
# validate messages
messages = validate_chat_completion_messages(messages=messages)
######### unpacking kwargs #####################
args = locals()
api_base = kwargs.get("api_base", None)
@ -997,9 +1002,6 @@ def completion( # type: ignore # noqa: PLR0915
"aws_region_name", None
) # support region-based pricing for bedrock
### VALIDATE USER MESSAGES ###
messages = validate_chat_completion_messages(messages=messages)
### TIMEOUT LOGIC ###
timeout = timeout or kwargs.get("request_timeout", 600) or 600
# set timeout for 10 minutes by default

View file

@ -0,0 +1,13 @@
{
"type": "service_account",
"project_id": "adroit-crow-413218",
"private_key_id": "",
"private_key": "",
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
"client_id": "104886546564708740969",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}

View file

@ -0,0 +1,3 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}

View file

@ -0,0 +1,3 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}

View file

@ -0,0 +1,171 @@
import asyncio
import json
import os
import sys
import traceback
from unittest.mock import AsyncMock, MagicMock, patch
from dotenv import load_dotenv
load_dotenv()
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system-path
import logging
import time
import pytest
from typing import Optional
import litellm
from litellm import create_batch, create_file
from litellm._logging import verbose_logger
from litellm.batches.batch_utils import (
_batch_cost_calculator,
_get_file_content_as_dictionary,
_get_batch_job_cost_from_file_content,
_get_batch_job_total_usage_from_file_content,
_get_batch_job_usage_from_response_body,
_get_response_from_batch_job_output_file,
_batch_response_was_successful,
)
@pytest.fixture
def sample_file_content():
return b"""
{"id": "batch_req_6769ca596b38819093d7ae9f522de924", "custom_id": "request-1", "response": {"status_code": 200, "request_id": "07bc45ab4e7e26ac23a0c949973327e7", "body": {"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7", "object": "chat.completion", "created": 1734986202, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! How can I assist you today?", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null}
{"id": "batch_req_6769ca597e588190920666612634e2b4", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "82e04f4c001fe2c127cbad199f5fd31b", "body": {"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP", "object": "chat.completion", "created": 1734986203, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! What can I do for you today?", "refusal": null}, "logprobs": null, "finish_reason": "length"}], "usage": {"prompt_tokens": 22, "completion_tokens": 10, "total_tokens": 32, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null}
"""
@pytest.fixture
def sample_file_content_dict():
return [
{
"id": "batch_req_6769ca596b38819093d7ae9f522de924",
"custom_id": "request-1",
"response": {
"status_code": 200,
"request_id": "07bc45ab4e7e26ac23a0c949973327e7",
"body": {
"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7",
"object": "chat.completion",
"created": 1734986202,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! How can I assist you today?",
"refusal": None,
},
"logprobs": None,
"finish_reason": "stop",
}
],
"usage": {
"prompt_tokens": 20,
"completion_tokens": 10,
"total_tokens": 30,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0,
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
},
"system_fingerprint": "fp_0aa8d3e20b",
},
},
"error": None,
},
{
"id": "batch_req_6769ca597e588190920666612634e2b4",
"custom_id": "request-2",
"response": {
"status_code": 200,
"request_id": "82e04f4c001fe2c127cbad199f5fd31b",
"body": {
"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP",
"object": "chat.completion",
"created": 1734986203,
"model": "gpt-4o-mini-2024-07-18",
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Hello! What can I do for you today?",
"refusal": None,
},
"logprobs": None,
"finish_reason": "length",
}
],
"usage": {
"prompt_tokens": 22,
"completion_tokens": 10,
"total_tokens": 32,
"prompt_tokens_details": {
"cached_tokens": 0,
"audio_tokens": 0,
},
"completion_tokens_details": {
"reasoning_tokens": 0,
"audio_tokens": 0,
"accepted_prediction_tokens": 0,
"rejected_prediction_tokens": 0,
},
},
"system_fingerprint": "fp_0aa8d3e20b",
},
},
"error": None,
},
]
def test_get_file_content_as_dictionary(sample_file_content):
result = _get_file_content_as_dictionary(sample_file_content)
assert len(result) == 2
assert result[0]["id"] == "batch_req_6769ca596b38819093d7ae9f522de924"
assert result[0]["custom_id"] == "request-1"
assert result[0]["response"]["status_code"] == 200
assert result[0]["response"]["body"]["usage"]["total_tokens"] == 30
def test_get_batch_job_total_usage_from_file_content(sample_file_content_dict):
usage = _get_batch_job_total_usage_from_file_content(
sample_file_content_dict, custom_llm_provider="openai"
)
assert usage.total_tokens == 62 # 30 + 32
assert usage.prompt_tokens == 42 # 20 + 22
assert usage.completion_tokens == 20 # 10 + 10
@pytest.mark.asyncio
async def test_batch_cost_calculator(sample_file_content_dict):
"""
mock litellm.completion_cost to return 0.5
we know sample_file_content_dict has 2 successful responses
so we expect the cost to be 0.5 * 2 = 1.0
"""
with patch("litellm.completion_cost", return_value=0.5):
cost = await _batch_cost_calculator(
file_content_dictionary=sample_file_content_dict,
custom_llm_provider="openai",
)
assert cost == 1.0 # 0.5 * 2 successful responses
def test_get_response_from_batch_job_output_file(sample_file_content_dict):
result = _get_response_from_batch_job_output_file(sample_file_content_dict[0])
assert result["id"] == "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7"
assert result["object"] == "chat.completion"
assert result["usage"]["total_tokens"] == 30

View file

@ -5,7 +5,7 @@ import json
import os
import sys
import traceback
import tempfile
from dotenv import load_dotenv
load_dotenv()
@ -20,7 +20,6 @@ from typing import Optional
import litellm
from litellm import create_batch, create_file
from litellm._logging import verbose_logger
from test_gcs_bucket import load_vertex_ai_credentials
verbose_logger.setLevel(logging.DEBUG)
@ -28,19 +27,47 @@ from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.standard_logging_object: Optional[StandardLoggingPayload] = None
def load_vertex_ai_credentials():
# Define the path to the vertex_key.json file
print("loading vertex ai credentials")
os.environ["GCS_FLUSH_INTERVAL"] = "1"
filepath = os.path.dirname(os.path.abspath(__file__))
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(
"Success event logged with kwargs=",
kwargs,
"and response_obj=",
response_obj,
)
self.standard_logging_object = kwargs["standard_logging_object"]
# Read the existing content of the file or create an empty dictionary
try:
with open(vertex_key_path, "r") as file:
# Read the file content
print("Read vertexai file path")
content = file.read()
# If the file is empty or not valid JSON, create an empty dictionary
if not content or not content.strip():
service_account_key_data = {}
else:
# Attempt to load the existing JSON content
file.seek(0)
service_account_key_data = json.load(file)
except FileNotFoundError:
# If the file doesn't exist, create an empty dictionary
service_account_key_data = {}
# Update the service_account_key_data with environment variables
private_key_id = os.environ.get("GCS_PRIVATE_KEY_ID", "")
private_key = os.environ.get("GCS_PRIVATE_KEY", "")
private_key = private_key.replace("\\n", "\n")
service_account_key_data["private_key_id"] = private_key_id
service_account_key_data["private_key"] = private_key
# Create a temporary file
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
# Write the updated content to the temporary files
json.dump(service_account_key_data, temp_file, indent=2)
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
os.environ["GCS_PATH_SERVICE_ACCOUNT"] = os.path.abspath(temp_file.name)
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
print("created gcs path service account=", os.environ["GCS_PATH_SERVICE_ACCOUNT"])
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
@ -128,6 +155,21 @@ async def test_create_batch(provider):
pass
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.standard_logging_object: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
print(
"Success event logged with kwargs=",
kwargs,
"and response_obj=",
response_obj,
)
self.standard_logging_object = kwargs["standard_logging_object"]
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
@pytest.mark.asyncio()
@pytest.mark.flaky(retries=3, delay=1)
@ -142,6 +184,9 @@ async def test_async_create_batch(provider):
# Don't have anymore Azure Quota
return
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger, "datadog"]
file_name = "openai_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)
@ -179,7 +224,9 @@ async def test_async_create_batch(provider):
create_batch_response.input_file_id == batch_input_file_id
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
await asyncio.sleep(1)
await asyncio.sleep(6)
# Assert that the create batch event is logged on CustomLogger
assert custom_logger.standard_logging_object is not None
retrieved_batch = await litellm.aretrieve_batch(
batch_id=create_batch_response.id, custom_llm_provider=provider
@ -223,9 +270,10 @@ async def test_async_create_batch(provider):
print("all_files_list = ", all_files_list)
# # write this file content to a file
# with open("file_content.json", "w") as f:
# json.dump(file_content, f)
result_file_name = "batch_job_results_furniture.jsonl"
with open(result_file_name, "wb") as file:
file.write(file_content.content)
def test_retrieve_batch():
@ -241,8 +289,9 @@ def test_list_batch():
@pytest.mark.asyncio
async def test_vertex_batch_prediction():
async def test_avertex_batch_prediction():
load_vertex_ai_credentials()
litellm.set_verbose = True
file_name = "vertex_batch_completions.jsonl"
_current_dir = os.path.dirname(os.path.abspath(__file__))
file_path = os.path.join(_current_dir, file_name)

View file

@ -0,0 +1,2 @@
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}

View file

@ -23,7 +23,7 @@ model_val = None
def test_completion_with_no_model():
# test on empty
with pytest.raises(ValueError):
with pytest.raises(TypeError):
response = completion(messages=messages)
@ -36,39 +36,6 @@ def test_completion_with_empty_model():
pass
# def test_completion_catch_nlp_exception():
# TEMP commented out NLP cloud API is unstable
# try:
# response = completion(model="dolphin", messages=messages, functions=[
# {
# "name": "get_current_weather",
# "description": "Get the current weather in a given location",
# "parameters": {
# "type": "object",
# "properties": {
# "location": {
# "type": "string",
# "description": "The city and state, e.g. San Francisco, CA"
# },
# "unit": {
# "type": "string",
# "enum": ["celsius", "fahrenheit"]
# }
# },
# "required": ["location"]
# }
# }
# ])
# except Exception as e:
# if "Function calling is not supported by nlp_cloud" in str(e):
# pass
# else:
# pytest.fail(f'An error occurred {e}')
# test_completion_catch_nlp_exception()
def test_completion_invalid_param_cohere():
try:
litellm.set_verbose = True
@ -94,9 +61,6 @@ def test_completion_function_call_cohere():
pass
# test_completion_function_call_cohere()
def test_completion_function_call_openai():
try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
@ -140,17 +104,3 @@ def test_completion_with_no_provider():
except Exception as e:
print(f"error occurred: {e}")
pass
# test_completion_with_no_provider()
# # bad key
# temp_key = os.environ.get("OPENAI_API_KEY")
# os.environ["OPENAI_API_KEY"] = "bad-key"
# # test on openai completion call
# try:
# response = completion(model="gpt-3.5-turbo", messages=messages)
# print(f"response: {response}")
# except Exception:
# print(f"error occurred: {traceback.format_exc()}")
# pass
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5