mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
raise exception if optional param is not mapped to model
This commit is contained in:
parent
49f65b7eb8
commit
1cae080eb2
8 changed files with 155 additions and 112 deletions
|
@ -154,14 +154,14 @@ def completion(
|
|||
messages: List = [],
|
||||
functions: List = [],
|
||||
function_call: str = "", # optional params
|
||||
temperature: float = 1,
|
||||
top_p: float = 1,
|
||||
n: int = 1,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stream: bool = False,
|
||||
stop=None,
|
||||
max_tokens: float = float("inf"),
|
||||
presence_penalty: float = 0,
|
||||
frequency_penalty=0,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float]=None,
|
||||
logit_bias: dict = {},
|
||||
user: str = "",
|
||||
deployment_id = None,
|
||||
|
@ -214,14 +214,15 @@ def completion(
|
|||
litellm_logging_obj = kwargs.get('litellm_logging_obj', None)
|
||||
id = kwargs.get('id', None)
|
||||
metadata = kwargs.get('metadata', None)
|
||||
request_timeout = kwargs.get('request_timeout', 0)
|
||||
fallbacks = kwargs.get('fallbacks', [])
|
||||
######## end of unpacking kwargs ###########
|
||||
|
||||
args = locals()
|
||||
openai_params = ["functions", "function_call", "temperature", "temperature", "top_p", "n", "stream", "stop", "max_tokens", "presence_penalty", "frequency_penalty", "logit_bias", "user", "metadata"]
|
||||
litellm_params = ["return_async", "mock_response", "api_key", "api_version", "api_base", "force_timeout", "logger_fn", "verbose", "custom_llm_provider", "litellm_logging_obj", "litellm_call_id", "id", "metadata", "fallbacks"]
|
||||
default_params = openai_params + litellm_params
|
||||
non_default_params = {k: v for k,v in kwargs.items() if k not in default_params} # model-specific params - pass them straight to the model/provider
|
||||
if mock_response:
|
||||
return mock_completion(model, messages, stream=stream, mock_response=mock_response)
|
||||
|
||||
args = locals()
|
||||
try:
|
||||
logging = litellm_logging_obj
|
||||
if fallbacks != []:
|
||||
|
@ -260,10 +261,7 @@ def completion(
|
|||
# params to identify the model
|
||||
model=model,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
top_k=kwargs.get('top_k', 40),
|
||||
task=kwargs.get('task', "text-generation-inference"),
|
||||
remove_input=kwargs.get('remove_input', True),
|
||||
return_full_text=kwargs.get('return_full_text', False),
|
||||
**non_default_params
|
||||
)
|
||||
# For logging - save the values of the litellm-specific params passed in
|
||||
litellm_params = get_litellm_params(
|
||||
|
@ -822,8 +820,10 @@ def completion(
|
|||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging.pre_call(input=prompt, api_key=None)
|
||||
# contains any default values we need to pass to the provider
|
||||
VertexAIConfig = {
|
||||
"top_k": 40 # override by setting kwarg in completion() - e.g. completion(..., top_k=20)
|
||||
}
|
||||
if model in litellm.vertex_chat_models:
|
||||
chat_model = ChatModel.from_pretrained(model)
|
||||
else: # vertex_code_chat_models
|
||||
|
@ -831,7 +831,16 @@ def completion(
|
|||
|
||||
chat = chat_model.start_chat()
|
||||
|
||||
if stream:
|
||||
## Load Config
|
||||
for k, v in VertexAIConfig.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
print(f"optional_params: {optional_params}")
|
||||
## LOGGING
|
||||
logging.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params})
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
model_response = chat.send_message_streaming(prompt, **optional_params)
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="vertex_ai", logging_obj=logging
|
||||
|
@ -875,16 +884,27 @@ def completion(
|
|||
)
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
# contains any default values we need to pass to the provider
|
||||
VertexAIConfig = {
|
||||
"top_k": 40 # override by setting kwarg in completion() - e.g. completion(..., top_k=20)
|
||||
}
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages])
|
||||
## LOGGING
|
||||
logging.pre_call(input=prompt, api_key=None)
|
||||
|
||||
if model in litellm.vertex_text_models:
|
||||
vertex_model = TextGenerationModel.from_pretrained(model)
|
||||
else:
|
||||
vertex_model = CodeGenerationModel.from_pretrained(model)
|
||||
|
||||
## Load Config
|
||||
for k, v in VertexAIConfig.items():
|
||||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
if stream:
|
||||
## LOGGING
|
||||
logging.pre_call(input=prompt, api_key=None)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
model_response = vertex_model.predict_streaming(prompt, **optional_params)
|
||||
response = CustomStreamWrapper(
|
||||
model_response, model, custom_llm_provider="vertexai", logging_obj=logging
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue