Merge branch 'main' into main

This commit is contained in:
Krish Dholakia 2023-12-18 17:54:34 -08:00 committed by GitHub
commit 408f232bd7
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 185 additions and 70 deletions

View file

@ -52,6 +52,7 @@ from .llms import (
cohere,
petals,
oobabooga,
openrouter,
palm,
vertex_ai,
maritalk)
@ -260,8 +261,8 @@ def completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
functions: Optional[List] = None,
function_call: Optional[str] = None,
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@ -1026,14 +1027,23 @@ def completion(
}
)
## Load Config
config = openrouter.OpenrouterConfig.get_config()
for k, v in config.items():
if k == "extra_body":
# we use openai 'extra_body' to pass openrouter specific params - transforms, route, models
if "extra_body" in optional_params:
optional_params[k].update(v)
else:
optional_params[k] = v
elif k not in optional_params:
optional_params[k] = v
data = {
"model": model,
"messages": messages,
**optional_params
}
## LOGGING
logging.pre_call(input=messages, api_key=openai.api_key, additional_args={"complete_input_dict": data, "headers": headers})
## COMPLETION CALL
## COMPLETION CALL
response = openai_chat_completions.completion(
@ -1510,8 +1520,8 @@ def batch_completion(
model: str,
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
functions: Optional[List] = None,
function_call: Optional[str] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
@ -2193,10 +2203,8 @@ def text_completion(
if stream == True or kwargs.get("stream", False) == True:
response = TextCompletionStreamWrapper(completion_stream=response, model=model)
return response
if asyncio.iscoroutine(response):
response = asyncio.run(response)
if kwargs.get("acompletion", False) == True:
return response
transformed_logprobs = None
# only supported for TGI models
try: