fix(gemini.py): support streaming

This commit is contained in:
Krrish Dholakia 2024-01-19 19:26:23 -08:00
parent b59e67f099
commit b07677c6be
4 changed files with 67 additions and 13 deletions

View file

@ -120,9 +120,7 @@ def completion(
## Load Config
inference_params = copy.deepcopy(optional_params)
inference_params.pop(
"stream", None
) # palm does not support streaming, so we handle this by fake streaming in main.py
stream = inference_params.pop("stream", None)
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
if (
@ -139,10 +137,18 @@ def completion(
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f"models/{model}")
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
if stream != True:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
stream=True,
)
return response
except Exception as e:
raise GeminiError(
message=str(e),
@ -177,16 +183,20 @@ def completion(
try:
completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None:
if completion_response is None:
raise Exception
except:
original_response = f"response: {response}"
if hasattr(response, "candidates"):
if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}"
if "SAFETY" in original_response:
original_response += "\nThe candidate content was flagged for safety reasons."
if "SAFETY" in original_response:
original_response += (
"\nThe candidate content was flagged for safety reasons."
)
elif "RECITATION" in original_response:
original_response += "\nThe candidate content was flagged for recitation reasons."
original_response += (
"\nThe candidate content was flagged for recitation reasons."
)
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {original_response}",

View file

@ -1382,6 +1382,18 @@ def completion(
acompletion=acompletion,
custom_prompt_dict=custom_prompt_dict,
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and acompletion == False
):
response = CustomStreamWrapper(
iter(model_response),
model,
custom_llm_provider="gemini",
logging_obj=logging,
)
return response
response = model_response
elif custom_llm_provider == "vertex_ai":
vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT")

View file

@ -398,6 +398,36 @@ def test_completion_palm_stream():
# test_completion_palm_stream()
def test_completion_gemini_stream():
try:
litellm.set_verbose = False
print("Streaming gemini response")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
},
]
print("testing gemini streaming")
response = completion(model="gemini/gemini-pro", messages=messages, stream=True)
print(f"type of response at the top: {response}")
complete_response = ""
# Add any assertions here to check the response
for idx, chunk in enumerate(response):
print(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
break
complete_response += chunk
if complete_response.strip() == "":
raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
def test_completion_mistral_api_stream():
try:
litellm.set_verbose = True

View file

@ -7622,7 +7622,9 @@ class CustomStreamWrapper:
raise Exception("An unknown error occurred with the stream")
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
elif self.custom_llm_provider == "gemini":
completion_obj["content"] = chunk.text
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try:
# print(chunk)
if hasattr(chunk, "text"):