mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
feat: added explicit args to acomplete
This commit is contained in:
parent
04bbd0649f
commit
b72d372aa7
1 changed files with 72 additions and 26 deletions
|
@ -117,7 +117,31 @@ class Completions():
|
||||||
return response
|
return response
|
||||||
|
|
||||||
@client
|
@client
|
||||||
async def acompletion(*args, **kwargs):
|
async def acompletion(
|
||||||
|
model: str,
|
||||||
|
messages: List = [],
|
||||||
|
functions: Optional[List] = None,
|
||||||
|
function_call: Optional[str] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stop=None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[Dict] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
metadata: Optional[Dict] = None,
|
||||||
|
api_base: Optional[str] = None,
|
||||||
|
api_version: Optional[str] = None,
|
||||||
|
api_key: Optional[str] = None,
|
||||||
|
model_list: Optional[List] = None,
|
||||||
|
mock_response: Optional[str] = None,
|
||||||
|
force_timeout: Optional[int] = None,
|
||||||
|
custom_llm_provider: Optional[str] = None,
|
||||||
|
**kwargs,
|
||||||
|
):
|
||||||
"""
|
"""
|
||||||
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
|
||||||
|
|
||||||
|
@ -157,19 +181,41 @@ async def acompletion(*args, **kwargs):
|
||||||
- If `stream` is True, the function returns an async generator that yields completion lines.
|
- If `stream` is True, the function returns an async generator that yields completion lines.
|
||||||
"""
|
"""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
model = args[0] if len(args) > 0 else kwargs["model"]
|
# Adjusted to use explicit arguments instead of *args and **kwargs
|
||||||
### PASS ARGS TO COMPLETION ###
|
completion_kwargs = {
|
||||||
kwargs["acompletion"] = True
|
"model": model,
|
||||||
custom_llm_provider = None
|
"messages": messages,
|
||||||
|
"functions": functions,
|
||||||
|
"function_call": function_call,
|
||||||
|
"temperature": temperature,
|
||||||
|
"top_p": top_p,
|
||||||
|
"n": n,
|
||||||
|
"stream": stream,
|
||||||
|
"stop": stop,
|
||||||
|
"max_tokens": max_tokens,
|
||||||
|
"presence_penalty": presence_penalty,
|
||||||
|
"frequency_penalty": frequency_penalty,
|
||||||
|
"logit_bias": logit_bias,
|
||||||
|
"user": user,
|
||||||
|
"metadata": metadata,
|
||||||
|
"api_base": api_base,
|
||||||
|
"api_version": api_version,
|
||||||
|
"api_key": api_key,
|
||||||
|
"model_list": model_list,
|
||||||
|
"mock_response": mock_response,
|
||||||
|
"force_timeout": force_timeout,
|
||||||
|
"custom_llm_provider": custom_llm_provider,
|
||||||
|
"acompletion": True # assuming this is a required parameter
|
||||||
|
}
|
||||||
try:
|
try:
|
||||||
# Use a partial function to pass your keyword arguments
|
# Use a partial function to pass your keyword arguments
|
||||||
func = partial(completion, *args, **kwargs)
|
func = partial(completion, **completion_kwargs)
|
||||||
|
|
||||||
# Add the context to the function
|
# Add the context to the function
|
||||||
ctx = contextvars.copy_context()
|
ctx = contextvars.copy_context()
|
||||||
func_with_context = partial(ctx.run, func)
|
func_with_context = partial(ctx.run, func)
|
||||||
|
|
||||||
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None))
|
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("api_base", None))
|
||||||
|
|
||||||
if (custom_llm_provider == "openai"
|
if (custom_llm_provider == "openai"
|
||||||
or custom_llm_provider == "azure"
|
or custom_llm_provider == "azure"
|
||||||
|
@ -183,8 +229,8 @@ async def acompletion(*args, **kwargs):
|
||||||
or custom_llm_provider == "huggingface"
|
or custom_llm_provider == "huggingface"
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
|
||||||
if kwargs.get("stream", False):
|
if completion_kwargs.get("stream", False):
|
||||||
response = completion(*args, **kwargs)
|
response = completion(**completion_kwargs)
|
||||||
else:
|
else:
|
||||||
# Await normally
|
# Await normally
|
||||||
init_response = await loop.run_in_executor(None, func_with_context)
|
init_response = await loop.run_in_executor(None, func_with_context)
|
||||||
|
@ -195,17 +241,17 @@ async def acompletion(*args, **kwargs):
|
||||||
else:
|
else:
|
||||||
# Call the synchronous function using run_in_executor
|
# Call the synchronous function using run_in_executor
|
||||||
response = await loop.run_in_executor(None, func_with_context)
|
response = await loop.run_in_executor(None, func_with_context)
|
||||||
if kwargs.get("stream", False): # return an async generator
|
if completion_kwargs.get("stream", False): # return an async generator
|
||||||
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args)
|
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, completion_kwargs=completion_kwargs)
|
||||||
else:
|
else:
|
||||||
return response
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
custom_llm_provider = custom_llm_provider or "openai"
|
custom_llm_provider = custom_llm_provider or "openai"
|
||||||
raise exception_type(
|
raise exception_type(
|
||||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _async_streaming(response, model, custom_llm_provider, args):
|
async def _async_streaming(response, model, custom_llm_provider, completion_kwargs):
|
||||||
try:
|
try:
|
||||||
print_verbose(f"received response in _async_streaming: {response}")
|
print_verbose(f"received response in _async_streaming: {response}")
|
||||||
async for line in response:
|
async for line in response:
|
||||||
|
@ -214,7 +260,7 @@ async def _async_streaming(response, model, custom_llm_provider, args):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print_verbose(f"error raised _async_streaming: {traceback.format_exc()}")
|
print_verbose(f"error raised _async_streaming: {traceback.format_exc()}")
|
||||||
raise exception_type(
|
raise exception_type(
|
||||||
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args,
|
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue