fix(gemini.py): support streaming

This commit is contained in:
Krrish Dholakia 2024-01-19 19:26:23 -08:00
parent b59e67f099
commit b07677c6be
4 changed files with 67 additions and 13 deletions

View file

@ -120,9 +120,7 @@ def completion(
## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
stream = inference_params.pop("stream", None)
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if (
@ -139,10 +137,18 @@ def completion(
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f"models/{model}")
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
if stream != True:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
stream=True,
)
return response
except Exception as e:
raise GeminiError(
message=str(e),
@ -177,16 +183,20 @@ def completion(
try:
completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None:
if completion_response is None:
raise Exception
except:
original_response = f"response: {response}"
if hasattr(response, "candidates"):
if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}"
if "SAFETY" in original_response:
original_response += "\nThe candidate content was flagged for safety reasons."
if "SAFETY" in original_response:
original_response += (
"\nThe candidate content was flagged for safety reasons."
)
elif "RECITATION" in original_response:
original_response += "\nThe candidate content was flagged for recitation reasons."
original_response += (
"\nThe candidate content was flagged for recitation reasons."
)
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {original_response}",