Merge pull request #1200 from MateoCamara/explicit-args-acomplete

feat: added explicit args to acomplete
This commit is contained in:
Ishaan Jaff 2024-01-11 10:39:05 +05:30 committed by GitHub
commit 2433d6c613
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 92 additions and 8 deletions

View file

@ -130,7 +130,39 @@ class Completions:
@client @client
async def acompletion(*args, **kwargs): async def acompletion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: Optional[List] = None,
function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[float] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None,
user: Optional[str] = None,
# openai v1.0+ new params
response_format: Optional[dict] = None,
seed: Optional[int] = None,
tools: Optional[List] = None,
tool_choice: Optional[str] = None,
logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None,
deployment_id=None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
# Optional liteLLM function params
**kwargs,
):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -170,10 +202,36 @@ async def acompletion(*args, **kwargs):
- If `stream` is True, the function returns an async generator that yields completion lines. - If `stream` is True, the function returns an async generator that yields completion lines.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"] # Adjusted to use explicit arguments instead of *args and **kwargs
### PASS ARGS TO COMPLETION ### completion_kwargs = {
kwargs["acompletion"] = True "model": model,
custom_llm_provider = None "messages": messages,
"functions": functions,
"function_call": function_call,
"timeout": timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stream": stream,
"stop": stop,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"user": user,
"response_format": response_format,
"seed": seed,
"tools": tools,
"tool_choice": tool_choice,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"deployment_id": deployment_id,
"base_url": base_url,
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
"acompletion": True # assuming this is a required parameter
}
try: try:
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs) func = partial(completion, *args, **kwargs)
@ -182,9 +240,7 @@ async def acompletion(*args, **kwargs):
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider( _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None))
model=model, api_base=kwargs.get("api_base", None)
)
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
@ -3200,9 +3256,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
created = chunks[0]["created"] created = chunks[0]["created"]
model = chunks[0]["model"] model = chunks[0]["model"]
system_fingerprint = chunks[0].get("system_fingerprint", None) system_fingerprint = chunks[0].get("system_fingerprint", None)
if isinstance( if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic ): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages) return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
role = chunks[0]["choices"][0]["delta"]["role"] role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"] finish_reason = chunks[-1]["choices"][0]["finish_reason"]

View file

@ -0,0 +1,23 @@
import pytest
from litellm import acompletion
def test_acompletion_params():
import inspect
from litellm.types.completion import CompletionRequest
acompletion_params_odict = inspect.signature(acompletion).parameters
acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()}
completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()}
# remove kwargs
acompletion_params.pop("kwargs", None)
keys_acompletion = set(acompletion_params.keys())
keys_completion = set(completion_params.keys())
# Assert that the parameters are the same
if keys_acompletion != keys_completion:
pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.")
# test_acompletion_params()

View file

@ -14,6 +14,7 @@ import subprocess, os
import litellm, openai import litellm, openai
import itertools import itertools
import random, uuid, requests import random, uuid, requests
from functools import wraps
import datetime, time import datetime, time
import tiktoken import tiktoken
import uuid import uuid
@ -1972,6 +1973,7 @@ def client(original_function):
# [Non-Blocking Error] # [Non-Blocking Error]
pass pass
@wraps(original_function)
def wrapper(*args, **kwargs): def wrapper(*args, **kwargs):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
result = None result = None
@ -2166,6 +2168,7 @@ def client(original_function):
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
raise e raise e
@wraps(original_function)
async def wrapper_async(*args, **kwargs): async def wrapper_async(*args, **kwargs):
start_time = datetime.datetime.now() start_time = datetime.datetime.now()
result = None result = None