(fix) litellm.acompletion with type hints

This commit is contained in:
ishaan-jaff 2024-01-11 10:47:12 +05:30
parent 6e1be43595
commit cea0d6c8b0
2 changed files with 40 additions and 53 deletions

View file

@ -131,37 +131,37 @@ class Completions:
@client @client
async def acompletion( async def acompletion(
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [], messages: List = [],
functions: Optional[List] = None, functions: Optional[List] = None,
function_call: Optional[str] = None, function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None, timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None, temperature: Optional[float] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
n: Optional[int] = None, n: Optional[int] = None,
stream: Optional[bool] = None, stream: Optional[bool] = None,
stop=None, stop=None,
max_tokens: Optional[float] = None, max_tokens: Optional[float] = None,
presence_penalty: Optional[float] = None, presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None, frequency_penalty: Optional[float] = None,
logit_bias: Optional[dict] = None, logit_bias: Optional[dict] = None,
user: Optional[str] = None, user: Optional[str] = None,
# openai v1.0+ new params # openai v1.0+ new params
response_format: Optional[dict] = None, response_format: Optional[dict] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
tools: Optional[List] = None, tools: Optional[List] = None,
tool_choice: Optional[str] = None, tool_choice: Optional[str] = None,
logprobs: Optional[bool] = None, logprobs: Optional[bool] = None,
top_logprobs: Optional[int] = None, top_logprobs: Optional[int] = None,
deployment_id=None, deployment_id=None,
# set api_base, api_version, api_key # set api_base, api_version, api_key
base_url: Optional[str] = None, base_url: Optional[str] = None,
api_version: Optional[str] = None, api_version: Optional[str] = None,
api_key: Optional[str] = None, api_key: Optional[str] = None,
model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. model_list: Optional[list] = None, # pass in a list of api_base,keys, etc.
# Optional liteLLM function params # Optional liteLLM function params
**kwargs, **kwargs,
): ):
""" """
Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) Asynchronously executes a litellm.completion() call for any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly)
@ -213,7 +213,7 @@ async def acompletion(
"top_p": top_p, "top_p": top_p,
"n": n, "n": n,
"stream": stream, "stream": stream,
"stop": stop, "stop": stop,
"max_tokens": max_tokens, "max_tokens": max_tokens,
"presence_penalty": presence_penalty, "presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty, "frequency_penalty": frequency_penalty,
@ -230,17 +230,19 @@ async def acompletion(
"api_version": api_version, "api_version": api_version,
"api_key": api_key, "api_key": api_key,
"model_list": model_list, "model_list": model_list,
"acompletion": True # assuming this is a required parameter "acompletion": True, # assuming this is a required parameter
} }
try: try:
# Use a partial function to pass your keyword arguments # Use a partial function to pass your keyword arguments
func = partial(completion, *args, **kwargs) func = partial(completion, **completion_kwargs)
# Add the context to the function # Add the context to the function
ctx = contextvars.copy_context() ctx = contextvars.copy_context()
func_with_context = partial(ctx.run, func) func_with_context = partial(ctx.run, func)
_, custom_llm_provider, _, _ = get_llm_provider(model=model, api_base=completion_kwargs.get("base_url", None)) _, custom_llm_provider, _, _ = get_llm_provider(
model=model, api_base=completion_kwargs.get("base_url", None)
)
if ( if (
custom_llm_provider == "openai" custom_llm_provider == "openai"
@ -284,7 +286,7 @@ async def acompletion(
model=model, model=model,
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
original_exception=e, original_exception=e,
completion_kwargs=args, completion_kwargs=completion_kwargs,
) )
@ -3260,7 +3262,6 @@ def stream_chunk_builder(chunks: list, messages: Optional[list] = None):
if isinstance( if isinstance(
chunks[0]["choices"][0], litellm.utils.TextChoices chunks[0]["choices"][0], litellm.utils.TextChoices
): # route to the text completion logic ): # route to the text completion logic
return stream_chunk_builder_text_completion(chunks=chunks, messages=messages) return stream_chunk_builder_text_completion(chunks=chunks, messages=messages)
role = chunks[0]["choices"][0]["delta"]["role"] role = chunks[0]["choices"][0]["delta"]["role"]
finish_reason = chunks[-1]["choices"][0]["finish_reason"] finish_reason = chunks[-1]["choices"][0]["finish_reason"]

View file

@ -15,22 +15,6 @@ from litellm import completion, acompletion, acreate
litellm.num_retries = 3 litellm.num_retries = 3
def test_sync_response():
litellm.set_verbose = False
user_message = "Hello, how are you?"
messages = [{"content": user_message, "role": "user"}]
try:
response = completion(model="gpt-3.5-turbo", messages=messages, timeout=5)
print(f"response: {response}")
except litellm.Timeout as e:
pass
except Exception as e:
pytest.fail(f"An exception occurred: {e}")
# test_sync_response()
def test_sync_response_anyscale(): def test_sync_response_anyscale():
litellm.set_verbose = False litellm.set_verbose = False
user_message = "Hello, how are you?" user_message = "Hello, how are you?"
@ -197,6 +181,7 @@ def test_get_cloudflare_response_streaming():
asyncio.run(test_async_call()) asyncio.run(test_async_call())
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_hf_completion_tgi(): async def test_hf_completion_tgi():
# litellm.set_verbose=True # litellm.set_verbose=True
@ -212,6 +197,7 @@ async def test_hf_completion_tgi():
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")
# test_get_cloudflare_response_streaming() # test_get_cloudflare_response_streaming()