diff --git a/.circleci/config.yml b/.circleci/config.yml index cb887077de..2312989a21 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -626,6 +626,50 @@ jobs: paths: - llm_translation_coverage.xml - llm_translation_coverage + batches_testing: + docker: + - image: cimg/python:3.11 + auth: + username: ${DOCKERHUB_USERNAME} + password: ${DOCKERHUB_PASSWORD} + working_directory: ~/project + + steps: + - checkout + - run: + name: Install Dependencies + command: | + python -m pip install --upgrade pip + python -m pip install -r requirements.txt + pip install "respx==0.21.1" + pip install "pytest==7.3.1" + pip install "pytest-retry==1.6.3" + pip install "pytest-asyncio==0.21.1" + pip install "pytest-cov==5.0.0" + pip install "google-generativeai==0.3.2" + pip install "google-cloud-aiplatform==1.43.0" + # Run pytest and generate JUnit XML report + - run: + name: Run tests + command: | + pwd + ls + python -m pytest -vv tests/batches_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5 + no_output_timeout: 120m + - run: + name: Rename the coverage files + command: | + mv coverage.xml batches_coverage.xml + mv .coverage batches_coverage + + # Store test results + - store_test_results: + path: test-results + - persist_to_workspace: + root: . + paths: + - batches_coverage.xml + - batches_coverage pass_through_unit_testing: docker: - image: cimg/python:3.11 @@ -1417,7 +1461,7 @@ jobs: python -m venv venv . venv/bin/activate pip install coverage - coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage + coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage coverage xml - codecov/upload: file: ./coverage.xml @@ -1714,6 +1758,12 @@ workflows: only: - main - /litellm_.*/ + - batches_testing: + filters: + branches: + only: + - main + - /litellm_.*/ - pass_through_unit_testing: filters: branches: @@ -1735,6 +1785,7 @@ workflows: - upload-coverage: requires: - llm_translation_testing + - batches_testing - pass_through_unit_testing - image_gen_testing - logging_testing @@ -1783,6 +1834,7 @@ workflows: - load_testing - test_bad_database_url - llm_translation_testing + - batches_testing - pass_through_unit_testing - image_gen_testing - logging_testing diff --git a/litellm/batches/batch_utils.py b/litellm/batches/batch_utils.py new file mode 100644 index 0000000000..0f68193695 --- /dev/null +++ b/litellm/batches/batch_utils.py @@ -0,0 +1,302 @@ +import asyncio +import datetime +import json +import threading +from typing import Any, List, Literal, Optional + +import litellm +from litellm._logging import verbose_logger +from litellm.constants import ( + BATCH_STATUS_POLL_INTERVAL_SECONDS, + BATCH_STATUS_POLL_MAX_ATTEMPTS, +) +from litellm.files.main import afile_content +from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj +from litellm.types.llms.openai import Batch +from litellm.types.utils import StandardLoggingPayload, Usage + + +async def batches_async_logging( + batch_id: str, + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", + logging_obj: Optional[LiteLLMLoggingObj] = None, + **kwargs, +): + """ + Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens + + + Polls retrieve_batch until it returns a batch with status "completed" or "failed" + """ + from .main import aretrieve_batch + + verbose_logger.debug( + ".....in _batches_async_logging... polling retrieve to get batch status" + ) + if logging_obj is None: + raise ValueError( + "logging_obj is None cannot calculate cost / log batch creation event" + ) + for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS): + try: + start_time = datetime.datetime.now() + batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider) + verbose_logger.debug( + "in _batches_async_logging... batch status= %s", batch.status + ) + + if batch.status == "completed": + end_time = datetime.datetime.now() + await _handle_completed_batch( + batch=batch, + custom_llm_provider=custom_llm_provider, + logging_obj=logging_obj, + start_time=start_time, + end_time=end_time, + **kwargs, + ) + break + elif batch.status == "failed": + pass + except Exception as e: + verbose_logger.error("error in batches_async_logging", e) + await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS) + + +async def _handle_completed_batch( + batch: Batch, + custom_llm_provider: Literal["openai", "azure", "vertex_ai"], + logging_obj: LiteLLMLoggingObj, + start_time: datetime.datetime, + end_time: datetime.datetime, + **kwargs, +) -> None: + """Helper function to process a completed batch and handle logging""" + # Get batch results + file_content_dictionary = await _get_batch_output_file_content_as_dictionary( + batch, custom_llm_provider + ) + + # Calculate costs and usage + batch_cost = await _batch_cost_calculator( + custom_llm_provider=custom_llm_provider, + file_content_dictionary=file_content_dictionary, + ) + batch_usage = _get_batch_job_total_usage_from_file_content( + file_content_dictionary=file_content_dictionary, + custom_llm_provider=custom_llm_provider, + ) + + # Handle logging + await _log_completed_batch( + logging_obj=logging_obj, + batch_usage=batch_usage, + batch_cost=batch_cost, + start_time=start_time, + end_time=end_time, + **kwargs, + ) + + +async def _log_completed_batch( + logging_obj: LiteLLMLoggingObj, + batch_usage: Usage, + batch_cost: float, + start_time: datetime.datetime, + end_time: datetime.datetime, + **kwargs, +) -> None: + """Helper function to handle all logging operations for a completed batch""" + logging_obj.call_type = "batch_success" + + standard_logging_object = _create_standard_logging_object_for_completed_batch( + kwargs=kwargs, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + batch_usage_object=batch_usage, + response_cost=batch_cost, + ) + + logging_obj.model_call_details["standard_logging_object"] = standard_logging_object + + # Launch async and sync logging handlers + asyncio.create_task( + logging_obj.async_success_handler( + result=None, + start_time=start_time, + end_time=end_time, + cache_hit=None, + ) + ) + threading.Thread( + target=logging_obj.success_handler, + args=(None, start_time, end_time), + ).start() + + +async def _batch_cost_calculator( + file_content_dictionary: List[dict], + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", +) -> float: + """ + Calculate the cost of a batch based on the output file id + """ + if custom_llm_provider == "vertex_ai": + raise ValueError("Vertex AI does not support file content retrieval") + total_cost = _get_batch_job_cost_from_file_content( + file_content_dictionary=file_content_dictionary, + custom_llm_provider=custom_llm_provider, + ) + verbose_logger.debug("total_cost=%s", total_cost) + return total_cost + + +async def _get_batch_output_file_content_as_dictionary( + batch: Batch, + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", +) -> List[dict]: + """ + Get the batch output file content as a list of dictionaries + """ + if custom_llm_provider == "vertex_ai": + raise ValueError("Vertex AI does not support file content retrieval") + + if batch.output_file_id is None: + raise ValueError("Output file id is None cannot retrieve file content") + + _file_content = await afile_content( + file_id=batch.output_file_id, + custom_llm_provider=custom_llm_provider, + ) + return _get_file_content_as_dictionary(_file_content.content) + + +def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]: + """ + Get the file content as a list of dictionaries from JSON Lines format + """ + try: + _file_content_str = file_content.decode("utf-8") + # Split by newlines and parse each line as a separate JSON object + json_objects = [] + for line in _file_content_str.strip().split("\n"): + if line: # Skip empty lines + json_objects.append(json.loads(line)) + verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4)) + return json_objects + except Exception as e: + raise e + + +def _get_batch_job_cost_from_file_content( + file_content_dictionary: List[dict], + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", +) -> float: + """ + Get the cost of a batch job from the file content + """ + try: + total_cost: float = 0.0 + # parse the file content as json + verbose_logger.debug( + "file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4) + ) + for _item in file_content_dictionary: + if _batch_response_was_successful(_item): + _response_body = _get_response_from_batch_job_output_file(_item) + total_cost += litellm.completion_cost( + completion_response=_response_body, + custom_llm_provider=custom_llm_provider, + ) + verbose_logger.debug("total_cost=%s", total_cost) + return total_cost + except Exception as e: + verbose_logger.error("error in _get_batch_job_cost_from_file_content", e) + raise e + + +def _get_batch_job_total_usage_from_file_content( + file_content_dictionary: List[dict], + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", +) -> Usage: + """ + Get the tokens of a batch job from the file content + """ + total_tokens: int = 0 + prompt_tokens: int = 0 + completion_tokens: int = 0 + for _item in file_content_dictionary: + if _batch_response_was_successful(_item): + _response_body = _get_response_from_batch_job_output_file(_item) + usage: Usage = _get_batch_job_usage_from_response_body(_response_body) + total_tokens += usage.total_tokens + prompt_tokens += usage.prompt_tokens + completion_tokens += usage.completion_tokens + return Usage( + total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + ) + + +def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage: + """ + Get the tokens of a batch job from the response body + """ + _usage_dict = response_body.get("usage", None) or {} + usage: Usage = Usage(**_usage_dict) + return usage + + +def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any: + """ + Get the response from the batch job output file + """ + _response: dict = batch_job_output_file.get("response", None) or {} + _response_body = _response.get("body", None) or {} + return _response_body + + +def _batch_response_was_successful(batch_job_output_file: dict) -> bool: + """ + Check if the batch job response status == 200 + """ + _response: dict = batch_job_output_file.get("response", None) or {} + return _response.get("status_code", None) == 200 + + +def _create_standard_logging_object_for_completed_batch( + kwargs: dict, + start_time: datetime.datetime, + end_time: datetime.datetime, + logging_obj: LiteLLMLoggingObj, + batch_usage_object: Usage, + response_cost: float, +) -> StandardLoggingPayload: + """ + Create a standard logging object for a completed batch + """ + from litellm.litellm_core_utils.litellm_logging import ( + get_standard_logging_object_payload, + ) + + standard_logging_object = get_standard_logging_object_payload( + kwargs=kwargs, + init_response_obj=None, + start_time=start_time, + end_time=end_time, + logging_obj=logging_obj, + status="success", + ) + + if standard_logging_object is None: + raise ValueError("unable to create standard logging object for completed batch") + + # Add Completed Batch Job Usage and Response Cost + standard_logging_object["call_type"] = "batch_success" + standard_logging_object["response_cost"] = response_cost + standard_logging_object["total_tokens"] = batch_usage_object.total_tokens + standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens + standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens + return standard_logging_object diff --git a/litellm/batches/main.py b/litellm/batches/main.py index 0d9ba55587..6e708b0585 100644 --- a/litellm/batches/main.py +++ b/litellm/batches/main.py @@ -27,6 +27,8 @@ from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRe from litellm.types.router import GenericLiteLLMParams from litellm.utils import client, supports_httpx_timeout +from .batch_utils import batches_async_logging + ####### ENVIRONMENT VARIABLES ################### openai_batches_instance = OpenAIBatchesAPI() azure_batches_instance = AzureBatchesAPI() @@ -71,10 +73,22 @@ async def acreate_batch( ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) init_response = await loop.run_in_executor(None, func_with_context) + if asyncio.iscoroutine(init_response): response = await init_response else: - response = init_response # type: ignore + response = init_response + + # Start async logging job + if response is not None: + asyncio.create_task( + batches_async_logging( + logging_obj=kwargs.get("litellm_logging_obj", None), + batch_id=response.id, + custom_llm_provider=custom_llm_provider, + **kwargs, + ) + ) return response except Exception as e: @@ -238,7 +252,7 @@ def create_batch( async def aretrieve_batch( batch_id: str, - custom_llm_provider: Literal["openai", "azure"] = "openai", + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", metadata: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, @@ -279,7 +293,7 @@ async def aretrieve_batch( def retrieve_batch( batch_id: str, - custom_llm_provider: Literal["openai", "azure"] = "openai", + custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai", metadata: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None, @@ -552,7 +566,6 @@ def list_batches( return response except Exception as e: raise e - pass def cancel_batch(): diff --git a/litellm/constants.py b/litellm/constants.py index 0cff9ab5ab..9fddc38e53 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -92,3 +92,6 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [ "generateQuery/", "optimize-prompt/", ] + +BATCH_STATUS_POLL_INTERVAL_SECONDS = 10 +BATCH_STATUS_POLL_MAX_ATTEMPTS = 10 diff --git a/litellm/llms/vertex_ai/files/handler.py b/litellm/llms/vertex_ai/files/handler.py index dca557a494..4bae106045 100644 --- a/litellm/llms/vertex_ai/files/handler.py +++ b/litellm/llms/vertex_ai/files/handler.py @@ -2,10 +2,12 @@ from typing import Any, Coroutine, Optional, Union import httpx +from litellm import LlmProviders from litellm.integrations.gcs_bucket.gcs_bucket_base import ( GCSBucketBase, GCSLoggingConfig, ) +from litellm.llms.custom_httpx.http_handler import get_async_httpx_client from litellm.types.llms.openai import CreateFileRequest, FileObject from .transformation import VertexAIFilesTransformation @@ -20,6 +22,12 @@ class VertexAIFilesHandler(GCSBucketBase): This implementation uploads files on GCS Buckets """ + def __init__(self): + super().__init__() + self.async_httpx_client = get_async_httpx_client( + llm_provider=LlmProviders.VERTEX_AI, + ) + pass async def async_create_file( diff --git a/litellm/main.py b/litellm/main.py index fb60116877..77f4af9f2b 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -826,6 +826,11 @@ def completion( # type: ignore # noqa: PLR0915 - It supports various optional parameters for customizing the completion behavior. - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. """ + ### VALIDATE Request ### + if model is None: + raise ValueError("model param not passed in.") + # validate messages + messages = validate_chat_completion_messages(messages=messages) ######### unpacking kwargs ##################### args = locals() api_base = kwargs.get("api_base", None) @@ -997,9 +1002,6 @@ def completion( # type: ignore # noqa: PLR0915 "aws_region_name", None ) # support region-based pricing for bedrock - ### VALIDATE USER MESSAGES ### - messages = validate_chat_completion_messages(messages=messages) - ### TIMEOUT LOGIC ### timeout = timeout or kwargs.get("request_timeout", 600) or 600 # set timeout for 10 minutes by default diff --git a/tests/batches_tests/adroit-crow-413218-bc47f303efc9.json b/tests/batches_tests/adroit-crow-413218-bc47f303efc9.json new file mode 100644 index 0000000000..e2fd8512b1 --- /dev/null +++ b/tests/batches_tests/adroit-crow-413218-bc47f303efc9.json @@ -0,0 +1,13 @@ +{ + "type": "service_account", + "project_id": "adroit-crow-413218", + "private_key_id": "", + "private_key": "", + "client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", + "client_id": "104886546564708740969", + "auth_uri": "https://accounts.google.com/o/oauth2/auth", + "token_uri": "https://oauth2.googleapis.com/token", + "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", + "universe_domain": "googleapis.com" +} diff --git a/tests/batches_tests/batch_job_results_furniture.jsonl b/tests/batches_tests/batch_job_results_furniture.jsonl new file mode 100644 index 0000000000..f026d1438a --- /dev/null +++ b/tests/batches_tests/batch_job_results_furniture.jsonl @@ -0,0 +1,3 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} + diff --git a/tests/batches_tests/openai_batch_completions.jsonl b/tests/batches_tests/openai_batch_completions.jsonl new file mode 100644 index 0000000000..f026d1438a --- /dev/null +++ b/tests/batches_tests/openai_batch_completions.jsonl @@ -0,0 +1,3 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} + diff --git a/tests/batches_tests/test_batches_logging_unit_tests.py b/tests/batches_tests/test_batches_logging_unit_tests.py new file mode 100644 index 0000000000..1ec4087794 --- /dev/null +++ b/tests/batches_tests/test_batches_logging_unit_tests.py @@ -0,0 +1,171 @@ +import asyncio +import json +import os +import sys +import traceback +from unittest.mock import AsyncMock, MagicMock, patch +from dotenv import load_dotenv + +load_dotenv() +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system-path +import logging +import time + +import pytest +from typing import Optional +import litellm +from litellm import create_batch, create_file +from litellm._logging import verbose_logger +from litellm.batches.batch_utils import ( + _batch_cost_calculator, + _get_file_content_as_dictionary, + _get_batch_job_cost_from_file_content, + _get_batch_job_total_usage_from_file_content, + _get_batch_job_usage_from_response_body, + _get_response_from_batch_job_output_file, + _batch_response_was_successful, +) + + +@pytest.fixture +def sample_file_content(): + return b""" +{"id": "batch_req_6769ca596b38819093d7ae9f522de924", "custom_id": "request-1", "response": {"status_code": 200, "request_id": "07bc45ab4e7e26ac23a0c949973327e7", "body": {"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7", "object": "chat.completion", "created": 1734986202, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! How can I assist you today?", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null} +{"id": "batch_req_6769ca597e588190920666612634e2b4", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "82e04f4c001fe2c127cbad199f5fd31b", "body": {"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP", "object": "chat.completion", "created": 1734986203, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! What can I do for you today?", "refusal": null}, "logprobs": null, "finish_reason": "length"}], "usage": {"prompt_tokens": 22, "completion_tokens": 10, "total_tokens": 32, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null} +""" + + +@pytest.fixture +def sample_file_content_dict(): + return [ + { + "id": "batch_req_6769ca596b38819093d7ae9f522de924", + "custom_id": "request-1", + "response": { + "status_code": 200, + "request_id": "07bc45ab4e7e26ac23a0c949973327e7", + "body": { + "id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7", + "object": "chat.completion", + "created": 1734986202, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?", + "refusal": None, + }, + "logprobs": None, + "finish_reason": "stop", + } + ], + "usage": { + "prompt_tokens": 20, + "completion_tokens": 10, + "total_tokens": 30, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + "system_fingerprint": "fp_0aa8d3e20b", + }, + }, + "error": None, + }, + { + "id": "batch_req_6769ca597e588190920666612634e2b4", + "custom_id": "request-2", + "response": { + "status_code": 200, + "request_id": "82e04f4c001fe2c127cbad199f5fd31b", + "body": { + "id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP", + "object": "chat.completion", + "created": 1734986203, + "model": "gpt-4o-mini-2024-07-18", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! What can I do for you today?", + "refusal": None, + }, + "logprobs": None, + "finish_reason": "length", + } + ], + "usage": { + "prompt_tokens": 22, + "completion_tokens": 10, + "total_tokens": 32, + "prompt_tokens_details": { + "cached_tokens": 0, + "audio_tokens": 0, + }, + "completion_tokens_details": { + "reasoning_tokens": 0, + "audio_tokens": 0, + "accepted_prediction_tokens": 0, + "rejected_prediction_tokens": 0, + }, + }, + "system_fingerprint": "fp_0aa8d3e20b", + }, + }, + "error": None, + }, + ] + + +def test_get_file_content_as_dictionary(sample_file_content): + result = _get_file_content_as_dictionary(sample_file_content) + assert len(result) == 2 + assert result[0]["id"] == "batch_req_6769ca596b38819093d7ae9f522de924" + assert result[0]["custom_id"] == "request-1" + assert result[0]["response"]["status_code"] == 200 + assert result[0]["response"]["body"]["usage"]["total_tokens"] == 30 + + +def test_get_batch_job_total_usage_from_file_content(sample_file_content_dict): + usage = _get_batch_job_total_usage_from_file_content( + sample_file_content_dict, custom_llm_provider="openai" + ) + assert usage.total_tokens == 62 # 30 + 32 + assert usage.prompt_tokens == 42 # 20 + 22 + assert usage.completion_tokens == 20 # 10 + 10 + + +@pytest.mark.asyncio +async def test_batch_cost_calculator(sample_file_content_dict): + """ + mock litellm.completion_cost to return 0.5 + + we know sample_file_content_dict has 2 successful responses + + so we expect the cost to be 0.5 * 2 = 1.0 + """ + with patch("litellm.completion_cost", return_value=0.5): + cost = await _batch_cost_calculator( + file_content_dictionary=sample_file_content_dict, + custom_llm_provider="openai", + ) + assert cost == 1.0 # 0.5 * 2 successful responses + + +def test_get_response_from_batch_job_output_file(sample_file_content_dict): + result = _get_response_from_batch_job_output_file(sample_file_content_dict[0]) + assert result["id"] == "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7" + assert result["object"] == "chat.completion" + assert result["usage"]["total_tokens"] == 30 diff --git a/tests/local_testing/test_openai_batches_and_files.py b/tests/batches_tests/test_openai_batches_and_files.py similarity index 77% rename from tests/local_testing/test_openai_batches_and_files.py rename to tests/batches_tests/test_openai_batches_and_files.py index 9c8ab79269..f97e38de51 100644 --- a/tests/local_testing/test_openai_batches_and_files.py +++ b/tests/batches_tests/test_openai_batches_and_files.py @@ -5,7 +5,7 @@ import json import os import sys import traceback - +import tempfile from dotenv import load_dotenv load_dotenv() @@ -20,7 +20,6 @@ from typing import Optional import litellm from litellm import create_batch, create_file from litellm._logging import verbose_logger -from test_gcs_bucket import load_vertex_ai_credentials verbose_logger.setLevel(logging.DEBUG) @@ -28,19 +27,47 @@ from litellm.integrations.custom_logger import CustomLogger from litellm.types.utils import StandardLoggingPayload -class TestCustomLogger(CustomLogger): - def __init__(self): - super().__init__() - self.standard_logging_object: Optional[StandardLoggingPayload] = None +def load_vertex_ai_credentials(): + # Define the path to the vertex_key.json file + print("loading vertex ai credentials") + os.environ["GCS_FLUSH_INTERVAL"] = "1" + filepath = os.path.dirname(os.path.abspath(__file__)) + vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json" - async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): - print( - "Success event logged with kwargs=", - kwargs, - "and response_obj=", - response_obj, - ) - self.standard_logging_object = kwargs["standard_logging_object"] + # Read the existing content of the file or create an empty dictionary + try: + with open(vertex_key_path, "r") as file: + # Read the file content + print("Read vertexai file path") + content = file.read() + + # If the file is empty or not valid JSON, create an empty dictionary + if not content or not content.strip(): + service_account_key_data = {} + else: + # Attempt to load the existing JSON content + file.seek(0) + service_account_key_data = json.load(file) + except FileNotFoundError: + # If the file doesn't exist, create an empty dictionary + service_account_key_data = {} + + # Update the service_account_key_data with environment variables + private_key_id = os.environ.get("GCS_PRIVATE_KEY_ID", "") + private_key = os.environ.get("GCS_PRIVATE_KEY", "") + private_key = private_key.replace("\\n", "\n") + service_account_key_data["private_key_id"] = private_key_id + service_account_key_data["private_key"] = private_key + + # Create a temporary file + with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file: + # Write the updated content to the temporary files + json.dump(service_account_key_data, temp_file, indent=2) + + # Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS + os.environ["GCS_PATH_SERVICE_ACCOUNT"] = os.path.abspath(temp_file.name) + os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) + print("created gcs path service account=", os.environ["GCS_PATH_SERVICE_ACCOUNT"]) @pytest.mark.parametrize("provider", ["openai"]) # , "azure" @@ -128,6 +155,21 @@ async def test_create_batch(provider): pass +class TestCustomLogger(CustomLogger): + def __init__(self): + super().__init__() + self.standard_logging_object: Optional[StandardLoggingPayload] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + print( + "Success event logged with kwargs=", + kwargs, + "and response_obj=", + response_obj, + ) + self.standard_logging_object = kwargs["standard_logging_object"] + + @pytest.mark.parametrize("provider", ["openai"]) # "azure" @pytest.mark.asyncio() @pytest.mark.flaky(retries=3, delay=1) @@ -142,6 +184,9 @@ async def test_async_create_batch(provider): # Don't have anymore Azure Quota return + custom_logger = TestCustomLogger() + litellm.callbacks = [custom_logger, "datadog"] + file_name = "openai_batch_completions.jsonl" _current_dir = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(_current_dir, file_name) @@ -179,7 +224,9 @@ async def test_async_create_batch(provider): create_batch_response.input_file_id == batch_input_file_id ), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}" - await asyncio.sleep(1) + await asyncio.sleep(6) + # Assert that the create batch event is logged on CustomLogger + assert custom_logger.standard_logging_object is not None retrieved_batch = await litellm.aretrieve_batch( batch_id=create_batch_response.id, custom_llm_provider=provider @@ -223,9 +270,10 @@ async def test_async_create_batch(provider): print("all_files_list = ", all_files_list) - # # write this file content to a file - # with open("file_content.json", "w") as f: - # json.dump(file_content, f) + result_file_name = "batch_job_results_furniture.jsonl" + + with open(result_file_name, "wb") as file: + file.write(file_content.content) def test_retrieve_batch(): @@ -241,8 +289,9 @@ def test_list_batch(): @pytest.mark.asyncio -async def test_vertex_batch_prediction(): +async def test_avertex_batch_prediction(): load_vertex_ai_credentials() + litellm.set_verbose = True file_name = "vertex_batch_completions.jsonl" _current_dir = os.path.dirname(os.path.abspath(__file__)) file_path = os.path.join(_current_dir, file_name) diff --git a/tests/batches_tests/vertex_batch_completions.jsonl b/tests/batches_tests/vertex_batch_completions.jsonl new file mode 100644 index 0000000000..ec899f8fc4 --- /dev/null +++ b/tests/batches_tests/vertex_batch_completions.jsonl @@ -0,0 +1,2 @@ +{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} +{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}} \ No newline at end of file diff --git a/tests/local_testing/test_bad_params.py b/tests/local_testing/test_bad_params.py index c18d462432..ef3b4596ec 100644 --- a/tests/local_testing/test_bad_params.py +++ b/tests/local_testing/test_bad_params.py @@ -23,7 +23,7 @@ model_val = None def test_completion_with_no_model(): # test on empty - with pytest.raises(ValueError): + with pytest.raises(TypeError): response = completion(messages=messages) @@ -36,39 +36,6 @@ def test_completion_with_empty_model(): pass -# def test_completion_catch_nlp_exception(): -# TEMP commented out NLP cloud API is unstable -# try: -# response = completion(model="dolphin", messages=messages, functions=[ -# { -# "name": "get_current_weather", -# "description": "Get the current weather in a given location", -# "parameters": { -# "type": "object", -# "properties": { -# "location": { -# "type": "string", -# "description": "The city and state, e.g. San Francisco, CA" -# }, -# "unit": { -# "type": "string", -# "enum": ["celsius", "fahrenheit"] -# } -# }, -# "required": ["location"] -# } -# } -# ]) - -# except Exception as e: -# if "Function calling is not supported by nlp_cloud" in str(e): -# pass -# else: -# pytest.fail(f'An error occurred {e}') - -# test_completion_catch_nlp_exception() - - def test_completion_invalid_param_cohere(): try: litellm.set_verbose = True @@ -94,9 +61,6 @@ def test_completion_function_call_cohere(): pass -# test_completion_function_call_cohere() - - def test_completion_function_call_openai(): try: messages = [{"role": "user", "content": "What is the weather like in Boston?"}] @@ -140,17 +104,3 @@ def test_completion_with_no_provider(): except Exception as e: print(f"error occurred: {e}") pass - - -# test_completion_with_no_provider() -# # bad key -# temp_key = os.environ.get("OPENAI_API_KEY") -# os.environ["OPENAI_API_KEY"] = "bad-key" -# # test on openai completion call -# try: -# response = completion(model="gpt-3.5-turbo", messages=messages) -# print(f"response: {response}") -# except Exception: -# print(f"error occurred: {traceback.format_exc()}") -# pass -# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5