fix: bug fix when n>1 passed in

This commit is contained in:
Krrish Dholakia 2023-10-09 16:46:18 -07:00
parent 2004b449e8
commit 253e8d27db
8 changed files with 119 additions and 43 deletions

View file

@ -97,6 +97,15 @@ last_fetched_at_keys = None
# 'usage': {'prompt_tokens': 18, 'completion_tokens': 23, 'total_tokens': 41}
# }
class UnsupportedParamsError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
def _generate_id(): # private helper function
return 'chatcmpl-' + str(uuid.uuid4())
@ -1008,7 +1017,7 @@ def get_optional_params( # use the openai defaults
if litellm.add_function_to_prompt: # if user opts to add it to prompt instead
optional_params["functions_unsupported_model"] = non_default_params.pop("functions")
else:
raise ValueError(f"LiteLLM.Exception: Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
raise UnsupportedParamsError(status_code=500, message=f"Function calling is not supported by {custom_llm_provider}. To add it to the prompt, set `litellm.add_function_to_prompt = True`.")
def _check_valid_arg(supported_params):
print_verbose(f"checking params for {model}")
@ -1025,7 +1034,7 @@ def get_optional_params( # use the openai defaults
else:
unsupported_params[k] = non_default_params[k]
if unsupported_params and not litellm.drop_params:
raise ValueError(f"LiteLLM.Exception: {custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
raise UnsupportedParamsError(status_code=500, message=f"{custom_llm_provider} does not support parameters: {unsupported_params}. To drop these, set `litellm.drop_params=True`.")
## raise exception if provider doesn't support passed in param
if custom_llm_provider == "anthropic":
@ -1163,7 +1172,7 @@ def get_optional_params( # use the openai defaults
if stop:
optional_params["stopSequences"] = stop
if max_tokens:
optional_params["maxOutputTokens"] = max_tokens
optional_params["max_output_tokens"] = max_tokens
elif (
custom_llm_provider == "vertex_ai"
):