mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(gemini.py): support streaming
This commit is contained in:
parent
b59e67f099
commit
b07677c6be
4 changed files with 67 additions and 13 deletions
|
@ -120,9 +120,7 @@ def completion(
|
|||
|
||||
## Load Config
|
||||
inference_params = copy.deepcopy(optional_params)
|
||||
inference_params.pop(
|
||||
"stream", None
|
||||
) # palm does not support streaming, so we handle this by fake streaming in main.py
|
||||
stream = inference_params.pop("stream", None)
|
||||
config = litellm.GeminiConfig.get_config()
|
||||
for k, v in config.items():
|
||||
if (
|
||||
|
@ -139,10 +137,18 @@ def completion(
|
|||
## COMPLETION CALL
|
||||
try:
|
||||
_model = genai.GenerativeModel(f"models/{model}")
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
)
|
||||
if stream != True:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
)
|
||||
else:
|
||||
response = _model.generate_content(
|
||||
contents=prompt,
|
||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||
stream=True,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
raise GeminiError(
|
||||
message=str(e),
|
||||
|
@ -184,9 +190,13 @@ def completion(
|
|||
if hasattr(response, "candidates"):
|
||||
original_response = f"response: {response.candidates}"
|
||||
if "SAFETY" in original_response:
|
||||
original_response += "\nThe candidate content was flagged for safety reasons."
|
||||
original_response += (
|
||||
"\nThe candidate content was flagged for safety reasons."
|
||||
)
|
||||
elif "RECITATION" in original_response:
|
||||
original_response += "\nThe candidate content was flagged for recitation reasons."
|
||||
original_response += (
|
||||
"\nThe candidate content was flagged for recitation reasons."
|
||||
)
|
||||
raise GeminiError(
|
||||
status_code=400,
|
||||
message=f"No response received. Original response - {original_response}",
|
||||
|
|
|
@ -1382,6 +1382,18 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and acompletion == False
|
||||
):
|
||||
response = CustomStreamWrapper(
|
||||
iter(model_response),
|
||||
model,
|
||||
custom_llm_provider="gemini",
|
||||
logging_obj=logging,
|
||||
)
|
||||
return response
|
||||
response = model_response
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT")
|
||||
|
|
|
@ -398,6 +398,36 @@ def test_completion_palm_stream():
|
|||
# test_completion_palm_stream()
|
||||
|
||||
|
||||
def test_completion_gemini_stream():
|
||||
try:
|
||||
litellm.set_verbose = False
|
||||
print("Streaming gemini response")
|
||||
messages = [
|
||||
{"role": "system", "content": "You are a helpful assistant."},
|
||||
{
|
||||
"role": "user",
|
||||
"content": "how does a court case get to the Supreme Court?",
|
||||
},
|
||||
]
|
||||
print("testing gemini streaming")
|
||||
response = completion(model="gemini/gemini-pro", messages=messages, stream=True)
|
||||
print(f"type of response at the top: {response}")
|
||||
complete_response = ""
|
||||
# Add any assertions here to check the response
|
||||
for idx, chunk in enumerate(response):
|
||||
print(chunk)
|
||||
# print(chunk.choices[0].delta)
|
||||
chunk, finished = streaming_format_tests(idx, chunk)
|
||||
if finished:
|
||||
break
|
||||
complete_response += chunk
|
||||
if complete_response.strip() == "":
|
||||
raise Exception("Empty response received")
|
||||
print(f"completion_response: {complete_response}")
|
||||
except Exception as e:
|
||||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
def test_completion_mistral_api_stream():
|
||||
try:
|
||||
litellm.set_verbose = True
|
||||
|
|
|
@ -7622,7 +7622,9 @@ class CustomStreamWrapper:
|
|||
raise Exception("An unknown error occurred with the stream")
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
self.sent_last_chunk = True
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
|
||||
elif self.custom_llm_provider == "gemini":
|
||||
completion_obj["content"] = chunk.text
|
||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||
try:
|
||||
# print(chunk)
|
||||
if hasattr(chunk, "text"):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue