fix(gemini.py): fix async streaming + add native async completions

This commit is contained in:
Krrish Dholakia 2024-02-19 22:41:36 -08:00
parent e4ae3e0ab6
commit 11c12e7381
6 changed files with 224 additions and 17 deletions

View file

@ -126,7 +126,9 @@ def completion(
safety_settings_param = inference_params.pop("safety_settings", None)
safety_settings = None
if safety_settings_param:
safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param]
safety_settings = [
genai.types.SafetySettingDict(x) for x in safety_settings_param
]
config = litellm.GeminiConfig.get_config()
for k, v in config.items():
@ -144,13 +146,29 @@ def completion(
## COMPLETION CALL
try:
_model = genai.GenerativeModel(f"models/{model}")
if stream != True:
response = _model.generate_content(
if stream == True:
if acompletion == True:
async def async_streaming():
response = await _model.generate_content_async(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
generation_config=genai.types.GenerationConfig(
**inference_params
),
safety_settings=safety_settings,
stream=True,
)
else:
response = litellm.CustomStreamWrapper(
aiter(response),
model,
custom_llm_provider="gemini",
logging_obj=logging_obj,
)
return response
return async_streaming()
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
@ -158,6 +176,25 @@ def completion(
stream=True,
)
return response
elif acompletion == True:
return async_completion(
_model=_model,
model=model,
prompt=prompt,
inference_params=inference_params,
safety_settings=safety_settings,
logging_obj=logging_obj,
print_verbose=print_verbose,
model_response=model_response,
messages=messages,
encoding=encoding,
)
else:
response = _model.generate_content(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
except Exception as e:
raise GeminiError(
message=str(e),
@ -236,6 +273,98 @@ def completion(
return model_response
async def async_completion(
_model,
model,
prompt,
inference_params,
safety_settings,
logging_obj,
print_verbose,
model_response,
messages,
encoding,
):
import google.generativeai as genai
response = await _model.generate_content_async(
contents=prompt,
generation_config=genai.types.GenerationConfig(**inference_params),
safety_settings=safety_settings,
)
## LOGGING
logging_obj.post_call(
input=prompt,
api_key="",
original_response=response,
additional_args={"complete_input_dict": {}},
)
print_verbose(f"raw model_response: {response}")
## RESPONSE OBJECT
completion_response = response
try:
choices_list = []
for idx, item in enumerate(completion_response.candidates):
if len(item.content.parts) > 0:
message_obj = Message(content=item.content.parts[0].text)
else:
message_obj = Message(content=None)
choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj)
model_response["choices"] = choices_list
except Exception as e:
traceback.print_exc()
raise GeminiError(
message=traceback.format_exc(), status_code=response.status_code
)
try:
completion_response = model_response["choices"][0]["message"].get("content")
if completion_response is None:
raise Exception
except:
original_response = f"response: {response}"
if hasattr(response, "candidates"):
original_response = f"response: {response.candidates}"
if "SAFETY" in original_response:
original_response += (
"\nThe candidate content was flagged for safety reasons."
)
elif "RECITATION" in original_response:
original_response += (
"\nThe candidate content was flagged for recitation reasons."
)
raise GeminiError(
status_code=400,
message=f"No response received. Original response - {original_response}",
)
## CALCULATING USAGE
prompt_str = ""
for m in messages:
if isinstance(m["content"], str):
prompt_str += m["content"]
elif isinstance(m["content"], list):
for content in m["content"]:
if content["type"] == "text":
prompt_str += content["text"]
prompt_tokens = len(encoding.encode(prompt_str))
completion_tokens = len(
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
)
model_response["created"] = int(time.time())
model_response["model"] = "gemini/" + model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
def embedding():
# logic for parsing in - calling - parsing out model embedding calls
pass

View file

@ -263,6 +263,7 @@ async def acompletion(
or custom_llm_provider == "ollama"
or custom_llm_provider == "ollama_chat"
or custom_llm_provider == "vertex_ai"
or custom_llm_provider == "gemini"
or custom_llm_provider == "sagemaker"
or custom_llm_provider in litellm.openai_compatible_providers
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.

View file

@ -681,11 +681,11 @@ class PrismaClient:
return response
elif table_name == "user_notification":
if query_type == "find_unique":
response = await self.db.litellm_usernotifications.find_unique(
response = await self.db.litellm_usernotifications.find_unique( # type: ignore
where={"user_id": user_id} # type: ignore
)
elif query_type == "find_all":
response = await self.db.litellm_usernotifications.find_many()
response = await self.db.litellm_usernotifications.find_many() # type: ignore
return response
except Exception as e:
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
@ -795,7 +795,7 @@ class PrismaClient:
elif table_name == "user_notification":
db_data = self.jsonify_object(data=data)
new_user_notification_row = (
await self.db.litellm_usernotifications.upsert(
await self.db.litellm_usernotifications.upsert( # type: ignore
where={"request_id": data["request_id"]},
data={
"create": {**db_data}, # type: ignore

View file

@ -1993,6 +1993,19 @@ def test_completion_gemini():
# test_completion_gemini()
@pytest.mark.asyncio
async def test_acompletion_gemini():
litellm.set_verbose = True
model_name = "gemini/gemini-pro"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
try:
response = await litellm.acompletion(model=model_name, messages=messages)
# Add any assertions here to check the response
print(f"response: {response}")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# Palm tests
def test_completion_palm():
litellm.set_verbose = True

View file

@ -429,6 +429,47 @@ def test_completion_gemini_stream():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.asyncio
async def test_acompletion_gemini_stream():
try:
litellm.set_verbose = True
print("Streaming gemini response")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{
"role": "user",
"content": "how does a court case get to the Supreme Court?",
},
]
print("testing gemini streaming")
response = await acompletion(
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
)
print(f"type of response at the top: {response}")
complete_response = ""
idx = 0
# Add any assertions here to check the response
async for chunk in response:
print(f"chunk in acompletion gemini: {chunk}")
print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if idx > 5:
break
if finished:
break
print(f"chunk: {chunk}")
complete_response += chunk
idx += 1
print(f"completion_response: {complete_response}")
if complete_response.strip() == "":
raise Exception("Empty response received")
except Exception as e:
pytest.fail(f"Error occurred: {e}")
# asyncio.run(test_acompletion_gemini_stream())
def test_completion_mistral_api_stream():
try:
litellm.set_verbose = True

View file

@ -8417,7 +8417,28 @@ class CustomStreamWrapper:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider == "gemini":
completion_obj["content"] = chunk.text
try:
if hasattr(chunk, "parts") == True:
try:
if len(chunk.parts) > 0:
completion_obj["content"] = chunk.parts[0].text
if hasattr(chunk.parts[0], "finish_reason"):
model_response.choices[0].finish_reason = (
map_finish_reason(chunk.parts[0].finish_reason.name)
)
except:
if chunk.parts[0].finish_reason.name == "SAFETY":
raise Exception(
f"The response was blocked by VertexAI. {str(chunk)}"
)
else:
completion_obj["content"] = str(chunk)
except StopIteration as e:
if self.sent_last_chunk:
raise e
else:
model_response.choices[0].finish_reason = "stop"
self.sent_last_chunk = True
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
try:
if hasattr(chunk, "candidates") == True:
@ -8727,19 +8748,21 @@ class CustomStreamWrapper:
or self.custom_llm_provider == "ollama_chat"
or self.custom_llm_provider == "vertex_ai"
or self.custom_llm_provider == "sagemaker"
or self.custom_llm_provider == "gemini"
or self.custom_llm_provider in litellm.openai_compatible_endpoints
):
print_verbose(
f"value of async completion stream: {self.completion_stream}"
)
async for chunk in self.completion_stream:
print_verbose(f"value of async chunk: {chunk}")
print_verbose(
f"value of async chunk: {chunk.parts}; len(chunk.parts): {len(chunk.parts)}"
)
if chunk == "None" or chunk is None:
raise Exception
elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0:
continue
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
# __anext__ also calls async_success_handler, which does logging
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
chunk=chunk
)