Add safety_settings parameter to gemini generate_content calls

This commit is contained in:
Andres Barbaro 2024-02-16 12:22:18 -06:00
parent 65e1d5a9e7
commit 1f054203bf

View file

@ -121,6 +121,13 @@ def completion(
## Load Config ## Load Config
inference_params = copy.deepcopy(optional_params) inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", None) stream = inference_params.pop("stream", None)
# Handle safety settings
safety_settings_param = inference_params.pop("safety_settings", None)
safety_settings = None
if safety_settings_param:
safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param]
config = litellm.GeminiConfig.get_config() config = litellm.GeminiConfig.get_config()
for k, v in config.items(): for k, v in config.items():
if ( if (
@ -141,11 +148,13 @@ def completion(
response = _model.generate_content( response = _model.generate_content(
contents=prompt, contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params), generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
) )
else: else:
response = _model.generate_content( response = _model.generate_content(
contents=prompt, contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params), generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
stream=True, stream=True,
) )
return response return response