mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(gemini.py): support streaming
This commit is contained in:
parent
b59e67f099
commit
b07677c6be
4 changed files with 67 additions and 13 deletions
|
@ -120,9 +120,7 @@ def completion(
|
||||||
|
|
||||||
## Load Config
|
## Load Config
|
||||||
inference_params = copy.deepcopy(optional_params)
|
inference_params = copy.deepcopy(optional_params)
|
||||||
inference_params.pop(
|
stream = inference_params.pop("stream", None)
|
||||||
"stream", None
|
|
||||||
) # palm does not support streaming, so we handle this by fake streaming in main.py
|
|
||||||
config = litellm.GeminiConfig.get_config()
|
config = litellm.GeminiConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
if (
|
if (
|
||||||
|
@ -139,10 +137,18 @@ def completion(
|
||||||
## COMPLETION CALL
|
## COMPLETION CALL
|
||||||
try:
|
try:
|
||||||
_model = genai.GenerativeModel(f"models/{model}")
|
_model = genai.GenerativeModel(f"models/{model}")
|
||||||
response = _model.generate_content(
|
if stream != True:
|
||||||
contents=prompt,
|
response = _model.generate_content(
|
||||||
generation_config=genai.types.GenerationConfig(**inference_params),
|
contents=prompt,
|
||||||
)
|
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
response = _model.generate_content(
|
||||||
|
contents=prompt,
|
||||||
|
generation_config=genai.types.GenerationConfig(**inference_params),
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
return response
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise GeminiError(
|
raise GeminiError(
|
||||||
message=str(e),
|
message=str(e),
|
||||||
|
@ -177,16 +183,20 @@ def completion(
|
||||||
|
|
||||||
try:
|
try:
|
||||||
completion_response = model_response["choices"][0]["message"].get("content")
|
completion_response = model_response["choices"][0]["message"].get("content")
|
||||||
if completion_response is None:
|
if completion_response is None:
|
||||||
raise Exception
|
raise Exception
|
||||||
except:
|
except:
|
||||||
original_response = f"response: {response}"
|
original_response = f"response: {response}"
|
||||||
if hasattr(response, "candidates"):
|
if hasattr(response, "candidates"):
|
||||||
original_response = f"response: {response.candidates}"
|
original_response = f"response: {response.candidates}"
|
||||||
if "SAFETY" in original_response:
|
if "SAFETY" in original_response:
|
||||||
original_response += "\nThe candidate content was flagged for safety reasons."
|
original_response += (
|
||||||
|
"\nThe candidate content was flagged for safety reasons."
|
||||||
|
)
|
||||||
elif "RECITATION" in original_response:
|
elif "RECITATION" in original_response:
|
||||||
original_response += "\nThe candidate content was flagged for recitation reasons."
|
original_response += (
|
||||||
|
"\nThe candidate content was flagged for recitation reasons."
|
||||||
|
)
|
||||||
raise GeminiError(
|
raise GeminiError(
|
||||||
status_code=400,
|
status_code=400,
|
||||||
message=f"No response received. Original response - {original_response}",
|
message=f"No response received. Original response - {original_response}",
|
||||||
|
|
|
@ -1382,6 +1382,18 @@ def completion(
|
||||||
acompletion=acompletion,
|
acompletion=acompletion,
|
||||||
custom_prompt_dict=custom_prompt_dict,
|
custom_prompt_dict=custom_prompt_dict,
|
||||||
)
|
)
|
||||||
|
if (
|
||||||
|
"stream" in optional_params
|
||||||
|
and optional_params["stream"] == True
|
||||||
|
and acompletion == False
|
||||||
|
):
|
||||||
|
response = CustomStreamWrapper(
|
||||||
|
iter(model_response),
|
||||||
|
model,
|
||||||
|
custom_llm_provider="gemini",
|
||||||
|
logging_obj=logging,
|
||||||
|
)
|
||||||
|
return response
|
||||||
response = model_response
|
response = model_response
|
||||||
elif custom_llm_provider == "vertex_ai":
|
elif custom_llm_provider == "vertex_ai":
|
||||||
vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT")
|
vertex_ai_project = litellm.vertex_project or get_secret("VERTEXAI_PROJECT")
|
||||||
|
|
|
@ -398,6 +398,36 @@ def test_completion_palm_stream():
|
||||||
# test_completion_palm_stream()
|
# test_completion_palm_stream()
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_gemini_stream():
|
||||||
|
try:
|
||||||
|
litellm.set_verbose = False
|
||||||
|
print("Streaming gemini response")
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "how does a court case get to the Supreme Court?",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
print("testing gemini streaming")
|
||||||
|
response = completion(model="gemini/gemini-pro", messages=messages, stream=True)
|
||||||
|
print(f"type of response at the top: {response}")
|
||||||
|
complete_response = ""
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
for idx, chunk in enumerate(response):
|
||||||
|
print(chunk)
|
||||||
|
# print(chunk.choices[0].delta)
|
||||||
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
|
if finished:
|
||||||
|
break
|
||||||
|
complete_response += chunk
|
||||||
|
if complete_response.strip() == "":
|
||||||
|
raise Exception("Empty response received")
|
||||||
|
print(f"completion_response: {complete_response}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
def test_completion_mistral_api_stream():
|
def test_completion_mistral_api_stream():
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
|
|
|
@ -7622,7 +7622,9 @@ class CustomStreamWrapper:
|
||||||
raise Exception("An unknown error occurred with the stream")
|
raise Exception("An unknown error occurred with the stream")
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
self.sent_last_chunk = True
|
self.sent_last_chunk = True
|
||||||
elif self.custom_llm_provider and self.custom_llm_provider == "vertex_ai":
|
elif self.custom_llm_provider == "gemini":
|
||||||
|
completion_obj["content"] = chunk.text
|
||||||
|
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||||
try:
|
try:
|
||||||
# print(chunk)
|
# print(chunk)
|
||||||
if hasattr(chunk, "text"):
|
if hasattr(chunk, "text"):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue