mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
396 lines
14 KiB
Python
396 lines
14 KiB
Python
# What is this?
|
|
## Unit Tests for OpenAI Batches API
|
|
import asyncio
|
|
import json
|
|
import os
|
|
import sys
|
|
import traceback
|
|
import tempfile
|
|
from dotenv import load_dotenv
|
|
|
|
load_dotenv()
|
|
sys.path.insert(
|
|
0, os.path.abspath("../..")
|
|
) # Adds the parent directory to the system-path
|
|
|
|
import logging
|
|
import time
|
|
|
|
import pytest
|
|
from typing import Optional
|
|
import litellm
|
|
from litellm import create_batch, create_file
|
|
from litellm._logging import verbose_logger
|
|
|
|
verbose_logger.setLevel(logging.DEBUG)
|
|
|
|
from litellm.integrations.custom_logger import CustomLogger
|
|
from litellm.types.utils import StandardLoggingPayload
|
|
import random
|
|
|
|
|
|
def load_vertex_ai_credentials():
|
|
# Define the path to the vertex_key.json file
|
|
print("loading vertex ai credentials")
|
|
os.environ["GCS_FLUSH_INTERVAL"] = "1"
|
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
|
vertex_key_path = filepath + "/pathrise-convert-1606954137718.json"
|
|
|
|
# Read the existing content of the file or create an empty dictionary
|
|
try:
|
|
with open(vertex_key_path, "r") as file:
|
|
# Read the file content
|
|
print("Read vertexai file path")
|
|
content = file.read()
|
|
|
|
# If the file is empty or not valid JSON, create an empty dictionary
|
|
if not content or not content.strip():
|
|
service_account_key_data = {}
|
|
else:
|
|
# Attempt to load the existing JSON content
|
|
file.seek(0)
|
|
service_account_key_data = json.load(file)
|
|
except FileNotFoundError:
|
|
# If the file doesn't exist, create an empty dictionary
|
|
service_account_key_data = {}
|
|
|
|
# Update the service_account_key_data with environment variables
|
|
private_key_id = os.environ.get("GCS_PRIVATE_KEY_ID", "")
|
|
private_key = os.environ.get("GCS_PRIVATE_KEY", "")
|
|
private_key = private_key.replace("\\n", "\n")
|
|
service_account_key_data["private_key_id"] = private_key_id
|
|
service_account_key_data["private_key"] = private_key
|
|
|
|
# Create a temporary file
|
|
with tempfile.NamedTemporaryFile(mode="w+", delete=False) as temp_file:
|
|
# Write the updated content to the temporary files
|
|
json.dump(service_account_key_data, temp_file, indent=2)
|
|
|
|
# Export the temporary file as GOOGLE_APPLICATION_CREDENTIALS
|
|
os.environ["GCS_PATH_SERVICE_ACCOUNT"] = os.path.abspath(temp_file.name)
|
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
|
print("created gcs path service account=", os.environ["GCS_PATH_SERVICE_ACCOUNT"])
|
|
|
|
|
|
@pytest.mark.parametrize("provider", ["openai"]) # , "azure"
|
|
@pytest.mark.asyncio
|
|
async def test_create_batch(provider):
|
|
"""
|
|
1. Create File for Batch completion
|
|
2. Create Batch Request
|
|
3. Retrieve the specific batch
|
|
"""
|
|
if provider == "azure":
|
|
# Don't have anymore Azure Quota
|
|
return
|
|
file_name = "openai_batch_completions.jsonl"
|
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
file_path = os.path.join(_current_dir, file_name)
|
|
|
|
file_obj = await litellm.acreate_file(
|
|
file=open(file_path, "rb"),
|
|
purpose="batch",
|
|
custom_llm_provider=provider,
|
|
)
|
|
print("Response from creating file=", file_obj)
|
|
|
|
batch_input_file_id = file_obj.id
|
|
assert (
|
|
batch_input_file_id is not None
|
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
|
|
|
await asyncio.sleep(1)
|
|
create_batch_response = await litellm.acreate_batch(
|
|
completion_window="24h",
|
|
endpoint="/v1/chat/completions",
|
|
input_file_id=batch_input_file_id,
|
|
custom_llm_provider=provider,
|
|
metadata={"key1": "value1", "key2": "value2"},
|
|
)
|
|
|
|
print("response from litellm.create_batch=", create_batch_response)
|
|
await asyncio.sleep(6)
|
|
|
|
assert (
|
|
create_batch_response.id is not None
|
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
|
assert (
|
|
create_batch_response.endpoint == "/v1/chat/completions"
|
|
or create_batch_response.endpoint == "/chat/completions"
|
|
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
|
assert (
|
|
create_batch_response.input_file_id == batch_input_file_id
|
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
|
|
|
retrieved_batch = await litellm.aretrieve_batch(
|
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
|
)
|
|
print("retrieved batch=", retrieved_batch)
|
|
# just assert that we retrieved a non None batch
|
|
|
|
assert retrieved_batch.id == create_batch_response.id
|
|
|
|
# list all batches
|
|
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2)
|
|
print("list_batches=", list_batches)
|
|
|
|
file_content = await litellm.afile_content(
|
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
|
)
|
|
|
|
result = file_content.content
|
|
|
|
result_file_name = "batch_job_results_furniture.jsonl"
|
|
|
|
with open(result_file_name, "wb") as file:
|
|
file.write(result)
|
|
|
|
# Cancel Batch
|
|
cancel_batch_response = await litellm.acancel_batch(
|
|
batch_id=create_batch_response.id,
|
|
custom_llm_provider=provider,
|
|
)
|
|
print("cancel_batch_response=", cancel_batch_response)
|
|
|
|
pass
|
|
|
|
|
|
class TestCustomLogger(CustomLogger):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.standard_logging_object: Optional[StandardLoggingPayload] = None
|
|
|
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
|
print(
|
|
"Success event logged with kwargs=",
|
|
kwargs,
|
|
"and response_obj=",
|
|
response_obj,
|
|
)
|
|
self.standard_logging_object = kwargs["standard_logging_object"]
|
|
|
|
|
|
@pytest.mark.parametrize("provider", ["azure", "openai"]) # "azure"
|
|
@pytest.mark.asyncio()
|
|
@pytest.mark.flaky(retries=3, delay=1)
|
|
async def test_async_create_batch(provider):
|
|
"""
|
|
1. Create File for Batch completion
|
|
2. Create Batch Request
|
|
3. Retrieve the specific batch
|
|
"""
|
|
print("Testing async create batch")
|
|
|
|
custom_logger = TestCustomLogger()
|
|
litellm.callbacks = [custom_logger, "datadog"]
|
|
|
|
file_name = "openai_batch_completions.jsonl"
|
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
file_path = os.path.join(_current_dir, file_name)
|
|
file_obj = await litellm.acreate_file(
|
|
file=open(file_path, "rb"),
|
|
purpose="batch",
|
|
custom_llm_provider=provider,
|
|
)
|
|
print("Response from creating file=", file_obj)
|
|
|
|
await asyncio.sleep(10)
|
|
batch_input_file_id = file_obj.id
|
|
assert (
|
|
batch_input_file_id is not None
|
|
), "Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
|
|
|
extra_metadata_field = {
|
|
"user_api_key_alias": "special_api_key_alias",
|
|
"user_api_key_team_alias": "special_team_alias",
|
|
}
|
|
create_batch_response = await litellm.acreate_batch(
|
|
completion_window="24h",
|
|
endpoint="/v1/chat/completions",
|
|
input_file_id=batch_input_file_id,
|
|
custom_llm_provider=provider,
|
|
metadata={"key1": "value1", "key2": "value2"},
|
|
# litellm specific param - used for logging metadata on logging callback
|
|
litellm_metadata=extra_metadata_field,
|
|
)
|
|
|
|
print("response from litellm.create_batch=", create_batch_response)
|
|
|
|
assert (
|
|
create_batch_response.id is not None
|
|
), f"Failed to create batch, expected a non null batch_id but got {create_batch_response.id}"
|
|
assert (
|
|
create_batch_response.endpoint == "/v1/chat/completions"
|
|
or create_batch_response.endpoint == "/chat/completions"
|
|
), f"Failed to create batch, expected endpoint to be /v1/chat/completions but got {create_batch_response.endpoint}"
|
|
assert (
|
|
create_batch_response.input_file_id == batch_input_file_id
|
|
), f"Failed to create batch, expected input_file_id to be {batch_input_file_id} but got {create_batch_response.input_file_id}"
|
|
|
|
await asyncio.sleep(6)
|
|
# Assert that the create batch event is logged on CustomLogger
|
|
assert custom_logger.standard_logging_object is not None
|
|
print(
|
|
"standard_logging_object=",
|
|
json.dumps(custom_logger.standard_logging_object, indent=4, default=str),
|
|
)
|
|
assert (
|
|
custom_logger.standard_logging_object["metadata"]["user_api_key_alias"]
|
|
== extra_metadata_field["user_api_key_alias"]
|
|
)
|
|
assert (
|
|
custom_logger.standard_logging_object["metadata"]["user_api_key_team_alias"]
|
|
== extra_metadata_field["user_api_key_team_alias"]
|
|
)
|
|
|
|
retrieved_batch = await litellm.aretrieve_batch(
|
|
batch_id=create_batch_response.id, custom_llm_provider=provider
|
|
)
|
|
print("retrieved batch=", retrieved_batch)
|
|
# just assert that we retrieved a non None batch
|
|
|
|
assert retrieved_batch.id == create_batch_response.id
|
|
|
|
# list all batches
|
|
list_batches = await litellm.alist_batches(custom_llm_provider=provider, limit=2)
|
|
print("list_batches=", list_batches)
|
|
|
|
# try to get file content for our original file
|
|
|
|
file_content = await litellm.afile_content(
|
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
|
)
|
|
|
|
print("file content = ", file_content)
|
|
|
|
# file obj
|
|
file_obj = await litellm.afile_retrieve(
|
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
|
)
|
|
print("file obj = ", file_obj)
|
|
assert file_obj.id == batch_input_file_id
|
|
|
|
# delete file
|
|
delete_file_response = await litellm.afile_delete(
|
|
file_id=batch_input_file_id, custom_llm_provider=provider
|
|
)
|
|
|
|
print("delete file response = ", delete_file_response)
|
|
|
|
assert delete_file_response.id == batch_input_file_id
|
|
|
|
all_files_list = await litellm.afile_list(
|
|
custom_llm_provider=provider,
|
|
)
|
|
|
|
print("all_files_list = ", all_files_list)
|
|
|
|
result_file_name = "batch_job_results_furniture.jsonl"
|
|
|
|
with open(result_file_name, "wb") as file:
|
|
file.write(file_content.content)
|
|
|
|
# Cancel Batch
|
|
cancel_batch_response = await litellm.acancel_batch(
|
|
batch_id=create_batch_response.id,
|
|
custom_llm_provider=provider,
|
|
)
|
|
print("cancel_batch_response=", cancel_batch_response)
|
|
|
|
if random.randint(1, 20) == 1:
|
|
print("Running random cleanup of Azure files and models...")
|
|
cleanup_azure_files()
|
|
cleanup_azure_ft_models()
|
|
|
|
|
|
def cleanup_azure_files():
|
|
"""
|
|
Delete all files for Azure - helper for when we run out of Azure Files Quota
|
|
"""
|
|
azure_files = litellm.file_list(
|
|
custom_llm_provider="azure",
|
|
api_key=os.getenv("AZURE_FT_API_KEY"),
|
|
api_base=os.getenv("AZURE_FT_API_BASE"),
|
|
)
|
|
print("azure_files=", azure_files)
|
|
for _file in azure_files:
|
|
print("deleting file=", _file)
|
|
delete_file_response = litellm.file_delete(
|
|
file_id=_file.id,
|
|
custom_llm_provider="azure",
|
|
api_key=os.getenv("AZURE_FT_API_KEY"),
|
|
api_base=os.getenv("AZURE_FT_API_BASE"),
|
|
)
|
|
print("delete_file_response=", delete_file_response)
|
|
assert delete_file_response.id == _file.id
|
|
|
|
|
|
def cleanup_azure_ft_models():
|
|
"""
|
|
Test CLEANUP: Delete all existing fine tuning jobs for Azure
|
|
"""
|
|
try:
|
|
from openai import AzureOpenAI
|
|
import requests
|
|
|
|
client = AzureOpenAI(
|
|
api_key=os.getenv("AZURE_FT_API_KEY"),
|
|
azure_endpoint=os.getenv("AZURE_FT_API_BASE"),
|
|
api_version=os.getenv("AZURE_API_VERSION"),
|
|
)
|
|
|
|
_list_ft_jobs = client.fine_tuning.jobs.list()
|
|
print("_list_ft_jobs=", _list_ft_jobs)
|
|
|
|
# delete all ft jobs make post request to this
|
|
# Delete all fine-tuning jobs
|
|
for job in _list_ft_jobs:
|
|
try:
|
|
endpoint = os.getenv("AZURE_FT_API_BASE").rstrip("/")
|
|
url = f"{endpoint}/openai/fine_tuning/jobs/{job.id}?api-version=2024-10-21"
|
|
print("url=", url)
|
|
|
|
headers = {
|
|
"api-key": os.getenv("AZURE_FT_API_KEY"),
|
|
"Content-Type": "application/json",
|
|
}
|
|
|
|
response = requests.delete(url, headers=headers)
|
|
print(f"Deleting job {job.id}: Status {response.status_code}")
|
|
if response.status_code != 204:
|
|
print(f"Error deleting job {job.id}: {response.text}")
|
|
|
|
except Exception as e:
|
|
print(f"Error deleting job {job.id}: {str(e)}")
|
|
except Exception as e:
|
|
print(f"Error on cleanup_azure_ft_models: {str(e)}")
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_avertex_batch_prediction():
|
|
load_vertex_ai_credentials()
|
|
litellm.set_verbose = True
|
|
file_name = "vertex_batch_completions.jsonl"
|
|
_current_dir = os.path.dirname(os.path.abspath(__file__))
|
|
file_path = os.path.join(_current_dir, file_name)
|
|
file_obj = await litellm.acreate_file(
|
|
file=open(file_path, "rb"),
|
|
purpose="batch",
|
|
custom_llm_provider="vertex_ai",
|
|
)
|
|
print("Response from creating file=", file_obj)
|
|
|
|
batch_input_file_id = file_obj.id
|
|
assert (
|
|
batch_input_file_id is not None
|
|
), f"Failed to create file, expected a non null file_id but got {batch_input_file_id}"
|
|
|
|
create_batch_response = await litellm.acreate_batch(
|
|
completion_window="24h",
|
|
endpoint="/v1/chat/completions",
|
|
input_file_id=batch_input_file_id,
|
|
custom_llm_provider="vertex_ai",
|
|
metadata={"key1": "value1", "key2": "value2"},
|
|
)
|
|
print("create_batch_response=", create_batch_response)
|
|
pass
|