fixing optional param mapping

This commit is contained in:
Krrish Dholakia 2023-10-02 14:14:30 -07:00
parent 7cec308a2c
commit 8b60d797e1
6 changed files with 32 additions and 21 deletions

View file

@ -202,6 +202,7 @@ def completion(
- If 'mock_response' is provided, a mock completion response is returned for testing or debugging.
"""
######### unpacking kwargs #####################
args = locals()
return_async = kwargs.get('return_async', False)
mock_response = kwargs.get('mock_response', None)
api_key = kwargs.get('api_key', None)
@ -216,9 +217,8 @@ def completion(
metadata = kwargs.get('metadata', None)
fallbacks = kwargs.get('fallbacks', [])
######## end of unpacking kwargs ###########
args = locals()
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "metadata"]
litellm_params = ["caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "metadata", "fallbacks"]
litellm_params = ["acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "metadata", "fallbacks"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
@ -797,7 +797,7 @@ def completion(
logging_obj=logging
)
# fake palm streaming
if stream == True:
if "stream" in optional_params and optional_params["stream"] == True:
# fake streaming for palm
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper(
@ -836,7 +836,6 @@ def completion(
if k not in optional_params:
optional_params[k] = v
print(f"optional_params: {optional_params}")
## LOGGING
logging.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
@ -979,7 +978,7 @@ def completion(
logging_obj=logging
)
if stream==True: ## [BETA]
if "stream" in optional_params and optional_params["stream"]==True: ## [BETA]
# sagemaker does not support streaming as of now so we're faking streaming:
# https://discuss.huggingface.co/t/streaming-output-text-when-deploying-on-sagemaker/39611
# "SageMaker is currently not supporting streaming responses."
@ -1009,7 +1008,7 @@ def completion(
)
if stream == True:
if "stream" in optional_params and optional_params["stream"] == True:
# don't try to access stream object,
response = CustomStreamWrapper(
iter(model_response), model, custom_llm_provider="bedrock", logging_obj=logging