raise exception if optional param is not mapped to model

This commit is contained in:
Krrish Dholakia 2023-10-02 11:17:44 -07:00
parent 49f65b7eb8
commit 1cae080eb2
8 changed files with 155 additions and 112 deletions

View file

@ -154,14 +154,14 @@ def completion(
messages: List = [],
functions: List = [],
function_call: str = "", # optional params
temperature: float = 1,
top_p: float = 1,
n: int = 1,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stream: bool = False,
stop=None,
max_tokens: float = float("inf"),
presence_penalty: float = 0,
frequency_penalty=0,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float]=None,
logit_bias: dict = {},
user: str = "",
deployment_id = None,
@ -214,14 +214,15 @@ def completion(
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
id = kwargs.get('id', None)
metadata = kwargs.get('metadata', None)
request_timeout = kwargs.get('request_timeout', 0)
fallbacks = kwargs.get('fallbacks', [])
######## end of unpacking kwargs ###########
args = locals()
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "metadata"]
litellm_params = ["return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "id", "metadata", "fallbacks"]
default_params = openai_params + litellm_params
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
if mock_response:
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
args = locals()
try:
logging = litellm_logging_obj
if fallbacks != []:
@ -260,10 +261,7 @@ def completion(
# params to identify the model
model=model,
custom_llm_provider=custom_llm_provider,
top_k=kwargs.get('top_k', 40),
task=kwargs.get('task', "text-generation-inference"),
remove_input=kwargs.get('remove_input', True),
return_full_text=kwargs.get('return_full_text', False),
**non_default_params
)
# For logging - save the values of the litellm-specific params passed in
litellm_params = get_litellm_params(
@ -822,8 +820,10 @@ def completion(
# vertexai does not use an API key, it looks for credentials.json in the environment
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=None)
# contains any default values we need to pass to the provider
VertexAIConfig = {
"top_k": 40 # override by setting kwarg in completion() - e.g. completion(..., top_k=20)
}
if model in litellm.vertex_chat_models:
chat_model = ChatModel.from_pretrained(model)
else: # vertex_code_chat_models
@ -831,7 +831,16 @@ def completion(
chat = chat_model.start_chat()
if stream:
## Load Config
for k, v in VertexAIConfig.items():
if k not in optional_params:
optional_params[k] = v
print(f"optional_params: {optional_params}")
## LOGGING
logging.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
if "stream" in optional_params and optional_params["stream"] == True:
model_response = chat.send_message_streaming(prompt, **optional_params)
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
@ -875,16 +884,27 @@ def completion(
)
# vertexai does not use an API key, it looks for credentials.json in the environment
# contains any default values we need to pass to the provider
VertexAIConfig = {
"top_k": 40 # override by setting kwarg in completion() - e.g. completion(..., top_k=20)
}
prompt = " ".join([message["content"] for message in messages])
## LOGGING
logging.pre_call(input=prompt, api_key=None)
if model in litellm.vertex_text_models:
vertex_model = TextGenerationModel.from_pretrained(model)
else:
vertex_model = CodeGenerationModel.from_pretrained(model)
## Load Config
for k, v in VertexAIConfig.items():
if k not in optional_params:
optional_params[k] = v
if stream:
## LOGGING
logging.pre_call(input=prompt, api_key=None)
if "stream" in optional_params and optional_params["stream"] == True:
model_response = vertex_model.predict_streaming(prompt, **optional_params)
response = CustomStreamWrapper(
model_response, model, custom_llm_provider="vertexai", logging_obj=logging