feat: added explicit args to acomplete

This commit is contained in:
Mateo Cámara 2023-12-20 19:49:12 +01:00
parent b873833340
commit e60e1afa53

View file

@ -117,7 +117,31 @@ class Completions():
return response return response
@client @client
async def acompletion(*args, **kwargs): async def acompletion(
model: str,
messages: List = [],
functions: Optional[List] = None,
function_call: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: Optional[bool] = None,
stop=None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict] = None,
user: Optional[str] = None,
metadata: Optional[Dict] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
model_list: Optional[List] = None,
mock_response: Optional[str] = None,
force_timeout: Optional[int] = None,
custom_llm_provider: Optional[str] = None,
**kwargs,
):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -138,7 +162,7 @@ async def acompletion(*args, **kwargs):
frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far.
logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion.
user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse.
metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc.
api_base (str, optional): Base URL for the API (default is None). api_base (str, optional): Base URL for the API (default is None).
api_version (str, optional): API version (default is None). api_version (str, optional): API version (default is None).
api_key (str, optional): API key (default is None). api_key (str, optional): API key (default is None).
@ -157,22 +181,44 @@ async def acompletion(*args, **kwargs):
- If `stream` is True, the function returns an async generator that yields completion lines. - If `stream` is True, the function returns an async generator that yields completion lines.
""" """
loop = asyncio.get_event_loop() loop = asyncio.get_event_loop()
model = args[0] if len(args) > 0 else kwargs["model"] # Adjusted to use explicit arguments instead of *args and **kwargs
### PASS ARGS TO COMPLETION ### completion_kwargs = {
kwargs["acompletion"] = True "model": model,
custom_llm_provider = None "messages": messages,
try: "functions": functions,
"function_call": function_call,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stream": stream,
"stop": stop,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"user": user,
"metadata": metadata,
"api_base": api_base,
"api_version": api_version,
"api_key": api_key,
"model_list": model_list,
"mock_response": mock_response,
"force_timeout": force_timeout,
"custom_llm_provider": custom_llm_provider,
"acompletion": True # assuming this is a required parameter
}
try:
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs) func = partial(completion, **completion_kwargs)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=kwargs.get("api_base", None)) _, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("api_base", None))
if (custom_llm_provider == "openai" if (custom_llm_provider == "openai"
or custom_llm_provider == "azure" or custom_llm_provider == "azure"
or custom_llm_provider == "custom_openai" or custom_llm_provider == "custom_openai"
or custom_llm_provider == "anyscale" or custom_llm_provider == "anyscale"
or custom_llm_provider == "mistral" or custom_llm_provider == "mistral"
@ -182,39 +228,39 @@ async def acompletion(*args, **kwargs):
or custom_llm_provider == "text-completion-openai" or custom_llm_provider == "text-completion-openai"
or custom_llm_provider == "huggingface" or custom_llm_provider == "huggingface"
or custom_llm_provider == "ollama" or custom_llm_provider == "ollama"
or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all. or custom_llm_provider == "vertex_ai"): # currently implemented aiohttp calls for just azure and openai, soon all.
if kwargs.get("stream", False): if completion_kwargs.get("stream", False):
response = completion(*args, **kwargs) response = completion(**completion_kwargs)
else: else:
# Await normally # Await normally
init_response = await loop.run_in_executor(None, func_with_context) init_response = await loop.run_in_executor(None, func_with_context)
if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO if isinstance(init_response, dict) or isinstance(init_response, ModelResponse): ## CACHING SCENARIO
response = init_response response = init_response
elif asyncio.iscoroutine(init_response): elif asyncio.iscoroutine(init_response):
response = await init_response response = await init_response
else: else:
# Call the synchronous function using run_in_executor # Call the synchronous function using run_in_executor
response = await loop.run_in_executor(None, func_with_context) response = await loop.run_in_executor(None, func_with_context)
if kwargs.get("stream", False): # return an async generator if completion_kwargs.get("stream", False): # return an async generator
return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, args=args) return _async_streaming(response=response, model=model, custom_llm_provider=custom_llm_provider, completion_kwargs=completion_kwargs)
else: else:
return response return response
except Exception as e: except Exception as e:
custom_llm_provider = custom_llm_provider or "openai" custom_llm_provider = custom_llm_provider or "openai"
raise exception_type( raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs,
) )
async def _async_streaming(response, model, custom_llm_provider, args): async def _async_streaming(response, model, custom_llm_provider, completion_kwargs):
try: try:
print_verbose(f"received response in _async_streaming: {response}") print_verbose(f"received response in _async_streaming: {response}")
async for line in response: async for line in response:
print_verbose(f"line in async streaming: {line}") print_verbose(f"line in async streaming: {line}")
yield line yield line
except Exception as e: except Exception as e:
print_verbose(f"error raised _async_streaming: {traceback.format_exc()}") print_verbose(f"error raised _async_streaming: {traceback.format_exc()}")
raise exception_type( raise exception_type(
model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=args, model=model, custom_llm_provider=custom_llm_provider, original_exception=e, completion_kwargs=completion_kwargs,
) )
def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs): def mock_completion(model: str, messages: List, stream: Optional[bool] = False, mock_response: str = "This is a mock request", **kwargs):