diff --git a/litellm/main.py b/litellm/main.py index b51000d8c..aff8f9bea 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -130,7 +130,39 @@ class Completions: @client -async def acompletion(*args, **kwargs): +async def acompletion( + model: str, + # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create + messages: List = [], + functions: Optional[List] = None, + function_call: Optional[str] = None, + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, + stop=None, + max_tokens: Optional[float] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[dict] = None, + user: Optional[str] = None, + # openai v1.0+ new params + response_format: Optional[dict] = None, + seed: Optional[int] = None, + tools: Optional[List] = None, + tool_choice: Optional[str] = None, + logprobs: Optional[bool] = None, + top_logprobs: Optional[int] = None, + deployment_id=None, + # set api_base, api_version, api_key + base_url: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. + # Optional liteLLM function params + **kwargs, +): """ Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) @@ -170,10 +202,36 @@ async def acompletion(*args, **kwargs): - If `stream` is True, the function returns an async generator that yields completion lines. """ loop = asyncio.get_event_loop() - model = args[0] if len(args) > 0 else kwargs["model"] - ### PASS ARGS TO COMPLETION ### - kwargs["acompletion"] = True - custom_llm_provider = None + # Adjusted to use explicit arguments instead of *args and **kwargs + completion_kwargs = { + "model": model, + "messages": messages, + "functions": functions, + "function_call": function_call, + "timeout": timeout, + "temperature": temperature, + "top_p": top_p, + "n": n, + "stream": stream, + "stop": stop, + "max_tokens": max_tokens, + "presence_penalty": presence_penalty, + "frequency_penalty": frequency_penalty, + "logit_bias": logit_bias, + "user": user, + "response_format": response_format, + "seed": seed, + "tools": tools, + "tool_choice": tool_choice, + "logprobs": logprobs, + "top_logprobs": top_logprobs, + "deployment_id": deployment_id, + "base_url": base_url, + "api_version": api_version, + "api_key": api_key, + "model_list": model_list, + "acompletion": True # assuming this is a required parameter + } try: # Use a partial function to pass your keyword arguments func = partial(completion, *args, **kwargs) @@ -182,9 +240,7 @@ async def acompletion(*args, **kwargs): ctx = contextvars.copy_context() func_with_context = partial(ctx.run, func) - _, custom_llm_provider, _, _ = get_llm_provider( - model=model, api_base=kwargs.get("api_base", None) - ) + _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None)) if ( custom_llm_provider == "openai" @@ -3200,9 +3256,11 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None): created = chunks[0]["created"] model = chunks[0]["model"] system_fingerprint = chunks[0].get("system_fingerprint", None) + if isinstance( chunks[0]["choices"][0], litellm.utils.TextChoices ): # route to the text completion logic + return stream_chunk_builder_text_completion(chunks=chunks, messages=messages) role = chunks[0]["choices"][0]["delta"]["role"] finish_reason = chunks[-1]["choices"][0]["finish_reason"] diff --git a/litellm/tests/test_acompletion.py b/litellm/tests/test_acompletion.py new file mode 100644 index 000000000..e5c09b9b7 --- /dev/null +++ b/litellm/tests/test_acompletion.py @@ -0,0 +1,23 @@ +import pytest +from litellm import acompletion + + +def test_acompletion_params(): + import inspect + from litellm.types.completion import CompletionRequest + + acompletion_params_odict = inspect.signature(acompletion).parameters + acompletion_params = {name: param.annotation for name, param in acompletion_params_odict.items()} + completion_params = {field_name: field_type for field_name, field_type in CompletionRequest.__annotations__.items()} + + # remove kwargs + acompletion_params.pop("kwargs", None) + + keys_acompletion = set(acompletion_params.keys()) + keys_completion = set(completion_params.keys()) + + # Assert that the parameters are the same + if keys_acompletion != keys_completion: + pytest.fail("The parameters of the acompletion function and the CompletionRequest class are not the same.") + +# test_acompletion_params() diff --git a/litellm/utils.py b/litellm/utils.py index f3e743ec4..fcf6e9dea 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -14,6 +14,7 @@ import subprocess, os import litellm, openai import itertools import random, uuid, requests +from functools import wraps import datetime, time import tiktoken import uuid @@ -1972,6 +1973,7 @@ def client(original_function): # [Non-Blocking Error] pass + @wraps(original_function) def wrapper(*args, **kwargs): start_time = datetime.datetime.now() result = None @@ -2166,6 +2168,7 @@ def client(original_function): e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}" raise e + @wraps(original_function) async def wrapper_async(*args, **kwargs): start_time = datetime.datetime.now() result = None