mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
Merge pull request #2090 from BerriAI/litellm_gemini_streaming_fixes
fix(gemini.py): fix async streaming + add native async completions
This commit is contained in:
commit
a9c3aeb9fa
5 changed files with 237 additions and 14 deletions
|
@ -1,4 +1,4 @@
|
||||||
import os, types, traceback, copy
|
import os, types, traceback, copy, asyncio
|
||||||
import json
|
import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import time
|
import time
|
||||||
|
@ -82,6 +82,27 @@ class GeminiConfig:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TextStreamer:
|
||||||
|
"""
|
||||||
|
A class designed to return an async stream from AsyncGenerateContentResponse object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, response):
|
||||||
|
self.response = response
|
||||||
|
self._aiter = self.response.__aiter__()
|
||||||
|
|
||||||
|
async def __aiter__(self):
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# This will manually advance the async iterator.
|
||||||
|
# In the case the next object doesn't exists, __anext__() will simply raise a StopAsyncIteration exception
|
||||||
|
next_object = await self._aiter.__anext__()
|
||||||
|
yield next_object
|
||||||
|
except StopAsyncIteration:
|
||||||
|
# After getting all items from the async iterator, stop iterating
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
model: str,
|
model: str,
|
||||||
messages: list,
|
messages: list,
|
||||||
|
@ -126,7 +147,9 @@ def completion(
|
||||||
safety_settings_param = inference_params.pop("safety_settings", None)
|
safety_settings_param = inference_params.pop("safety_settings", None)
|
||||||
safety_settings = None
|
safety_settings = None
|
||||||
if safety_settings_param:
|
if safety_settings_param:
|
||||||
safety_settings = [genai.types.SafetySettingDict(x) for x in safety_settings_param]
|
safety_settings = [
|
||||||
|
genai.types.SafetySettingDict(x) for x in safety_settings_param
|
||||||
|
]
|
||||||
|
|
||||||
config = litellm.GeminiConfig.get_config()
|
config = litellm.GeminiConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -144,13 +167,28 @@ def completion(
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
_model = genai.GenerativeModel(f"models/{model}")
|
_model = genai.GenerativeModel(f"models/{model}")
|
||||||
if stream != True:
|
if stream == True:
|
||||||
response = _model.generate_content(
|
if acompletion == True:
|
||||||
contents=prompt,
|
|
||||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
async def async_streaming():
|
||||||
safety_settings=safety_settings,
|
response = await _model.generate_content_async(
|
||||||
)
|
contents=prompt,
|
||||||
else:
|
generation_config=genai.types.GenerationConfig(
|
||||||
|
**inference_params
|
||||||
|
),
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
TextStreamer(response),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="gemini",
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
)
|
||||||
|
return response
|
||||||
|
|
||||||
|
return async_streaming()
|
||||||
response = _model.generate_content(
|
response = _model.generate_content(
|
||||||
contents=prompt,
|
contents=prompt,
|
||||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||||
|
@ -158,6 +196,25 @@ def completion(
|
||||||
stream=True,
|
stream=True,
|
||||||
)
|
)
|
||||||
return response
|
return response
|
||||||
|
elif acompletion == True:
|
||||||
|
return async_completion(
|
||||||
|
_model=_model,
|
||||||
|
model=model,
|
||||||
|
prompt=prompt,
|
||||||
|
inference_params=inference_params,
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
logging_obj=logging_obj,
|
||||||
|
print_verbose=print_verbose,
|
||||||
|
model_response=model_response,
|
||||||
|
messages=messages,
|
||||||
|
encoding=encoding,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = _model.generate_content(
|
||||||
|
contents=prompt,
|
||||||
|
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise GeminiError(
|
raise GeminiError(
|
||||||
message=str(e),
|
message=str(e),
|
||||||
|
@ -236,6 +293,98 @@ def completion(
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
async def async_completion(
|
||||||
|
_model,
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
inference_params,
|
||||||
|
safety_settings,
|
||||||
|
logging_obj,
|
||||||
|
print_verbose,
|
||||||
|
model_response,
|
||||||
|
messages,
|
||||||
|
encoding,
|
||||||
|
):
|
||||||
|
import google.generativeai as genai
|
||||||
|
|
||||||
|
response = await _model.generate_content_async(
|
||||||
|
contents=prompt,
|
||||||
|
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||||
|
safety_settings=safety_settings,
|
||||||
|
)
|
||||||
|
|
||||||
|
## LOGGING
|
||||||
|
logging_obj.post_call(
|
||||||
|
input=prompt,
|
||||||
|
api_key="",
|
||||||
|
original_response=response,
|
||||||
|
additional_args={"complete_input_dict": {}},
|
||||||
|
)
|
||||||
|
print_verbose(f"raw model_response: {response}")
|
||||||
|
## RESPONSE OBJECT
|
||||||
|
completion_response = response
|
||||||
|
try:
|
||||||
|
choices_list = []
|
||||||
|
for idx, item in enumerate(completion_response.candidates):
|
||||||
|
if len(item.content.parts) > 0:
|
||||||
|
message_obj = Message(content=item.content.parts[0].text)
|
||||||
|
else:
|
||||||
|
message_obj = Message(content=None)
|
||||||
|
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||||
|
choices_list.append(choice_obj)
|
||||||
|
model_response["choices"] = choices_list
|
||||||
|
except Exception as e:
|
||||||
|
traceback.print_exc()
|
||||||
|
raise GeminiError(
|
||||||
|
message=traceback.format_exc(), status_code=response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
completion_response = model_response["choices"][0]["message"].get("content")
|
||||||
|
if completion_response is None:
|
||||||
|
raise Exception
|
||||||
|
except:
|
||||||
|
original_response = f"response: {response}"
|
||||||
|
if hasattr(response, "candidates"):
|
||||||
|
original_response = f"response: {response.candidates}"
|
||||||
|
if "SAFETY" in original_response:
|
||||||
|
original_response += (
|
||||||
|
"\nThe candidate content was flagged for safety reasons."
|
||||||
|
)
|
||||||
|
elif "RECITATION" in original_response:
|
||||||
|
original_response += (
|
||||||
|
"\nThe candidate content was flagged for recitation reasons."
|
||||||
|
)
|
||||||
|
raise GeminiError(
|
||||||
|
status_code=400,
|
||||||
|
message=f"No response received. Original response - {original_response}",
|
||||||
|
)
|
||||||
|
|
||||||
|
## CALCULATING USAGE
|
||||||
|
prompt_str = ""
|
||||||
|
for m in messages:
|
||||||
|
if isinstance(m["content"], str):
|
||||||
|
prompt_str += m["content"]
|
||||||
|
elif isinstance(m["content"], list):
|
||||||
|
for content in m["content"]:
|
||||||
|
if content["type"] == "text":
|
||||||
|
prompt_str += content["text"]
|
||||||
|
prompt_tokens = len(encoding.encode(prompt_str))
|
||||||
|
completion_tokens = len(
|
||||||
|
encoding.encode(model_response["choices"][0]["message"].get("content", ""))
|
||||||
|
)
|
||||||
|
|
||||||
|
model_response["created"] = int(time.time())
|
||||||
|
model_response["model"] = "gemini/" + model
|
||||||
|
usage = Usage(
|
||||||
|
prompt_tokens=prompt_tokens,
|
||||||
|
completion_tokens=completion_tokens,
|
||||||
|
total_tokens=prompt_tokens + completion_tokens,
|
||||||
|
)
|
||||||
|
model_response.usage = usage
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
def embedding():
|
def embedding():
|
||||||
# logic for parsing in - calling - parsing out model embedding calls
|
# logic for parsing in - calling - parsing out model embedding calls
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -264,6 +264,7 @@ async def acompletion(
|
||||||
or custom_llm_provider == "ollama"
|
or custom_llm_provider == "ollama"
|
||||||
or custom_llm_provider == "ollama_chat"
|
or custom_llm_provider == "ollama_chat"
|
||||||
or custom_llm_provider == "vertex_ai"
|
or custom_llm_provider == "vertex_ai"
|
||||||
|
or custom_llm_provider == "gemini"
|
||||||
or custom_llm_provider == "sagemaker"
|
or custom_llm_provider == "sagemaker"
|
||||||
or custom_llm_provider in litellm.openai_compatible_providers
|
or custom_llm_provider in litellm.openai_compatible_providers
|
||||||
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
): # currently implemented aiohttp calls for just azure, openai, hf, ollama, vertex ai soon all.
|
||||||
|
|
|
@ -1993,6 +1993,19 @@ def test_completion_gemini():
|
||||||
# test_completion_gemini()
|
# test_completion_gemini()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_gemini():
|
||||||
|
litellm.set_verbose = True
|
||||||
|
model_name = "gemini/gemini-pro"
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(model=model_name, messages=messages)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
print(f"response: {response}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
# Palm tests
|
# Palm tests
|
||||||
def test_completion_palm():
|
def test_completion_palm():
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -429,6 +429,45 @@ def test_completion_gemini_stream():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_acompletion_gemini_stream():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = True
|
||||||
|
print("Streaming gemini response")
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "What do you know?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
print("testing gemini streaming")
|
||||||
|
response = await acompletion(
|
||||||
|
model="gemini/gemini-pro", messages=messages, max_tokens=50, stream=True
|
||||||
|
)
|
||||||
|
print(f"type of response at the top: {response}")
|
||||||
|
complete_response = ""
|
||||||
|
idx = 0
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
async for chunk in response:
|
||||||
|
print(f"chunk in acompletion gemini: {chunk}")
|
||||||
|
print(chunk.choices[0].delta)
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
print(f"chunk: {chunk}")
|
||||||
|
complete_response += chunk
|
||||||
|
idx += 1
|
||||||
|
print(f"completion_response: {complete_response}")
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# asyncio.run(test_acompletion_gemini_stream())
|
||||||
|
|
||||||
|
|
||||||
def test_completion_mistral_api_stream():
|
def test_completion_mistral_api_stream():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -8417,7 +8417,28 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider == "gemini":
|
elif self.custom_llm_provider == "gemini":
|
||||||
completion_obj["content"] = chunk.text
|
try:
|
||||||
|
if hasattr(chunk, "parts") == True:
|
||||||
|
try:
|
||||||
|
if len(chunk.parts) > 0:
|
||||||
|
completion_obj["content"] = chunk.parts[0].text
|
||||||
|
if hasattr(chunk.parts[0], "finish_reason"):
|
||||||
|
model_response.choices[0].finish_reason = (
|
||||||
|
map_finish_reason(chunk.parts[0].finish_reason.name)
|
||||||
|
)
|
||||||
|
except:
|
||||||
|
if chunk.parts[0].finish_reason.name == "SAFETY":
|
||||||
|
raise Exception(
|
||||||
|
f"The response was blocked by VertexAI. {str(chunk)}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
completion_obj["content"] = str(chunk)
|
||||||
|
except StopIteration as e:
|
||||||
|
if self.sent_last_chunk:
|
||||||
|
raise e
|
||||||
|
else:
|
||||||
|
model_response.choices[0].finish_reason = "stop"
|
||||||
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||||
try:
|
try:
|
||||||
if hasattr(chunk, "candidates") == True:
|
if hasattr(chunk, "candidates") == True:
|
||||||
|
@ -8727,19 +8748,19 @@ class CustomStreamWrapper:
|
||||||
or self.custom_llm_provider == "ollama_chat"
|
or self.custom_llm_provider == "ollama_chat"
|
||||||
or self.custom_llm_provider == "vertex_ai"
|
or self.custom_llm_provider == "vertex_ai"
|
||||||
or self.custom_llm_provider == "sagemaker"
|
or self.custom_llm_provider == "sagemaker"
|
||||||
|
or self.custom_llm_provider == "gemini"
|
||||||
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
or self.custom_llm_provider in litellm.openai_compatible_endpoints
|
||||||
):
|
):
|
||||||
print_verbose(
|
|
||||||
f"value of async completion stream: {self.completion_stream}"
|
|
||||||
)
|
|
||||||
async for chunk in self.completion_stream:
|
async for chunk in self.completion_stream:
|
||||||
print_verbose(f"value of async chunk: {chunk}")
|
print_verbose(f"value of async chunk: {chunk}")
|
||||||
if chunk == "None" or chunk is None:
|
if chunk == "None" or chunk is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
|
elif self.custom_llm_provider == "gemini" and len(chunk.parts) == 0:
|
||||||
|
continue
|
||||||
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
|
# chunk_creator() does logging/stream chunk building. We need to let it know its being called in_async_func, so we don't double add chunks.
|
||||||
# __anext__ also calls async_success_handler, which does logging
|
# __anext__ also calls async_success_handler, which does logging
|
||||||
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
print_verbose(f"PROCESSED ASYNC CHUNK PRE CHUNK CREATOR: {chunk}")
|
||||||
|
|
||||||
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
processed_chunk: Optional[ModelResponse] = self.chunk_creator(
|
||||||
chunk=chunk
|
chunk=chunk
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue