Add safety_settings parameter to gemini generate_content calls

This commit is contained in:
Andres Barbaro 2024-02-16 12:22:18 -06:00
parent 65e1d5a9e7
commit 1f054203bf

View file

@ -121,6 +121,13 @@ def completion(
## Load Config
inference_params = copy.deepcopy(optional_params)
stream = inference_params.pop("stream", None)
# Handle safety settings
safety_settings_param = inference_params.pop("safety_settings", None)
safety_settings = None
if safety_settings_param:
safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param]
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if (
@ -141,11 +148,13 @@ def completion(
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
stream=True,
)
return response