mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Add partial support of vertexai safety settings
This commit is contained in:
parent
9548334e2f
commit
2d15e5384b
1 changed files with 18 additions and 6 deletions
|
@ -181,6 +181,7 @@ def completion(
|
|||
from vertexai.preview.language_models import ChatModel, CodeChatModel, InputOutputTextPair
|
||||
from vertexai.language_models import TextGenerationModel, CodeGenerationModel
|
||||
from vertexai.preview.generative_models import GenerativeModel, Part, GenerationConfig
|
||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types
|
||||
|
||||
|
||||
vertexai.init(
|
||||
|
@ -193,6 +194,15 @@ def completion(
|
|||
if k not in optional_params:
|
||||
optional_params[k] = v
|
||||
|
||||
## Process safety settings into format expected by vertex AI
|
||||
if "safety_settings" in optional_params:
|
||||
safety_settings = optional_params.pop("safety_settings")
|
||||
if not isinstance(safety_settings, list):
|
||||
raise ValueError("safety_settings must be a list")
|
||||
if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict):
|
||||
raise ValueError("safety_settings must be a list of dicts")
|
||||
safety_settings=[gapic_content_types.SafetySetting(x) for x in safety_settings]
|
||||
|
||||
# vertexai does not use an API key, it looks for credentials.json in the environment
|
||||
|
||||
prompt = " ".join([message["content"] for message in messages if isinstance(message["content"], str)])
|
||||
|
@ -238,16 +248,16 @@ def completion(
|
|||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), stream=stream)
|
||||
model_response = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings, stream=stream)
|
||||
optional_params["stream"] = True
|
||||
return model_response
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params})).text\n"
|
||||
request_str += f"chat.send_message({prompt}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}).text\n"
|
||||
## LOGGING
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params))
|
||||
response_obj = chat.send_message(prompt, generation_config=GenerationConfig(**optional_params), safety_settings=safety_settings)
|
||||
completion_response = response_obj.text
|
||||
response_obj = response_obj._raw_response
|
||||
elif mode == "vision":
|
||||
|
@ -258,12 +268,13 @@ def completion(
|
|||
content = [prompt] + images
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
stream = optional_params.pop("stream")
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), stream={stream})\n"
|
||||
request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n"
|
||||
logging_obj.pre_call(input=prompt, api_key=None, additional_args={"complete_input_dict": optional_params, "request_str": request_str})
|
||||
|
||||
model_response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
stream=True
|
||||
)
|
||||
optional_params["stream"] = True
|
||||
|
@ -276,7 +287,8 @@ def completion(
|
|||
## LLM Call
|
||||
response = llm_model.generate_content(
|
||||
contents=content,
|
||||
generation_config=GenerationConfig(**optional_params)
|
||||
generation_config=GenerationConfig(**optional_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
completion_response = response.text
|
||||
response_obj = response._raw_response
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue