diff --git a/litellm/main.py b/litellm/main.py index 45537eb49..0b58c0aab 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1139,26 +1139,20 @@ def batch_completion( messages: List = [], functions: List = [], function_call: str = "", # optional params - temperature: float = 1, - top_p: float = 1, - n: int = 1, - stream: bool = False, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stream: Optional[bool] = None, stop=None, - max_tokens: float = float("inf"), - presence_penalty: float = 0, - frequency_penalty=0, + max_tokens: Optional[float] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float]=None, logit_bias: dict = {}, user: str = "", + deployment_id = None, + request_timeout: Optional[int] = None, # Optional liteLLM function params - *, - return_async=False, - api_key: Optional[str] = None, - api_version: Optional[str] = None, - api_base: Optional[str] = None, - force_timeout=600, - # used by text-bison only - top_k=40, - custom_llm_provider=None,): + **kwargs): args = locals() batch_messages = messages completions = [] @@ -1183,10 +1177,10 @@ def batch_completion( user=user, # params to identify the model model=model, - custom_llm_provider=custom_llm_provider, - top_k=top_k, + custom_llm_provider=custom_llm_provider ) results = vllm.batch_completions(model=model, messages=batch_messages, custom_prompt_dict=litellm.custom_prompt_dict, optional_params=optional_params) + # all non VLLM models for batch completion models else: def chunks(lst, n): """Yield successive n-sized chunks from lst.""" @@ -1195,9 +1189,12 @@ def batch_completion( with ThreadPoolExecutor(max_workers=100) as executor: for sub_batch in chunks(batch_messages, 100): for message_list in sub_batch: - kwargs_modified = args + kwargs_modified = args.copy() kwargs_modified["messages"] = message_list - future = executor.submit(completion, **kwargs_modified) + original_kwargs = {} + if "kwargs" in kwargs_modified: + original_kwargs = kwargs_modified.pop("kwargs") + future = executor.submit(completion, **kwargs_modified, **original_kwargs) completions.append(future) # Retrieve the results from the futures diff --git a/pyproject.toml b/pyproject.toml index ee80ec21c..bd85f7f7d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.2.6" +version = "0.3.0" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License" @@ -25,7 +25,7 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "0.2.6" +version = "0.3.0" version_files = [ "pyproject.toml:^version" ]