mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(gemini.py): support streaming
This commit is contained in:
parent
b59e67f099
commit
b07677c6be
4 changed files with 67 additions and 13 deletions
|
@ -120,9 +120,7 @@ def completion(
|
|||
|
||||
## Load Config
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"stream", None
|
||||
) # palm does not support streaming, so we handle this by fake streaming in main.py
|
||||
stream = inference_params.pop("stream", None)
|
||||
config = litellm.GeminiConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
|
@ -139,10 +137,18 @@ def completion(
|
|||
## COMPLETION CALL
|
||||
try:
|
||||
_model = genai.GenerativeModel(f"models/{model}")
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
)
|
||||
if stream != True:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
)
|
||||
else:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
stream=True,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise GeminiError(
|
||||
message=str(e),
|
||||
|
@ -177,16 +183,20 @@ def completion(
|
|||
|
||||
try:
|
||||
completion_response = model_response["choices"][0]["message"].get("content")
|
||||
if completion_response is None:
|
||||
if completion_response is None:
|
||||
raise Exception
|
||||
except:
|
||||
original_response = f"response: {response}"
|
||||
if hasattr(response, "candidates"):
|
||||
if hasattr(response, "candidates"):
|
||||
original_response = f"response: {response.candidates}"
|
||||
if "SAFETY" in original_response:
|
||||
original_response += "\nThe candidate content was flagged for safety reasons."
|
||||
if "SAFETY" in original_response:
|
||||
original_response += (
|
||||
"\nThe candidate content was flagged for safety reasons."
|
||||
)
|
||||
elif "RECITATION" in original_response:
|
||||
original_response += "\nThe candidate content was flagged for recitation reasons."
|
||||
original_response += (
|
||||
"\nThe candidate content was flagged for recitation reasons."
|
||||
)
|
||||
raise GeminiError(
|
||||
status_code=400,
|
||||
message=f"No response received. Original response - {original_response}",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue