From 45eb4a5fcca34b9538dc3ce84129f1c63b967e13 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 19 Feb 2024 22:41:36 -0800 Subject: [PATCH 1/3] fix(gemini.py): fix async streaming + add native async completions --- litellm/llms/gemini.py | 145 +++++++++++++++++++++++++++++-- litellm/main.py | 1 + litellm/proxy/utils.py | 6 +- litellm/tests/test_completion.py | 13 +++ litellm/tests/test_streaming.py | 41 +++++++++ litellm/utils.py | 35 ++++++-- 6 files changed, 224 insertions(+), 17 deletions(-) diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index 2db27aeba..03574559c 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -126,7 +126,9 @@ def completion( safety_settings_param = inference_params.pop("safety_settings", None) safety_settings = None if safety_settings_param: - safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param] + safety_settings = [ + genai.types.SafetySettingDict(x) for x in safety_settings_param + ] config = litellm.GeminiConfig.get_config() for k, v in config.items(): @@ -144,13 +146,29 @@ def completion( ## COMPLETION CALL try: _model = genai.GenerativeModel(f"models/{model}") - if stream != True: - response = _model.generate_content( - contents=prompt, - generation_config=genai.types.GenerationConfig(**inference_params), - safety_settings=safety_settings, - ) - else: + if stream == True: + if acompletion == True: + + async def async_streaming(): + response = await _model.generate_content_async( + contents=prompt, + generation_config=genai.types.GenerationConfig( + **inference_params + ), + safety_settings=safety_settings, + stream=True, + ) + + response = litellm.CustomStreamWrapper( + aiter(response), + model, + custom_llm_provider="gemini", + logging_obj=logging_obj, + ) + + return response + + return async_streaming() response = _model.generate_content( contents=prompt, generation_config=genai.types.GenerationConfig(**inference_params), @@ -158,6 +176,25 @@ def completion( stream=True, ) return response + elif acompletion == True: + return async_completion( + _model=_model, + model=model, + prompt=prompt, + inference_params=inference_params, + safety_settings=safety_settings, + logging_obj=logging_obj, + print_verbose=print_verbose, + model_response=model_response, + messages=messages, + encoding=encoding, + ) + else: + response = _model.generate_content( + contents=prompt, + generation_config=genai.types.GenerationConfig(**inference_params), + safety_settings=safety_settings, + ) except Exception as e: raise GeminiError( message=str(e), @@ -236,6 +273,98 @@ def completion( return model_response +async def async_completion( + _model, + model, + prompt, + inference_params, + safety_settings, + logging_obj, + print_verbose, + model_response, + messages, + encoding, +): + import google.generativeai as genai + + response = await _model.generate_content_async( + contents=prompt, + generation_config=genai.types.GenerationConfig(**inference_params), + safety_settings=safety_settings, + ) + + ## LOGGING + logging_obj.post_call( + input=prompt, + api_key="", + original_response=response, + additional_args={"complete_input_dict": {}}, + ) + print_verbose(f"raw model_response: {response}") + ## RESPONSE OBJECT + completion_response = response + try: + choices_list = [] + for idx, item in enumerate(completion_response.candidates): + if len(item.content.parts) > 0: + message_obj = Message(content=item.content.parts[0].text) + else: + message_obj = Message(content=None) + choice_obj = Choices(index=idx + 1, message=message_obj) + choices_list.append(choice_obj) + model_response["choices"] = choices_list + except Exception as e: + traceback.print_exc() + raise GeminiError( + message=traceback.format_exc(), status_code=response.status_code + ) + + try: + completion_response = model_response["choices"][0]["message"].get("content") + if completion_response is None: + raise Exception + except: + original_response = f"response: {response}" + if hasattr(response, "candidates"): + original_response = f"response: {response.candidates}" + if "SAFETY" in original_response: + original_response += ( + "\nThe candidate content was flagged for safety reasons." + ) + elif "RECITATION" in original_response: + original_response += ( + "\nThe candidate content was flagged for recitation reasons." + ) + raise GeminiError( + status_code=400, + message=f"No response received. Original response - {original_response}", + ) + + ## CALCULATING USAGE + prompt_str = "" + for m in messages: + if isinstance(m["content"], str): + prompt_str += m["content"] + elif isinstance(m["content"], list): + for content in m["content"]: + if content["type"] == "text": + prompt_str += content["text"] + prompt_tokens = len(encoding.encode(prompt_str)) + completion_tokens = len( + encoding.encode(model_response["choices"][0]["message"].get("content", "")) + ) + + model_response["created"] = int(time.time()) + model_response["model"] = "gemini/" + model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response + + def embedding(): # logic for parsing in - calling - parsing out model embedding calls pass diff --git a/litellm/main.py b/litellm/main.py index ec69f5f3a..1ee36504f 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -263,6 +263,7 @@ async def acompletion( or custom_llm_provider == "ollama" or custom_llm_provider == "ollama_chat" or custom_llm_provider == "vertex_ai" + or custom_llm_provider == "gemini" or custom_llm_provider == "sagemaker" or custom_llm_provider in litellm.openai_compatible_providers ): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all. diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index bd0c0eaf7..78c1e4b63 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -681,11 +681,11 @@ class PrismaClient: return response elif table_name == "user_notification": if query_type == "find_unique": - response = await self.db.litellm_usernotifications.find_unique( + response = await self.db.litellm_usernotifications.find_unique( # type: ignore where={"user_id": user_id} # type: ignore ) elif query_type == "find_all": - response = await self.db.litellm_usernotifications.find_many() + response = await self.db.litellm_usernotifications.find_many() # type: ignore return response except Exception as e: print_verbose(f"LiteLLM Prisma Client Exception: {e}") @@ -795,7 +795,7 @@ class PrismaClient: elif table_name == "user_notification": db_data = self.jsonify_object(data=data) new_user_notification_row = ( - await self.db.litellm_usernotifications.upsert( + await self.db.litellm_usernotifications.upsert( # type: ignore where={"request_id": data["request_id"]}, data={ "create": {**db_data}, # type: ignore diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f2186093a..605113d35 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1993,6 +1993,19 @@ def test_completion_gemini(): # test_completion_gemini() +@pytest.mark.asyncio +async def test_acompletion_gemini(): + litellm.set_verbose = True + model_name = "gemini/gemini-pro" + messages = [{"role": "user", "content": "Hey, how's it going?"}] + try: + response = await litellm.acompletion(model=model_name, messages=messages) + # Add any assertions here to check the response + print(f"response: {response}") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + # Palm tests def test_completion_palm(): litellm.set_verbose = True diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 4df4f20e9..ee6a187e2 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -429,6 +429,47 @@ def test_completion_gemini_stream(): pytest.fail(f"Error occurred: {e}") +@pytest.mark.asyncio +async def test_acompletion_gemini_stream(): + try: + litellm.set_verbose = True + print("Streaming gemini response") + messages = [ + {"role": "system", "content": "You are a helpful assistant."}, + { + "role": "user", + "content": "how does a court case get to the Supreme Court?", + }, + ] + print("testing gemini streaming") + response = await acompletion( + model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True + ) + print(f"type of response at the top: {response}") + complete_response = "" + idx = 0 + # Add any assertions here to check the response + async for chunk in response: + print(f"chunk in acompletion gemini: {chunk}") + print(chunk.choices[0].delta) + chunk, finished = streaming_format_tests(idx, chunk) + if idx > 5: + break + if finished: + break + print(f"chunk: {chunk}") + complete_response += chunk + idx += 1 + print(f"completion_response: {complete_response}") + if complete_response.strip() == "": + raise Exception("Empty response received") + except Exception as e: + pytest.fail(f"Error occurred: {e}") + + +# asyncio.run(test_acompletion_gemini_stream()) + + def test_completion_mistral_api_stream(): try: litellm.set_verbose = True diff --git a/litellm/utils.py b/litellm/utils.py index faa464448..c299a440d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8417,7 +8417,28 @@ class CustomStreamWrapper: model_response.choices[0].finish_reason = "stop" self.sent_last_chunk = True elif self.custom_llm_provider == "gemini": - completion_obj["content"] = chunk.text + try: + if hasattr(chunk, "parts") == True: + try: + if len(chunk.parts) > 0: + completion_obj["content"] = chunk.parts[0].text + if hasattr(chunk.parts[0], "finish_reason"): + model_response.choices[0].finish_reason = ( + map_finish_reason(chunk.parts[0].finish_reason.name) + ) + except: + if chunk.parts[0].finish_reason.name == "SAFETY": + raise Exception( + f"The response was blocked by VertexAI. {str(chunk)}" + ) + else: + completion_obj["content"] = str(chunk) + except StopIteration as e: + if self.sent_last_chunk: + raise e + else: + model_response.choices[0].finish_reason = "stop" + self.sent_last_chunk = True elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): try: if hasattr(chunk, "candidates") == True: @@ -8727,19 +8748,21 @@ class CustomStreamWrapper: or self.custom_llm_provider == "ollama_chat" or self.custom_llm_provider == "vertex_ai" or self.custom_llm_provider == "sagemaker" + or self.custom_llm_provider == "gemini" or self.custom_llm_provider in litellm.openai_compatible_endpoints ): - print_verbose( - f"value of async completion stream: {self.completion_stream}" - ) async for chunk in self.completion_stream: - print_verbose(f"value of async chunk: {chunk}") + print_verbose( + f"value of async chunk: {chunk.parts}; len(chunk.parts): {len(chunk.parts)}" + ) if chunk == "None" or chunk is None: raise Exception - + elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0: + continue # chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks. # __anext__ also calls async_success_handler, which does logging print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}") + processed_chunk: Optional[ModelResponse] = self.chunk_creator( chunk=chunk ) From 7b641491a2fb1a20912f002bf67f490b6aac9d1a Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Mon, 19 Feb 2024 23:00:41 -0800 Subject: [PATCH 2/3] fix(utils.py): fix print statement --- litellm/utils.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index c299a440d..982462e3f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8752,9 +8752,7 @@ class CustomStreamWrapper: or self.custom_llm_provider in litellm.openai_compatible_endpoints ): async for chunk in self.completion_stream: - print_verbose( - f"value of async chunk: {chunk.parts}; len(chunk.parts): {len(chunk.parts)}" - ) + print_verbose(f"value of async chunk: {chunk}") if chunk == "None" or chunk is None: raise Exception elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0: From 1d3bef2e9c38b7eac6ada2e67e4b17ae172a4c63 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 20 Feb 2024 17:10:51 -0800 Subject: [PATCH 3/3] fix(gemini.py): implement custom streamer --- litellm/llms/gemini.py | 26 +++++++++++++++++++++++--- litellm/tests/test_streaming.py | 4 +--- 2 files changed, 24 insertions(+), 6 deletions(-) diff --git a/litellm/llms/gemini.py b/litellm/llms/gemini.py index 03574559c..8d9994cb6 100644 --- a/litellm/llms/gemini.py +++ b/litellm/llms/gemini.py @@ -1,4 +1,4 @@ -import os, types, traceback, copy +import os, types, traceback, copy, asyncio import json from enum import Enum import time @@ -82,6 +82,27 @@ class GeminiConfig: } +class TextStreamer: + """ + A class designed to return an async stream from AsyncGenerateContentResponse object. + """ + + def __init__(self, response): + self.response = response + self._aiter = self.response.__aiter__() + + async def __aiter__(self): + while True: + try: + # This will manually advance the async iterator. + # In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception + next_object = await self._aiter.__anext__() + yield next_object + except StopAsyncIteration: + # After getting all items from the async iterator, stop iterating + break + + def completion( model: str, messages: list, @@ -160,12 +181,11 @@ def completion( ) response = litellm.CustomStreamWrapper( - aiter(response), + TextStreamer(response), model, custom_llm_provider="gemini", logging_obj=logging_obj, ) - return response return async_streaming() diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index ee6a187e2..58dc25fb0 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -438,7 +438,7 @@ async def test_acompletion_gemini_stream(): {"role": "system", "content": "You are a helpful assistant."}, { "role": "user", - "content": "how does a court case get to the Supreme Court?", + "content": "What do you know?", }, ] print("testing gemini streaming") @@ -453,8 +453,6 @@ async def test_acompletion_gemini_stream(): print(f"chunk in acompletion gemini: {chunk}") print(chunk.choices[0].delta) chunk, finished = streaming_format_tests(idx, chunk) - if idx > 5: - break if finished: break print(f"chunk: {chunk}")