diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index 7e98345b37..2db27aebaf 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -121,6 +121,13 @@ def completion( ## Load Config inference_params = copy.deepcopy(optional_params) stream = inference_params.pop("stream", None) + + # Handle safety settings + safety_settings_param = inference_params.pop("safety_settings", None) + safety_settings = None + if safety_settings_param: + safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param] + config = litellm.GeminiConfig.get_config() for k, v in config.items(): if ( @@ -141,11 +148,13 @@ def completion( response = _model.generate_content( contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params), + safety_settings=safety_settings, ) else: response = _model.generate_content( contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params), + safety_settings=safety_settings, stream=True, ) return response