add types to completion()

This commit is contained in:
ishaan-jaff 2023-09-05 14:42:09 -07:00
parent 4b98feec36
commit d0d5ef505d

View file

@ -29,6 +29,7 @@ from .llms import aleph_alpha
from .llms import baseten from .llms import baseten
import tiktoken import tiktoken
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Callable, List, Optional, Dict
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
@ -67,34 +68,34 @@ async def acompletion(*args, **kwargs):
600 600
) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout` ) ## set timeouts, in case calls hang (e.g. Azure) - default is 600s, override with `force_timeout`
def completion( def completion(
model, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages=[], messages=[],
functions=[], functions=[],
function_call="", # optional params function_call="", # optional params
temperature=1, temperature: float = 1,
top_p=1, top_p: float = 1,
n=1, n: int = 1,
stream=False, stream: bool = False,
stop=None, stop=None,
max_tokens=float("inf"), max_tokens=float("inf"),
presence_penalty=0, presence_penalty=0,
frequency_penalty=0, frequency_penalty=0,
logit_bias={}, logit_bias: dict = {},
user="", user: str = "",
deployment_id=None, deployment_id = None,
# Optional liteLLM function params # Optional liteLLM function params
*, *,
return_async=False, return_async=False,
api_key=None, api_key: Optional[str] = None,
api_version=None, api_version: Optional[str] = None,
api_base: Optional[str] = None,
force_timeout=600, force_timeout=600,
num_beams=1, num_beams=1,
logger_fn=None, logger_fn=None,
verbose=False, verbose=False,
azure=False, azure=False,
custom_llm_provider=None, custom_llm_provider=None,
api_base=None,
litellm_call_id=None, litellm_call_id=None,
litellm_logging_obj=None, litellm_logging_obj=None,
use_client=False, use_client=False,