feat(utils.py): support gemini/vertex ai streaming function param usage

This commit is contained in:
Krrish Dholakia 2024-08-26 11:23:45 -07:00
parent aedc6652d4
commit c3db2d8bbf
2 changed files with 57 additions and 9 deletions

View file

@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
print("Streaming gemini response") print("Streaming gemini response")
messages = [ function1 = [
{"role": "system", "content": "You are a helpful assistant."},
{ {
"role": "user", "name": "get_current_weather",
"content": "Who was Alexander?", "description": "Get the current weather in a given location",
}, "parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
] ]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
print("testing gemini streaming") print("testing gemini streaming")
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
non_empty_chunks = 0 non_empty_chunks = 0
chunks = []
if sync_mode: if sync_mode:
response = completion( response = completion(
model="gemini/gemini-1.5-flash", model="gemini/gemini-1.5-flash",
messages=messages, messages=messages,
stream=True, stream=True,
functions=function1,
) )
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
print(chunk) print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta) # print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode):
model="gemini/gemini-1.5-flash", model="gemini/gemini-1.5-flash",
messages=messages, messages=messages,
stream=True, stream=True,
functions=function1,
) )
idx = 0 idx = 0
async for chunk in response: async for chunk in response:
print(chunk) print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta) # print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode):
complete_response += chunk complete_response += chunk
idx += 1 idx += 1
if complete_response.strip() == "": # if complete_response.strip() == "":
raise Exception("Empty response received") # raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
assert non_empty_chunks > 1
complete_response = litellm.stream_chunk_builder(
chunks=chunks, messages=messages
)
assert complete_response.choices[0].message.function_call is not None
# assert non_empty_chunks > 1
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
pass pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:

View file

@ -8771,6 +8771,7 @@ class CustomStreamWrapper:
self.chunks: List = ( self.chunks: List = (
[] []
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options ) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
def __iter__(self): def __iter__(self):
return self return self
@ -8778,6 +8779,19 @@ class CustomStreamWrapper:
def __aiter__(self): def __aiter__(self):
return self return self
def check_is_function_call(self, logging_obj) -> bool:
if hasattr(logging_obj, "optional_params") and isinstance(
logging_obj.optional_params, dict
):
if (
"litellm_param_is_function_call" in logging_obj.optional_params
and logging_obj.optional_params["litellm_param_is_function_call"]
is not None
):
return True
return False
def process_chunk(self, chunk: str): def process_chunk(self, chunk: str):
""" """
NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta.
@ -10275,6 +10289,12 @@ class CustomStreamWrapper:
## CHECK FOR TOOL USE ## CHECK FOR TOOL USE
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0: if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
if self.is_function_call is True: # user passed in 'functions' param
completion_obj["function_call"] = completion_obj["tool_calls"][0][
"function"
]
completion_obj["tool_calls"] = None
self.tool_call = True self.tool_call = True
## RETURN ARG ## RETURN ARG
@ -10286,8 +10306,13 @@ class CustomStreamWrapper:
) )
or ( or (
"tool_calls" in completion_obj "tool_calls" in completion_obj
and completion_obj["tool_calls"] is not None
and len(completion_obj["tool_calls"]) > 0 and len(completion_obj["tool_calls"]) > 0
) )
or (
"function_call" in completion_obj
and completion_obj["function_call"] is not None
)
): # cannot set content of an OpenAI Object to be an empty string ): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker() self.safety_checker()
hold, model_response_str = self.check_special_tokens( hold, model_response_str = self.check_special_tokens(
@ -10347,6 +10372,7 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
if completion_obj.get("index") is not None: if completion_obj.get("index") is not None:
model_response.choices[0].index = completion_obj.get( model_response.choices[0].index = completion_obj.get(