mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(gemini.py): fix async streaming + add native async completions
This commit is contained in:
parent
e4ae3e0ab6
commit
11c12e7381
6 changed files with 224 additions and 17 deletions
|
@ -126,7 +126,9 @@ def completion(
|
|||
safety_settings_param = inference_params.pop("safety_settings", None)
|
||||
safety_settings = None
|
||||
if safety_settings_param:
|
||||
safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param]
|
||||
safety_settings = [
|
||||
genai.types.SafetySettingDict(x) for x in safety_settings_param
|
||||
]
|
||||
|
||||
config = litellm.GeminiConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -144,13 +146,29 @@ def completion(
|
|||
## COMPLETION CALL
|
||||
try:
|
||||
_model = genai.GenerativeModel(f"models/{model}")
|
||||
if stream != True:
|
||||
response = _model.generate_content(
|
||||
if stream == True:
|
||||
if acompletion == True:
|
||||
|
||||
async def async_streaming():
|
||||
response = await _model.generate_content_async(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
**inference_params
|
||||
),
|
||||
safety_settings=safety_settings,
|
||||
stream=True,
|
||||
)
|
||||
else:
|
||||
|
||||
response = litellm.CustomStreamWrapper(
|
||||
aiter(response),
|
||||
model,
|
||||
custom_llm_provider="gemini",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
return async_streaming()
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
|
@ -158,6 +176,25 @@ def completion(
|
|||
stream=True,
|
||||
)
|
||||
return response
|
||||
elif acompletion == True:
|
||||
return async_completion(
|
||||
_model=_model,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
inference_params=inference_params,
|
||||
safety_settings=safety_settings,
|
||||
logging_obj=logging_obj,
|
||||
print_verbose=print_verbose,
|
||||
model_response=model_response,
|
||||
messages=messages,
|
||||
encoding=encoding,
|
||||
)
|
||||
else:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
except Exception as e:
|
||||
raise GeminiError(
|
||||
message=str(e),
|
||||
|
@ -236,6 +273,98 @@ def completion(
|
|||
return model_response
|
||||
|
||||
|
||||
async def async_completion(
|
||||
_model,
|
||||
model,
|
||||
prompt,
|
||||
inference_params,
|
||||
safety_settings,
|
||||
logging_obj,
|
||||
print_verbose,
|
||||
model_response,
|
||||
messages,
|
||||
encoding,
|
||||
):
|
||||
import google.generativeai as genai
|
||||
|
||||
response = await _model.generate_content_async(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
safety_settings=safety_settings,
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
api_key="",
|
||||
original_response=response,
|
||||
additional_args={"complete_input_dict": {}},
|
||||
)
|
||||
print_verbose(f"raw model_response: {response}")
|
||||
## RESPONSE OBJECT
|
||||
completion_response = response
|
||||
try:
|
||||
choices_list = []
|
||||
for idx, item in enumerate(completion_response.candidates):
|
||||
if len(item.content.parts) > 0:
|
||||
message_obj = Message(content=item.content.parts[0].text)
|
||||
else:
|
||||
message_obj = Message(content=None)
|
||||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response["choices"] = choices_list
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
raise GeminiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
||||
try:
|
||||
completion_response = model_response["choices"][0]["message"].get("content")
|
||||
if completion_response is None:
|
||||
raise Exception
|
||||
except:
|
||||
original_response = f"response: {response}"
|
||||
if hasattr(response, "candidates"):
|
||||
original_response = f"response: {response.candidates}"
|
||||
if "SAFETY" in original_response:
|
||||
original_response += (
|
||||
"\nThe candidate content was flagged for safety reasons."
|
||||
)
|
||||
elif "RECITATION" in original_response:
|
||||
original_response += (
|
||||
"\nThe candidate content was flagged for recitation reasons."
|
||||
)
|
||||
raise GeminiError(
|
||||
status_code=400,
|
||||
message=f"No response received. Original response - {original_response}",
|
||||
)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_str = ""
|
||||
for m in messages:
|
||||
if isinstance(m["content"], str):
|
||||
prompt_str += m["content"]
|
||||
elif isinstance(m["content"], list):
|
||||
for content in m["content"]:
|
||||
if content["type"] == "text":
|
||||
prompt_str += content["text"]
|
||||
prompt_tokens = len(encoding.encode(prompt_str))
|
||||
completion_tokens = len(
|
||||
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||
)
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = "gemini/" + model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
def embedding():
|
||||
# logic for parsing in - calling - parsing out model embedding calls
|
||||
pass
|
||||
|
|
|
@ -263,6 +263,7 @@ async def acompletion(
|
|||
or custom_llm_provider == "ollama"
|
||||
or custom_llm_provider == "ollama_chat"
|
||||
or custom_llm_provider == "vertex_ai"
|
||||
or custom_llm_provider == "gemini"
|
||||
or custom_llm_provider == "sagemaker"
|
||||
or custom_llm_provider in litellm.openai_compatible_providers
|
||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||
|
|
|
@ -681,11 +681,11 @@ class PrismaClient:
|
|||
return response
|
||||
elif table_name == "user_notification":
|
||||
if query_type == "find_unique":
|
||||
response = await self.db.litellm_usernotifications.find_unique(
|
||||
response = await self.db.litellm_usernotifications.find_unique( # type: ignore
|
||||
where={"user_id": user_id} # type: ignore
|
||||
)
|
||||
elif query_type == "find_all":
|
||||
response = await self.db.litellm_usernotifications.find_many()
|
||||
response = await self.db.litellm_usernotifications.find_many() # type: ignore
|
||||
return response
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM Prisma Client Exception: {e}")
|
||||
|
@ -795,7 +795,7 @@ class PrismaClient:
|
|||
elif table_name == "user_notification":
|
||||
db_data = self.jsonify_object(data=data)
|
||||
new_user_notification_row = (
|
||||
await self.db.litellm_usernotifications.upsert(
|
||||
await self.db.litellm_usernotifications.upsert( # type: ignore
|
||||
where={"request_id": data["request_id"]},
|
||||
data={
|
||||
"create": {**db_data}, # type: ignore
|
||||
|
|
|
@ -1993,6 +1993,19 @@ def test_completion_gemini():
|
|||
# test_completion_gemini()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_gemini():
|
||||
litellm.set_verbose = True
|
||||
model_name = "gemini/gemini-pro"
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
try:
|
||||
response = await litellm.acompletion(model=model_name, messages=messages)
|
||||
# Add any assertions here to check the response
|
||||
print(f"response: {response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# Palm tests
|
||||
def test_completion_palm():
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -429,6 +429,47 @@ def test_completion_gemini_stream():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_gemini_stream():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
print("Streaming gemini response")
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how does a court case get to the Supreme Court?",
|
||||
},
|
||||
]
|
||||
print("testing gemini streaming")
|
||||
response = await acompletion(
|
||||
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
|
||||
)
|
||||
print(f"type of response at the top: {response}")
|
||||
complete_response = ""
|
||||
idx = 0
|
||||
# Add any assertions here to check the response
|
||||
async for chunk in response:
|
||||
print(f"chunk in acompletion gemini: {chunk}")
|
||||
print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if idx > 5:
|
||||
break
|
||||
if finished:
|
||||
break
|
||||
print(f"chunk: {chunk}")
|
||||
complete_response += chunk
|
||||
idx += 1
|
||||
print(f"completion_response: {complete_response}")
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
# asyncio.run(test_acompletion_gemini_stream())
|
||||
|
||||
|
||||
def test_completion_mistral_api_stream():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -8417,7 +8417,28 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].finish_reason = "stop"
|
||||
self.sent_last_chunk = True
|
||||
elif self.custom_llm_provider == "gemini":
|
||||
completion_obj["content"] = chunk.text
|
||||
try:
|
||||
if hasattr(chunk, "parts") == True:
|
||||
try:
|
||||
if len(chunk.parts) > 0:
|
||||
completion_obj["content"] = chunk.parts[0].text
|
||||
if hasattr(chunk.parts[0], "finish_reason"):
|
||||
model_response.choices[0].finish_reason = (
|
||||
map_finish_reason(chunk.parts[0].finish_reason.name)
|
||||
)
|
||||
except:
|
||||
if chunk.parts[0].finish_reason.name == "SAFETY":
|
||||
raise Exception(
|
||||
f"The response was blocked by VertexAI. {str(chunk)}"
|
||||
)
|
||||
else:
|
||||
completion_obj["content"] = str(chunk)
|
||||
except StopIteration as e:
|
||||
if self.sent_last_chunk:
|
||||
raise e
|
||||
else:
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
self.sent_last_chunk = True
|
||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||
try:
|
||||
if hasattr(chunk, "candidates") == True:
|
||||
|
@ -8727,19 +8748,21 @@ class CustomStreamWrapper:
|
|||
or self.custom_llm_provider == "ollama_chat"
|
||||
or self.custom_llm_provider == "vertex_ai"
|
||||
or self.custom_llm_provider == "sagemaker"
|
||||
or self.custom_llm_provider == "gemini"
|
||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||
):
|
||||
print_verbose(
|
||||
f"value of async completion stream: {self.completion_stream}"
|
||||
)
|
||||
async for chunk in self.completion_stream:
|
||||
print_verbose(f"value of async chunk: {chunk}")
|
||||
print_verbose(
|
||||
f"value of async chunk: {chunk.parts}; len(chunk.parts): {len(chunk.parts)}"
|
||||
)
|
||||
if chunk == "None" or chunk is None:
|
||||
raise Exception
|
||||
|
||||
elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0:
|
||||
continue
|
||||
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
|
||||
# __anext__ also calls async_success_handler, which does logging
|
||||
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||
|
||||
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
||||
chunk=chunk
|
||||
)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue