fix(main.py): accepting azure deployment_id

This commit is contained in:
Krrish Dholakia 2023-11-09 18:15:54 -08:00
parent 523c540051
commit 249cde3d40
3 changed files with 8 additions and 8 deletions

View file

@ -301,7 +301,7 @@ def completion(
eos_token = kwargs.get("eos_token", None)
acompletion = kwargs.get("acompletion", False)
######## end of unpacking kwargs ###########
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key"]
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "request_timeout", "api_base", "api_version", "api_key", "deployment_id"]
litellm_params = ["metadata", "acompletion", "caching", "return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "use_client", "id", "fallbacks", "azure", "headers", "model_list", "num_retries", "context_window_fallback_dict", "roles", "final_prompt_value", "bos_token", "eos_token", "request_timeout", "complete_response"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
@ -366,7 +366,6 @@ def completion(
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
user=user,
deployment_id=deployment_id,
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,

View file

@ -445,7 +445,7 @@ def test_completion_openai_litellm_key():
except Exception as e:
pytest.fail(f"Error occurred: {e}")
test_completion_openai_litellm_key()
# test_completion_openai_litellm_key()
def test_completion_openrouter1():
try:
@ -540,6 +540,8 @@ def test_completion_openai_with_more_optional_params():
pytest.fail(f"Error occurred: {e}")
if type(response_str_2) != str:
pytest.fail(f"Error occurred: {e}")
except Timeout as e:
pass
except Exception as e:
pytest.fail(f"Error occurred: {e}")
@ -721,6 +723,7 @@ def test_completion_azure_with_litellm_key():
def test_completion_azure_deployment_id():
try:
litellm.set_verbose = True
response = completion(
deployment_id="chatgpt-v-2",
model="gpt-3.5-turbo",
@ -730,7 +733,7 @@ def test_completion_azure_deployment_id():
print(response)
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# test_completion_azure_deployment_id()
test_completion_azure_deployment_id()
# Only works for local endpoint
# def test_completion_anthropic_openai_proxy():

View file

@ -1363,7 +1363,6 @@ def get_optional_params( # use the openai defaults
frequency_penalty=0,
logit_bias={},
user="",
deployment_id=None,
model=None,
custom_llm_provider="",
**kwargs
@ -1386,7 +1385,6 @@ def get_optional_params( # use the openai defaults
"frequency_penalty":None,
"logit_bias":{},
"user":"",
"deployment_id":None,
"model":None,
"custom_llm_provider":"",
}
@ -1762,7 +1760,7 @@ def get_optional_params( # use the openai defaults
if stream:
optional_params["stream"] = stream
elif custom_llm_provider == "deepinfra":
supported_params = ["temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "deployment_id"]
supported_params = ["temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user"]
_check_valid_arg(supported_params=supported_params)
optional_params = non_default_params
if temperature != None:
@ -1770,7 +1768,7 @@ def get_optional_params( # use the openai defaults
temperature = 0.0001 # close to 0
optional_params["temperature"] = temperature
else: # assume passing in params for openai/azure openai
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "deployment_id"]
supported_params = ["functions", "function_call", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user"]
_check_valid_arg(supported_params=supported_params)
optional_params = non_default_params
# if user passed in non-default kwargs for specific providers/models, pass them along