From b07677c6bee87f95108cad8cf8392e5d7297af82 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 19 Jan 2024 19:26:23 -0800 Subject: [PATCH] fix(gemini.py): support streaming --- litellm/llms/gemini.py | 34 +++++++++++++++++++++------------ litellm/main.py | 12 ++++++++++++ litellm/tests/test_streaming.py | 30 +++++++++++++++++++++++++++++ litellm/utils.py | 4 +++- 4 files changed, 67 insertions(+), 13 deletions(-) diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index 863fb4baf..7e98345b3 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -120,9 +120,7 @@ def completion( ## Load Config inference_params = copy.deepcopy(optional_params) - inference_params.pop( - "stream", None - ) # palm does not support streaming, so we handle this by fake streaming in main.py + stream = inference_params.pop("stream", None) config = litellm.GeminiConfig.get_config() for k, v in config.items(): if ( @@ -139,10 +137,18 @@ def completion( ## COMPLETION CALL try: _model = genai.GenerativeModel(f"models/{model}") - response = _model.generate_content( - contents=prompt, - generation_config=genai.types.GenerationConfig(**inference_params), - ) + if stream != True: + response = _model.generate_content( + contents=prompt, + generation_config=genai.types.GenerationConfig(**inference_params), + ) + else: + response = _model.generate_content( + contents=prompt, + generation_config=genai.types.GenerationConfig(**inference_params), + stream=True, + ) + return response except Exception as e: raise GeminiError( message=str(e), @@ -177,16 +183,20 @@ def completion( try: completion_response = model_response["choices"][0]["message"].get("content") - if completion_response is None: + if completion_response is None: raise Exception except: original_response = f"response: {response}" - if hasattr(response, "candidates"): + if hasattr(response, "candidates"): original_response = f"response: {response.candidates}" - if "SAFETY" in original_response: - original_response += "\nThe candidate content was flagged for safety reasons." + if "SAFETY" in original_response: + original_response += ( + "\nThe candidate content was flagged for safety reasons." + ) elif "RECITATION" in original_response: - original_response += "\nThe candidate content was flagged for recitation reasons." + original_response += ( + "\nThe candidate content was flagged for recitation reasons." + ) raise GeminiError( status_code=400, message=f"No response received. Original response - {original_response}", diff --git a/litellm/main.py b/litellm/main.py index 2fef048a6..271c54e51 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -1382,6 +1382,18 @@ def completion( acompletion=acompletion, custom_prompt_dict=custom_prompt_dict, ) + if ( + "stream" in optional_params + and optional_params["stream"] == True + and acompletion == False + ): + response = CustomStreamWrapper( + iter(model_response), + model, + custom_llm_provider="gemini", + logging_obj=logging, + ) + return response response = model_response elif custom_llm_provider == "vertex_ai": vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 28f2271d7..c3e0b68fa 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -398,6 +398,36 @@ def test_completion_palm_stream(): # test_completion_palm_stream() +def test_completion_gemini_stream(): + try: + litellm.set_verbose = False + print("Streaming gemini response") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how does a court case get to the Supreme Court?", + }, + ] + print("testing gemini streaming") + response = completion(model="gemini/gemini-pro", messages=messages, stream=True) + print(f"type of response at the top: {response}") + complete_response = "" + # Add any assertions here to check the response + for idx, chunk in enumerate(response): + print(chunk) + # print(chunk.choices[0].delta) + chunk, finished = streaming_format_tests(idx, chunk) + if finished: + break + complete_response += chunk + if complete_response.strip() == "": + raise Exception("Empty response received") + print(f"completion_response: {complete_response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + def test_completion_mistral_api_stream(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index a70db73a1..36bf5c9d6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7622,7 +7622,9 @@ class CustomStreamWrapper: raise Exception("An unknown error occurred with the stream") model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True - elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai": + elif self.custom_llm_provider == "gemini": + completion_obj["content"] = chunk.text + elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): try: # print(chunk) if hasattr(chunk, "text"):