Merge pull request #1200 from MateoCamara/explicit-args-acomplete

feat: added explicit args to acomplete
This commit is contained in:
Ishaan Jaff 2024-01-11 10:39:05 +05:30 committed by GitHub
commit 2433d6c613
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 8 deletions

View file

@ -130,7 +130,39 @@ class Completions:
@client
async def acompletion(*args, **kwargs):
async def acompletion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: Optional[List] = None,
function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
# openai v1.0+ new params
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
# Optional liteLLM function params
**kwargs,
):
"""
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -170,10 +202,36 @@ async def acompletion(*args, **kwargs):
- If `stream` is True, the function returns an async generator that yields completion lines.
"""
loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"]
### PASS ARGS TO COMPLETION ###
kwargs["acompletion"] = True
custom_llm_provider = None
# Adjusted to use explicit arguments instead of *args and **kwargs
completion_kwargs = {
"model": model,
"messages": messages,
"functions": functions,
"function_call": function_call,
"timeout": timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stream": stream,
"stop": stop,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"user": user,
"response_format": response_format,
"seed": seed,
"tools": tools,
"tool_choice": tool_choice,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"deployment_id": deployment_id,
"base_url": base_url,
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
"acompletion": True # assuming this is a required parameter
}
try:
# Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs)
@ -182,9 +240,7 @@ async def acompletion(*args, **kwargs):
ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=kwargs.get("api_base", None)
)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None))
if (
custom_llm_provider == "openai"
@ -3200,9 +3256,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
created = chunks[0]["created"]
model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None)
if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"]

View file

@ -0,0 +1,23 @@
import pytest
from litellm import acompletion
def test_acompletion_params():
import inspect
from litellm.types.completion import CompletionRequest
acompletion_params_odict = inspect.signature(acompletion).parameters
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
# remove kwargs
acompletion_params.pop("kwargs", None)
keys_acompletion = set(acompletion_params.keys())
keys_completion = set(completion_params.keys())
# Assert that the parameters are the same
if keys_acompletion != keys_completion:
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
# test_acompletion_params()

View file

@ -14,6 +14,7 @@ import subprocess, os
import litellm, openai
import itertools
import random, uuid, requests
from functools import wraps
import datetime, time
import tiktoken
import uuid
@ -1972,6 +1973,7 @@ def client(original_function):
# [Non-Blocking Error]
pass
@wraps(original_function)
def wrapper(*args, **kwargs):
start_time = datetime.datetime.now()
result = None
@ -2166,6 +2168,7 @@ def client(original_function):
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e
@wraps(original_function)
async def wrapper_async(*args, **kwargs):
start_time = datetime.datetime.now()
result = None