mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(feat) Add cost tracking for /batches requests OpenAI (#7384)
* add basic logging for create`batch` * add create_batch as a call type * add basic dd logging for batches * basic batch creation logging on DD * batch endpoints add cost calc * fix batches_async_logging * separate folder for batches testing * new job for batches tests * test batches logging * fix validation logic * add vertex_batch_completions.jsonl * test test_async_create_batch * test_async_create_batch * update tests * test_completion_with_no_model * remove dead code * update load_vertex_ai_credentials * test_avertex_batch_prediction * update get async httpx client * fix get_async_httpx_client * update test_avertex_batch_prediction * fix batches testing config.yaml * add google deps * fix vertex files handler
This commit is contained in:
parent
87f19d6f13
commit
05b0d2026f
13 changed files with 649 additions and 78 deletions
|
@ -626,6 +626,50 @@ jobs:
|
||||||
paths:
|
paths:
|
||||||
- llm_translation_coverage.xml
|
- llm_translation_coverage.xml
|
||||||
- llm_translation_coverage
|
- llm_translation_coverage
|
||||||
|
batches_testing:
|
||||||
|
docker:
|
||||||
|
- image: cimg/python:3.11
|
||||||
|
auth:
|
||||||
|
username: ${DOCKERHUB_USERNAME}
|
||||||
|
password: ${DOCKERHUB_PASSWORD}
|
||||||
|
working_directory: ~/project
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- checkout
|
||||||
|
- run:
|
||||||
|
name: Install Dependencies
|
||||||
|
command: |
|
||||||
|
python -m pip install --upgrade pip
|
||||||
|
python -m pip install -r requirements.txt
|
||||||
|
pip install "respx==0.21.1"
|
||||||
|
pip install "pytest==7.3.1"
|
||||||
|
pip install "pytest-retry==1.6.3"
|
||||||
|
pip install "pytest-asyncio==0.21.1"
|
||||||
|
pip install "pytest-cov==5.0.0"
|
||||||
|
pip install "google-generativeai==0.3.2"
|
||||||
|
pip install "google-cloud-aiplatform==1.43.0"
|
||||||
|
# Run pytest and generate JUnit XML report
|
||||||
|
- run:
|
||||||
|
name: Run tests
|
||||||
|
command: |
|
||||||
|
pwd
|
||||||
|
ls
|
||||||
|
python -m pytest -vv tests/batches_tests --cov=litellm --cov-report=xml -x -s -v --junitxml=test-results/junit.xml --durations=5
|
||||||
|
no_output_timeout: 120m
|
||||||
|
- run:
|
||||||
|
name: Rename the coverage files
|
||||||
|
command: |
|
||||||
|
mv coverage.xml batches_coverage.xml
|
||||||
|
mv .coverage batches_coverage
|
||||||
|
|
||||||
|
# Store test results
|
||||||
|
- store_test_results:
|
||||||
|
path: test-results
|
||||||
|
- persist_to_workspace:
|
||||||
|
root: .
|
||||||
|
paths:
|
||||||
|
- batches_coverage.xml
|
||||||
|
- batches_coverage
|
||||||
pass_through_unit_testing:
|
pass_through_unit_testing:
|
||||||
docker:
|
docker:
|
||||||
- image: cimg/python:3.11
|
- image: cimg/python:3.11
|
||||||
|
@ -1417,7 +1461,7 @@ jobs:
|
||||||
python -m venv venv
|
python -m venv venv
|
||||||
. venv/bin/activate
|
. venv/bin/activate
|
||||||
pip install coverage
|
pip install coverage
|
||||||
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage
|
coverage combine llm_translation_coverage logging_coverage litellm_router_coverage local_testing_coverage litellm_assistants_api_coverage auth_ui_unit_tests_coverage langfuse_coverage caching_coverage litellm_proxy_unit_tests_coverage image_gen_coverage pass_through_unit_tests_coverage batches_coverage
|
||||||
coverage xml
|
coverage xml
|
||||||
- codecov/upload:
|
- codecov/upload:
|
||||||
file: ./coverage.xml
|
file: ./coverage.xml
|
||||||
|
@ -1714,6 +1758,12 @@ workflows:
|
||||||
only:
|
only:
|
||||||
- main
|
- main
|
||||||
- /litellm_.*/
|
- /litellm_.*/
|
||||||
|
- batches_testing:
|
||||||
|
filters:
|
||||||
|
branches:
|
||||||
|
only:
|
||||||
|
- main
|
||||||
|
- /litellm_.*/
|
||||||
- pass_through_unit_testing:
|
- pass_through_unit_testing:
|
||||||
filters:
|
filters:
|
||||||
branches:
|
branches:
|
||||||
|
@ -1735,6 +1785,7 @@ workflows:
|
||||||
- upload-coverage:
|
- upload-coverage:
|
||||||
requires:
|
requires:
|
||||||
- llm_translation_testing
|
- llm_translation_testing
|
||||||
|
- batches_testing
|
||||||
- pass_through_unit_testing
|
- pass_through_unit_testing
|
||||||
- image_gen_testing
|
- image_gen_testing
|
||||||
- logging_testing
|
- logging_testing
|
||||||
|
@ -1783,6 +1834,7 @@ workflows:
|
||||||
- load_testing
|
- load_testing
|
||||||
- test_bad_database_url
|
- test_bad_database_url
|
||||||
- llm_translation_testing
|
- llm_translation_testing
|
||||||
|
- batches_testing
|
||||||
- pass_through_unit_testing
|
- pass_through_unit_testing
|
||||||
- image_gen_testing
|
- image_gen_testing
|
||||||
- logging_testing
|
- logging_testing
|
||||||
|
|
302
litellm/batches/batch_utils.py
Normal file
302
litellm/batches/batch_utils.py
Normal file
|
@ -0,0 +1,302 @@
|
||||||
|
import asyncio
|
||||||
|
import datetime
|
||||||
|
import json
|
||||||
|
import threading
|
||||||
|
from typing import Any, List, Literal, Optional
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.constants import (
|
||||||
|
BATCH_STATUS_POLL_INTERVAL_SECONDS,
|
||||||
|
BATCH_STATUS_POLL_MAX_ATTEMPTS,
|
||||||
|
)
|
||||||
|
from litellm.files.main import afile_content
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||||
|
from litellm.types.llms.openai import Batch
|
||||||
|
from litellm.types.utils import StandardLoggingPayload, Usage
|
||||||
|
|
||||||
|
|
||||||
|
async def batches_async_logging(
|
||||||
|
batch_id: str,
|
||||||
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
|
logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Async Job waits for the batch to complete and then logs the completed batch usage - cost, total tokens, prompt tokens, completion tokens
|
||||||
|
|
||||||
|
|
||||||
|
Polls retrieve_batch until it returns a batch with status "completed" or "failed"
|
||||||
|
"""
|
||||||
|
from .main import aretrieve_batch
|
||||||
|
|
||||||
|
verbose_logger.debug(
|
||||||
|
".....in _batches_async_logging... polling retrieve to get batch status"
|
||||||
|
)
|
||||||
|
if logging_obj is None:
|
||||||
|
raise ValueError(
|
||||||
|
"logging_obj is None cannot calculate cost / log batch creation event"
|
||||||
|
)
|
||||||
|
for _ in range(BATCH_STATUS_POLL_MAX_ATTEMPTS):
|
||||||
|
try:
|
||||||
|
start_time = datetime.datetime.now()
|
||||||
|
batch: Batch = await aretrieve_batch(batch_id, custom_llm_provider)
|
||||||
|
verbose_logger.debug(
|
||||||
|
"in _batches_async_logging... batch status= %s", batch.status
|
||||||
|
)
|
||||||
|
|
||||||
|
if batch.status == "completed":
|
||||||
|
end_time = datetime.datetime.now()
|
||||||
|
await _handle_completed_batch(
|
||||||
|
batch=batch,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
elif batch.status == "failed":
|
||||||
|
pass
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("error in batches_async_logging", e)
|
||||||
|
await asyncio.sleep(BATCH_STATUS_POLL_INTERVAL_SECONDS)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_completed_batch(
|
||||||
|
batch: Batch,
|
||||||
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"],
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
end_time: datetime.datetime,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Helper function to process a completed batch and handle logging"""
|
||||||
|
# Get batch results
|
||||||
|
file_content_dictionary = await _get_batch_output_file_content_as_dictionary(
|
||||||
|
batch, custom_llm_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
# Calculate costs and usage
|
||||||
|
batch_cost = await _batch_cost_calculator(
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
file_content_dictionary=file_content_dictionary,
|
||||||
|
)
|
||||||
|
batch_usage = _get_batch_job_total_usage_from_file_content(
|
||||||
|
file_content_dictionary=file_content_dictionary,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle logging
|
||||||
|
await _log_completed_batch(
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
batch_usage=batch_usage,
|
||||||
|
batch_cost=batch_cost,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _log_completed_batch(
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
batch_usage: Usage,
|
||||||
|
batch_cost: float,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
end_time: datetime.datetime,
|
||||||
|
**kwargs,
|
||||||
|
) -> None:
|
||||||
|
"""Helper function to handle all logging operations for a completed batch"""
|
||||||
|
logging_obj.call_type = "batch_success"
|
||||||
|
|
||||||
|
standard_logging_object = _create_standard_logging_object_for_completed_batch(
|
||||||
|
kwargs=kwargs,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
batch_usage_object=batch_usage,
|
||||||
|
response_cost=batch_cost,
|
||||||
|
)
|
||||||
|
|
||||||
|
logging_obj.model_call_details["standard_logging_object"] = standard_logging_object
|
||||||
|
|
||||||
|
# Launch async and sync logging handlers
|
||||||
|
asyncio.create_task(
|
||||||
|
logging_obj.async_success_handler(
|
||||||
|
result=None,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
cache_hit=None,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
threading.Thread(
|
||||||
|
target=logging_obj.success_handler,
|
||||||
|
args=(None, start_time, end_time),
|
||||||
|
).start()
|
||||||
|
|
||||||
|
|
||||||
|
async def _batch_cost_calculator(
|
||||||
|
file_content_dictionary: List[dict],
|
||||||
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Calculate the cost of a batch based on the output file id
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
|
raise ValueError("Vertex AI does not support file content retrieval")
|
||||||
|
total_cost = _get_batch_job_cost_from_file_content(
|
||||||
|
file_content_dictionary=file_content_dictionary,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
verbose_logger.debug("total_cost=%s", total_cost)
|
||||||
|
return total_cost
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_batch_output_file_content_as_dictionary(
|
||||||
|
batch: Batch,
|
||||||
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
|
) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Get the batch output file content as a list of dictionaries
|
||||||
|
"""
|
||||||
|
if custom_llm_provider == "vertex_ai":
|
||||||
|
raise ValueError("Vertex AI does not support file content retrieval")
|
||||||
|
|
||||||
|
if batch.output_file_id is None:
|
||||||
|
raise ValueError("Output file id is None cannot retrieve file content")
|
||||||
|
|
||||||
|
_file_content = await afile_content(
|
||||||
|
file_id=batch.output_file_id,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
return _get_file_content_as_dictionary(_file_content.content)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_file_content_as_dictionary(file_content: bytes) -> List[dict]:
|
||||||
|
"""
|
||||||
|
Get the file content as a list of dictionaries from JSON Lines format
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_file_content_str = file_content.decode("utf-8")
|
||||||
|
# Split by newlines and parse each line as a separate JSON object
|
||||||
|
json_objects = []
|
||||||
|
for line in _file_content_str.strip().split("\n"):
|
||||||
|
if line: # Skip empty lines
|
||||||
|
json_objects.append(json.loads(line))
|
||||||
|
verbose_logger.debug("json_objects=%s", json.dumps(json_objects, indent=4))
|
||||||
|
return json_objects
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_job_cost_from_file_content(
|
||||||
|
file_content_dictionary: List[dict],
|
||||||
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
|
) -> float:
|
||||||
|
"""
|
||||||
|
Get the cost of a batch job from the file content
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
total_cost: float = 0.0
|
||||||
|
# parse the file content as json
|
||||||
|
verbose_logger.debug(
|
||||||
|
"file_content_dictionary=%s", json.dumps(file_content_dictionary, indent=4)
|
||||||
|
)
|
||||||
|
for _item in file_content_dictionary:
|
||||||
|
if _batch_response_was_successful(_item):
|
||||||
|
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||||
|
total_cost += litellm.completion_cost(
|
||||||
|
completion_response=_response_body,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
)
|
||||||
|
verbose_logger.debug("total_cost=%s", total_cost)
|
||||||
|
return total_cost
|
||||||
|
except Exception as e:
|
||||||
|
verbose_logger.error("error in _get_batch_job_cost_from_file_content", e)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_job_total_usage_from_file_content(
|
||||||
|
file_content_dictionary: List[dict],
|
||||||
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
|
) -> Usage:
|
||||||
|
"""
|
||||||
|
Get the tokens of a batch job from the file content
|
||||||
|
"""
|
||||||
|
total_tokens: int = 0
|
||||||
|
prompt_tokens: int = 0
|
||||||
|
completion_tokens: int = 0
|
||||||
|
for _item in file_content_dictionary:
|
||||||
|
if _batch_response_was_successful(_item):
|
||||||
|
_response_body = _get_response_from_batch_job_output_file(_item)
|
||||||
|
usage: Usage = _get_batch_job_usage_from_response_body(_response_body)
|
||||||
|
total_tokens += usage.total_tokens
|
||||||
|
prompt_tokens += usage.prompt_tokens
|
||||||
|
completion_tokens += usage.completion_tokens
|
||||||
|
return Usage(
|
||||||
|
total_tokens=total_tokens,
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_batch_job_usage_from_response_body(response_body: dict) -> Usage:
|
||||||
|
"""
|
||||||
|
Get the tokens of a batch job from the response body
|
||||||
|
"""
|
||||||
|
_usage_dict = response_body.get("usage", None) or {}
|
||||||
|
usage: Usage = Usage(**_usage_dict)
|
||||||
|
return usage
|
||||||
|
|
||||||
|
|
||||||
|
def _get_response_from_batch_job_output_file(batch_job_output_file: dict) -> Any:
|
||||||
|
"""
|
||||||
|
Get the response from the batch job output file
|
||||||
|
"""
|
||||||
|
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||||
|
_response_body = _response.get("body", None) or {}
|
||||||
|
return _response_body
|
||||||
|
|
||||||
|
|
||||||
|
def _batch_response_was_successful(batch_job_output_file: dict) -> bool:
|
||||||
|
"""
|
||||||
|
Check if the batch job response status == 200
|
||||||
|
"""
|
||||||
|
_response: dict = batch_job_output_file.get("response", None) or {}
|
||||||
|
return _response.get("status_code", None) == 200
|
||||||
|
|
||||||
|
|
||||||
|
def _create_standard_logging_object_for_completed_batch(
|
||||||
|
kwargs: dict,
|
||||||
|
start_time: datetime.datetime,
|
||||||
|
end_time: datetime.datetime,
|
||||||
|
logging_obj: LiteLLMLoggingObj,
|
||||||
|
batch_usage_object: Usage,
|
||||||
|
response_cost: float,
|
||||||
|
) -> StandardLoggingPayload:
|
||||||
|
"""
|
||||||
|
Create a standard logging object for a completed batch
|
||||||
|
"""
|
||||||
|
from litellm.litellm_core_utils.litellm_logging import (
|
||||||
|
get_standard_logging_object_payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
standard_logging_object = get_standard_logging_object_payload(
|
||||||
|
kwargs=kwargs,
|
||||||
|
init_response_obj=None,
|
||||||
|
start_time=start_time,
|
||||||
|
end_time=end_time,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
status="success",
|
||||||
|
)
|
||||||
|
|
||||||
|
if standard_logging_object is None:
|
||||||
|
raise ValueError("unable to create standard logging object for completed batch")
|
||||||
|
|
||||||
|
# Add Completed Batch Job Usage and Response Cost
|
||||||
|
standard_logging_object["call_type"] = "batch_success"
|
||||||
|
standard_logging_object["response_cost"] = response_cost
|
||||||
|
standard_logging_object["total_tokens"] = batch_usage_object.total_tokens
|
||||||
|
standard_logging_object["prompt_tokens"] = batch_usage_object.prompt_tokens
|
||||||
|
standard_logging_object["completion_tokens"] = batch_usage_object.completion_tokens
|
||||||
|
return standard_logging_object
|
|
@ -27,6 +27,8 @@ from litellm.types.llms.openai import Batch, CreateBatchRequest, RetrieveBatchRe
|
||||||
from litellm.types.router import GenericLiteLLMParams
|
from litellm.types.router import GenericLiteLLMParams
|
||||||
from litellm.utils import client, supports_httpx_timeout
|
from litellm.utils import client, supports_httpx_timeout
|
||||||
|
|
||||||
|
from .batch_utils import batches_async_logging
|
||||||
|
|
||||||
####### ENVIRONMENT VARIABLES ###################
|
####### ENVIRONMENT VARIABLES ###################
|
||||||
openai_batches_instance = OpenAIBatchesAPI()
|
openai_batches_instance = OpenAIBatchesAPI()
|
||||||
azure_batches_instance = AzureBatchesAPI()
|
azure_batches_instance = AzureBatchesAPI()
|
||||||
|
@ -71,10 +73,22 @@ async def acreate_batch(
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
func_with_context = partial(ctx.run, func)
|
func_with_context = partial(ctx.run, func)
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
|
||||||
if asyncio.iscoroutine(init_response):
|
if asyncio.iscoroutine(init_response):
|
||||||
response = await init_response
|
response = await init_response
|
||||||
else:
|
else:
|
||||||
response = init_response # type: ignore
|
response = init_response
|
||||||
|
|
||||||
|
# Start async logging job
|
||||||
|
if response is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
batches_async_logging(
|
||||||
|
logging_obj=kwargs.get("litellm_logging_obj", None),
|
||||||
|
batch_id=response.id,
|
||||||
|
custom_llm_provider=custom_llm_provider,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -238,7 +252,7 @@ def create_batch(
|
||||||
|
|
||||||
async def aretrieve_batch(
|
async def aretrieve_batch(
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
@ -279,7 +293,7 @@ async def aretrieve_batch(
|
||||||
|
|
||||||
def retrieve_batch(
|
def retrieve_batch(
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
custom_llm_provider: Literal["openai", "azure"] = "openai",
|
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||||
metadata: Optional[Dict[str, str]] = None,
|
metadata: Optional[Dict[str, str]] = None,
|
||||||
extra_headers: Optional[Dict[str, str]] = None,
|
extra_headers: Optional[Dict[str, str]] = None,
|
||||||
extra_body: Optional[Dict[str, str]] = None,
|
extra_body: Optional[Dict[str, str]] = None,
|
||||||
|
@ -552,7 +566,6 @@ def list_batches(
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
def cancel_batch():
|
def cancel_batch():
|
||||||
|
|
|
@ -92,3 +92,6 @@ BEDROCK_AGENT_RUNTIME_PASS_THROUGH_ROUTES = [
|
||||||
"generateQuery/",
|
"generateQuery/",
|
||||||
"optimize-prompt/",
|
"optimize-prompt/",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
BATCH_STATUS_POLL_INTERVAL_SECONDS = 10
|
||||||
|
BATCH_STATUS_POLL_MAX_ATTEMPTS = 10
|
||||||
|
|
|
@ -2,10 +2,12 @@ from typing import Any, Coroutine, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
|
from litellm import LlmProviders
|
||||||
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
||||||
GCSBucketBase,
|
GCSBucketBase,
|
||||||
GCSLoggingConfig,
|
GCSLoggingConfig,
|
||||||
)
|
)
|
||||||
|
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
from litellm.types.llms.openai import CreateFileRequest, FileObject
|
from litellm.types.llms.openai import CreateFileRequest, FileObject
|
||||||
|
|
||||||
from .transformation import VertexAIFilesTransformation
|
from .transformation import VertexAIFilesTransformation
|
||||||
|
@ -20,6 +22,12 @@ class VertexAIFilesHandler(GCSBucketBase):
|
||||||
This implementation uploads files on GCS Buckets
|
This implementation uploads files on GCS Buckets
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.async_httpx_client = get_async_httpx_client(
|
||||||
|
llm_provider=LlmProviders.VERTEX_AI,
|
||||||
|
)
|
||||||
|
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def async_create_file(
|
async def async_create_file(
|
||||||
|
|
|
@ -826,6 +826,11 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
- It supports various optional parameters for customizing the completion behavior.
|
- It supports various optional parameters for customizing the completion behavior.
|
||||||
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
|
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
|
||||||
"""
|
"""
|
||||||
|
### VALIDATE Request ###
|
||||||
|
if model is None:
|
||||||
|
raise ValueError("model param not passed in.")
|
||||||
|
# validate messages
|
||||||
|
messages = validate_chat_completion_messages(messages=messages)
|
||||||
######### unpacking kwargs #####################
|
######### unpacking kwargs #####################
|
||||||
args = locals()
|
args = locals()
|
||||||
api_base = kwargs.get("api_base", None)
|
api_base = kwargs.get("api_base", None)
|
||||||
|
@ -997,9 +1002,6 @@ def completion( # type: ignore # noqa: PLR0915
|
||||||
"aws_region_name", None
|
"aws_region_name", None
|
||||||
) # support region-based pricing for bedrock
|
) # support region-based pricing for bedrock
|
||||||
|
|
||||||
### VALIDATE USER MESSAGES ###
|
|
||||||
messages = validate_chat_completion_messages(messages=messages)
|
|
||||||
|
|
||||||
### TIMEOUT LOGIC ###
|
### TIMEOUT LOGIC ###
|
||||||
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
timeout = timeout or kwargs.get("request_timeout", 600) or 600
|
||||||
# set timeout for 10 minutes by default
|
# set timeout for 10 minutes by default
|
||||||
|
|
13
tests/batches_tests/adroit-crow-413218-bc47f303efc9.json
Normal file
13
tests/batches_tests/adroit-crow-413218-bc47f303efc9.json
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
{
|
||||||
|
"type": "service_account",
|
||||||
|
"project_id": "adroit-crow-413218",
|
||||||
|
"private_key_id": "",
|
||||||
|
"private_key": "",
|
||||||
|
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
|
||||||
|
"client_id": "104886546564708740969",
|
||||||
|
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||||
|
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
|
||||||
|
"universe_domain": "googleapis.com"
|
||||||
|
}
|
3
tests/batches_tests/batch_job_results_furniture.jsonl
Normal file
3
tests/batches_tests/batch_job_results_furniture.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
|
3
tests/batches_tests/openai_batch_completions.jsonl
Normal file
3
tests/batches_tests/openai_batch_completions.jsonl
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
|
171
tests/batches_tests/test_batches_logging_unit_tests.py
Normal file
171
tests/batches_tests/test_batches_logging_unit_tests.py
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import traceback
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system-path
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from typing import Optional
|
||||||
|
import litellm
|
||||||
|
from litellm import create_batch, create_file
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
from litellm.batches.batch_utils import (
|
||||||
|
_batch_cost_calculator,
|
||||||
|
_get_file_content_as_dictionary,
|
||||||
|
_get_batch_job_cost_from_file_content,
|
||||||
|
_get_batch_job_total_usage_from_file_content,
|
||||||
|
_get_batch_job_usage_from_response_body,
|
||||||
|
_get_response_from_batch_job_output_file,
|
||||||
|
_batch_response_was_successful,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_file_content():
|
||||||
|
return b"""
|
||||||
|
{"id": "batch_req_6769ca596b38819093d7ae9f522de924", "custom_id": "request-1", "response": {"status_code": 200, "request_id": "07bc45ab4e7e26ac23a0c949973327e7", "body": {"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7", "object": "chat.completion", "created": 1734986202, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! How can I assist you today?", "refusal": null}, "logprobs": null, "finish_reason": "stop"}], "usage": {"prompt_tokens": 20, "completion_tokens": 10, "total_tokens": 30, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null}
|
||||||
|
{"id": "batch_req_6769ca597e588190920666612634e2b4", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "82e04f4c001fe2c127cbad199f5fd31b", "body": {"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP", "object": "chat.completion", "created": 1734986203, "model": "gpt-4o-mini-2024-07-18", "choices": [{"index": 0, "message": {"role": "assistant", "content": "Hello! What can I do for you today?", "refusal": null}, "logprobs": null, "finish_reason": "length"}], "usage": {"prompt_tokens": 22, "completion_tokens": 10, "total_tokens": 32, "prompt_tokens_details": {"cached_tokens": 0, "audio_tokens": 0}, "completion_tokens_details": {"reasoning_tokens": 0, "audio_tokens": 0, "accepted_prediction_tokens": 0, "rejected_prediction_tokens": 0}}, "system_fingerprint": "fp_0aa8d3e20b"}}, "error": null}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def sample_file_content_dict():
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": "batch_req_6769ca596b38819093d7ae9f522de924",
|
||||||
|
"custom_id": "request-1",
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": "07bc45ab4e7e26ac23a0c949973327e7",
|
||||||
|
"body": {
|
||||||
|
"id": "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1734986202,
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! How can I assist you today?",
|
||||||
|
"refusal": None,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 20,
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"total_tokens": 30,
|
||||||
|
"prompt_tokens_details": {
|
||||||
|
"cached_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
},
|
||||||
|
"completion_tokens_details": {
|
||||||
|
"reasoning_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
"accepted_prediction_tokens": 0,
|
||||||
|
"rejected_prediction_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"system_fingerprint": "fp_0aa8d3e20b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"error": None,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "batch_req_6769ca597e588190920666612634e2b4",
|
||||||
|
"custom_id": "request-2",
|
||||||
|
"response": {
|
||||||
|
"status_code": 200,
|
||||||
|
"request_id": "82e04f4c001fe2c127cbad199f5fd31b",
|
||||||
|
"body": {
|
||||||
|
"id": "chatcmpl-AhjSNgVB4Oa4Hq0NruTRsBaEbRWUP",
|
||||||
|
"object": "chat.completion",
|
||||||
|
"created": 1734986203,
|
||||||
|
"model": "gpt-4o-mini-2024-07-18",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"message": {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": "Hello! What can I do for you today?",
|
||||||
|
"refusal": None,
|
||||||
|
},
|
||||||
|
"logprobs": None,
|
||||||
|
"finish_reason": "length",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usage": {
|
||||||
|
"prompt_tokens": 22,
|
||||||
|
"completion_tokens": 10,
|
||||||
|
"total_tokens": 32,
|
||||||
|
"prompt_tokens_details": {
|
||||||
|
"cached_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
},
|
||||||
|
"completion_tokens_details": {
|
||||||
|
"reasoning_tokens": 0,
|
||||||
|
"audio_tokens": 0,
|
||||||
|
"accepted_prediction_tokens": 0,
|
||||||
|
"rejected_prediction_tokens": 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"system_fingerprint": "fp_0aa8d3e20b",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"error": None,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_file_content_as_dictionary(sample_file_content):
|
||||||
|
result = _get_file_content_as_dictionary(sample_file_content)
|
||||||
|
assert len(result) == 2
|
||||||
|
assert result[0]["id"] == "batch_req_6769ca596b38819093d7ae9f522de924"
|
||||||
|
assert result[0]["custom_id"] == "request-1"
|
||||||
|
assert result[0]["response"]["status_code"] == 200
|
||||||
|
assert result[0]["response"]["body"]["usage"]["total_tokens"] == 30
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_batch_job_total_usage_from_file_content(sample_file_content_dict):
|
||||||
|
usage = _get_batch_job_total_usage_from_file_content(
|
||||||
|
sample_file_content_dict, custom_llm_provider="openai"
|
||||||
|
)
|
||||||
|
assert usage.total_tokens == 62 # 30 + 32
|
||||||
|
assert usage.prompt_tokens == 42 # 20 + 22
|
||||||
|
assert usage.completion_tokens == 20 # 10 + 10
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_cost_calculator(sample_file_content_dict):
|
||||||
|
"""
|
||||||
|
mock litellm.completion_cost to return 0.5
|
||||||
|
|
||||||
|
we know sample_file_content_dict has 2 successful responses
|
||||||
|
|
||||||
|
so we expect the cost to be 0.5 * 2 = 1.0
|
||||||
|
"""
|
||||||
|
with patch("litellm.completion_cost", return_value=0.5):
|
||||||
|
cost = await _batch_cost_calculator(
|
||||||
|
file_content_dictionary=sample_file_content_dict,
|
||||||
|
custom_llm_provider="openai",
|
||||||
|
)
|
||||||
|
assert cost == 1.0 # 0.5 * 2 successful responses
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_response_from_batch_job_output_file(sample_file_content_dict):
|
||||||
|
result = _get_response_from_batch_job_output_file(sample_file_content_dict[0])
|
||||||
|
assert result["id"] == "chatcmpl-AhjSMl7oZ79yIPHLRYgmgXSixTJr7"
|
||||||
|
assert result["object"] == "chat.completion"
|
||||||
|
assert result["usage"]["total_tokens"] == 30
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
|
import tempfile
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
load_dotenv()
|
load_dotenv()
|
||||||
|
@ -20,7 +20,6 @@ from typing import Optional
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import create_batch, create_file
|
from litellm import create_batch, create_file
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
from test_gcs_bucket import load_vertex_ai_credentials
|
|
||||||
|
|
||||||
verbose_logger.setLevel(logging.DEBUG)
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
@ -28,19 +27,47 @@ from litellm.integrations.custom_logger import CustomLogger
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
|
||||||
class TestCustomLogger(CustomLogger):
|
def load_vertex_ai_credentials():
|
||||||
def __init__(self):
|
# Define the path to the vertex_key.json file
|
||||||
super().__init__()
|
print("loading vertex ai credentials")
|
||||||
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
os.environ["GCS_FLUSH_INTERVAL"] = "1"
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
vertex_key_path = filepath + "/adroit-crow-413218-bc47f303efc9.json"
|
||||||
|
|
||||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
# Read the existing content of the file or create an empty dictionary
|
||||||
print(
|
try:
|
||||||
"Success event logged with kwargs=",
|
with open(vertex_key_path, "r") as file:
|
||||||
kwargs,
|
# Read the file content
|
||||||
"and response_obj=",
|
print("Read vertexai file path")
|
||||||
response_obj,
|
content = file.read()
|
||||||
)
|
|
||||||
self.standard_logging_object = kwargs["standard_logging_object"]
|
# If the file is empty or not valid JSON, create an empty dictionary
|
||||||
|
if not content or not content.strip():
|
||||||
|
service_account_key_data = {}
|
||||||
|
else:
|
||||||
|
# Attempt to load the existing JSON content
|
||||||
|
file.seek(0)
|
||||||
|
service_account_key_data = json.load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
# If the file doesn't exist, create an empty dictionary
|
||||||
|
service_account_key_data = {}
|
||||||
|
|
||||||
|
# Update the service_account_key_data with environment variables
|
||||||
|
private_key_id = os.environ.get("GCS_PRIVATE_KEY_ID", "")
|
||||||
|
private_key = os.environ.get("GCS_PRIVATE_KEY", "")
|
||||||
|
private_key = private_key.replace("\\n", "\n")
|
||||||
|
service_account_key_data["private_key_id"] = private_key_id
|
||||||
|
service_account_key_data["private_key"] = private_key
|
||||||
|
|
||||||
|
# Create a temporary file
|
||||||
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
||||||
|
# Write the updated content to the temporary files
|
||||||
|
json.dump(service_account_key_data, temp_file, indent=2)
|
||||||
|
|
||||||
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
||||||
|
os.environ["GCS_PATH_SERVICE_ACCOUNT"] = os.path.abspath(temp_file.name)
|
||||||
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||||
|
print("created gcs path service account=", os.environ["GCS_PATH_SERVICE_ACCOUNT"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
||||||
|
@ -128,6 +155,21 @@ async def test_create_batch(provider):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomLogger(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
print(
|
||||||
|
"Success event logged with kwargs=",
|
||||||
|
kwargs,
|
||||||
|
"and response_obj=",
|
||||||
|
response_obj,
|
||||||
|
)
|
||||||
|
self.standard_logging_object = kwargs["standard_logging_object"]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
|
@pytest.mark.parametrize("provider", ["openai"]) # "azure"
|
||||||
@pytest.mark.asyncio()
|
@pytest.mark.asyncio()
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
@ -142,6 +184,9 @@ async def test_async_create_batch(provider):
|
||||||
# Don't have anymore Azure Quota
|
# Don't have anymore Azure Quota
|
||||||
return
|
return
|
||||||
|
|
||||||
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.callbacks = [custom_logger, "datadog"]
|
||||||
|
|
||||||
file_name = "openai_batch_completions.jsonl"
|
file_name = "openai_batch_completions.jsonl"
|
||||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
file_path = os.path.join(_current_dir, file_name)
|
file_path = os.path.join(_current_dir, file_name)
|
||||||
|
@ -179,7 +224,9 @@ async def test_async_create_batch(provider):
|
||||||
create_batch_response.input_file_id == batch_input_file_id
|
create_batch_response.input_file_id == batch_input_file_id
|
||||||
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
||||||
|
|
||||||
await asyncio.sleep(1)
|
await asyncio.sleep(6)
|
||||||
|
# Assert that the create batch event is logged on CustomLogger
|
||||||
|
assert custom_logger.standard_logging_object is not None
|
||||||
|
|
||||||
retrieved_batch = await litellm.aretrieve_batch(
|
retrieved_batch = await litellm.aretrieve_batch(
|
||||||
batch_id=create_batch_response.id, custom_llm_provider=provider
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
||||||
|
@ -223,9 +270,10 @@ async def test_async_create_batch(provider):
|
||||||
|
|
||||||
print("all_files_list = ", all_files_list)
|
print("all_files_list = ", all_files_list)
|
||||||
|
|
||||||
# # write this file content to a file
|
result_file_name = "batch_job_results_furniture.jsonl"
|
||||||
# with open("file_content.json", "w") as f:
|
|
||||||
# json.dump(file_content, f)
|
with open(result_file_name, "wb") as file:
|
||||||
|
file.write(file_content.content)
|
||||||
|
|
||||||
|
|
||||||
def test_retrieve_batch():
|
def test_retrieve_batch():
|
||||||
|
@ -241,8 +289,9 @@ def test_list_batch():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_batch_prediction():
|
async def test_avertex_batch_prediction():
|
||||||
load_vertex_ai_credentials()
|
load_vertex_ai_credentials()
|
||||||
|
litellm.set_verbose = True
|
||||||
file_name = "vertex_batch_completions.jsonl"
|
file_name = "vertex_batch_completions.jsonl"
|
||||||
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
file_path = os.path.join(_current_dir, file_name)
|
file_path = os.path.join(_current_dir, file_name)
|
2
tests/batches_tests/vertex_batch_completions.jsonl
Normal file
2
tests/batches_tests/vertex_batch_completions.jsonl
Normal file
|
@ -0,0 +1,2 @@
|
||||||
|
{"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
||||||
|
{"custom_id": "request-2", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gemini-1.5-flash-001", "messages": [{"role": "system", "content": "You are an unhelpful assistant."},{"role": "user", "content": "Hello world!"}],"max_tokens": 10}}
|
|
@ -23,7 +23,7 @@ model_val = None
|
||||||
|
|
||||||
def test_completion_with_no_model():
|
def test_completion_with_no_model():
|
||||||
# test on empty
|
# test on empty
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(TypeError):
|
||||||
response = completion(messages=messages)
|
response = completion(messages=messages)
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,39 +36,6 @@ def test_completion_with_empty_model():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# def test_completion_catch_nlp_exception():
|
|
||||||
# TEMP commented out NLP cloud API is unstable
|
|
||||||
# try:
|
|
||||||
# response = completion(model="dolphin", messages=messages, functions=[
|
|
||||||
# {
|
|
||||||
# "name": "get_current_weather",
|
|
||||||
# "description": "Get the current weather in a given location",
|
|
||||||
# "parameters": {
|
|
||||||
# "type": "object",
|
|
||||||
# "properties": {
|
|
||||||
# "location": {
|
|
||||||
# "type": "string",
|
|
||||||
# "description": "The city and state, e.g. San Francisco, CA"
|
|
||||||
# },
|
|
||||||
# "unit": {
|
|
||||||
# "type": "string",
|
|
||||||
# "enum": ["celsius", "fahrenheit"]
|
|
||||||
# }
|
|
||||||
# },
|
|
||||||
# "required": ["location"]
|
|
||||||
# }
|
|
||||||
# }
|
|
||||||
# ])
|
|
||||||
|
|
||||||
# except Exception as e:
|
|
||||||
# if "Function calling is not supported by nlp_cloud" in str(e):
|
|
||||||
# pass
|
|
||||||
# else:
|
|
||||||
# pytest.fail(f'An error occurred {e}')
|
|
||||||
|
|
||||||
# test_completion_catch_nlp_exception()
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_invalid_param_cohere():
|
def test_completion_invalid_param_cohere():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
@ -94,9 +61,6 @@ def test_completion_function_call_cohere():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# test_completion_function_call_cohere()
|
|
||||||
|
|
||||||
|
|
||||||
def test_completion_function_call_openai():
|
def test_completion_function_call_openai():
|
||||||
try:
|
try:
|
||||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||||
|
@ -140,17 +104,3 @@ def test_completion_with_no_provider():
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"error occurred: {e}")
|
print(f"error occurred: {e}")
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# test_completion_with_no_provider()
|
|
||||||
# # bad key
|
|
||||||
# temp_key = os.environ.get("OPENAI_API_KEY")
|
|
||||||
# os.environ["OPENAI_API_KEY"] = "bad-key"
|
|
||||||
# # test on openai completion call
|
|
||||||
# try:
|
|
||||||
# response = completion(model="gpt-3.5-turbo", messages=messages)
|
|
||||||
# print(f"response: {response}")
|
|
||||||
# except Exception:
|
|
||||||
# print(f"error occurred: {traceback.format_exc()}")
|
|
||||||
# pass
|
|
||||||
# os.environ["OPENAI_API_KEY"] = str(temp_key) # this passes linting#5
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue