add completion types

This commit is contained in:
ishaan-jaff 2023-09-05 14:57:01 -07:00
parent 470eba90b6
commit 5cbe7a941d

View file

@ -70,16 +70,16 @@ async def acompletion(*args, **kwargs):
def completion( def completion(
model: str, model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages=[], messages: List = [],
functions=[], functions: List = [],
function_call="", # optional params function_call: str = "", # optional params
temperature: float = 1, temperature: float = 1,
top_p: float = 1, top_p: float = 1,
n: int = 1, n: int = 1,
stream: bool = False, stream: bool = False,
stop=None, stop=None,
max_tokens=float("inf"), max_tokens: float = float("inf"),
presence_penalty=0, presence_penalty: float = 0,
frequency_penalty=0, frequency_penalty=0,
logit_bias: dict = {}, logit_bias: dict = {},
user: str = "", user: str = "",
@ -360,10 +360,12 @@ def completion(
# set replicate key # set replicate key
os.environ["REPLICATE_API_TOKEN"] = str(replicate_key) os.environ["REPLICATE_API_TOKEN"] = str(replicate_key)
prompt = " ".join([message["content"] for message in messages]) prompt = " ".join([message["content"] for message in messages])
input = {"prompt": prompt} input = {
"prompt": prompt
}
if "max_tokens" in optional_params: if "max_tokens" in optional_params:
input["max_length"] = max_tokens # for t5 models input["max_length"] = optional_params['max_tokens'] # for t5 models
input["max_new_tokens"] = max_tokens # for llama2 models input["max_new_tokens"] = optional_params['max_tokens'] # for llama2 models
## LOGGING ## LOGGING
logging.pre_call( logging.pre_call(
input=prompt, input=prompt,