Added test to check if acompletion is using the same parameters as CompletionRequest attributes. Added functools to client decorator to expose acompletion parameters from outside.

This commit is contained in:
Mateo Cámara 2024-01-09 12:06:49 +01:00
parent 48b2f69c93
commit bb06c51ede
2 changed files with 26 additions and 1 deletions

View file

@ -10,7 +10,7 @@ sys.path.insert(
) # Adds the parent directory to the system path ) # Adds the parent directory to the system path
import pytest import pytest
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout from litellm import embedding, completion, completion_cost, Timeout, acompletion
from litellm import RateLimitError from litellm import RateLimitError
# litellm.num_retries = 3 # litellm.num_retries = 3
@ -859,6 +859,28 @@ def test_completion_azure_key_completion_arg():
# test_completion_azure_key_completion_arg() # test_completion_azure_key_completion_arg()
def test_acompletion_params():
import inspect
from litellm.types.completion import CompletionRequest
acompletion_params_odict = inspect.signature(acompletion).parameters
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
# remove kwargs
acompletion_params.pop("kwargs", None)
keys_acompletion = set(acompletion_params.keys())
keys_completion = set(completion_params.keys())
# Assert that the parameters are the same
if keys_acompletion != keys_completion:
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
# test_acompletion_params()
async def test_re_use_azure_async_client(): async def test_re_use_azure_async_client():
try: try:
print("azure gpt-3.5 ASYNC with clie nttest\n\n") print("azure gpt-3.5 ASYNC with clie nttest\n\n")

View file

@ -14,6 +14,7 @@ import subprocess, os
import litellm, openai import litellm, openai
import itertools import itertools
import random, uuid, requests import random, uuid, requests
from functools import wraps
import datetime, time import datetime, time
import tiktoken import tiktoken
import uuid import uuid
@ -1934,6 +1935,7 @@ def client(original_function):
# [Non-Blocking Error] # [Non-Blocking Error]
pass pass
@wraps(original_function)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
result = None result = None
@ -2128,6 +2130,7 @@ def client(original_function):
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e raise e
@wraps(original_function)
async def wrapper_async(*args, **kwargs): async def wrapper_async(*args, **kwargs):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
result = None result = None