Added the new acompletion parameters based on CompletionRequest attributes

This commit is contained in:
Mateo Cámara 2024-01-09 12:05:31 +01:00
parent 178a57492b
commit 48b2f69c93

View file

@ -118,29 +118,37 @@ class Completions():
@client @client
async def acompletion( async def acompletion(
model: str, model: str,
messages: List = [], # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
functions: Optional[List] = None, messages: List = [],
function_call: Optional[str] = None, functions: Optional[List] = None,
temperature: Optional[float] = None, function_call: Optional[str] = None,
top_p: Optional[float] = None, timeout: Optional[Union[float, int]] = None,
n: Optional[int] = None, temperature: Optional[float] = None,
stream: Optional[bool] = None, top_p: Optional[float] = None,
stop=None, n: Optional[int] = None,
max_tokens: Optional[int] = None, stream: Optional[bool] = None,
presence_penalty: Optional[float] = None, stop=None,
frequency_penalty: Optional[float] = None, max_tokens: Optional[float] = None,
logit_bias: Optional[Dict] = None, presence_penalty: Optional[float] = None,
user: Optional[str] = None, frequency_penalty: Optional[float] = None,
metadata: Optional[Dict] = None, logit_bias: Optional[dict] = None,
api_base: Optional[str] = None, user: Optional[str] = None,
api_version: Optional[str] = None, # openai v1.0+ new params
api_key: Optional[str] = None, response_format: Optional[dict] = None,
model_list: Optional[List] = None, seed: Optional[int] = None,
mock_response: Optional[str] = None, tools: Optional[List] = None,
force_timeout: Optional[int] = None, tool_choice: Optional[str] = None,
custom_llm_provider: Optional[str] = None, logprobs: Optional[bool] = None,
**kwargs, top_logprobs: Optional[int] = None,
deployment_id=None,
# set api_base, api_version, api_key
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
# Optional liteLLM function params
**kwargs,
): ):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -187,24 +195,28 @@ async def acompletion(
"messages": messages, "messages": messages,
"functions": functions, "functions": functions,
"function_call": function_call, "function_call": function_call,
"timeout": timeout,
"temperature": temperature, "temperature": temperature,
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
"stream": stream, "stream": stream,
"stop": stop, "stop": stop,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"presence_penalty": presence_penalty, "presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty, "frequency_penalty": frequency_penalty,
"logit_bias": logit_bias, "logit_bias": logit_bias,
"user": user, "user": user,
"metadata": metadata, "response_format": response_format,
"api_base": api_base, "seed": seed,
"tools": tools,
"tool_choice": tool_choice,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"deployment_id": deployment_id,
"base_url": base_url,
"api_version": api_version, "api_version": api_version,
"api_key": api_key, "api_key": api_key,
"model_list": model_list, "model_list": model_list,
"mock_response": mock_response,
"force_timeout": force_timeout,
"custom_llm_provider": custom_llm_provider,
"acompletion": True # assuming this is a required parameter "acompletion": True # assuming this is a required parameter
} }
try: try:
@ -215,7 +227,7 @@ async def acompletion(
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("api_base", None)) _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None))
if (custom_llm_provider == "openai" if (custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"