Merge pull request #4579 from BerriAI/litellm_gemini_stream_tool_calling

fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini
This commit is contained in:
Krish Dholakia 2024-07-06 19:07:37 -07:00 committed by GitHub
commit e835f7336a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 105 additions and 12 deletions

View file

@ -630,7 +630,11 @@ class Logging:
model_call_details=self.model_call_details
),
call_type=self.call_type,
optional_params=self.optional_params,
optional_params=(
self.optional_params
if hasattr(self, "optional_params")
else {}
),
)
)
if self.dynamic_success_callbacks is not None and isinstance(

View file

@ -1330,17 +1330,30 @@ class ModelResponseIterator:
gemini_chunk = processed_chunk["candidates"][0]
if (
"content" in gemini_chunk
and "text" in gemini_chunk["content"]["parts"][0]
):
text = gemini_chunk["content"]["parts"][0]["text"]
if "content" in gemini_chunk:
if "text" in gemini_chunk["content"]["parts"][0]:
text = gemini_chunk["content"]["parts"][0]["text"]
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
function_call = ChatCompletionToolCallFunctionChunk(
name=gemini_chunk["content"]["parts"][0]["functionCall"][
"name"
],
arguments=json.dumps(
gemini_chunk["content"]["parts"][0]["functionCall"]["args"]
),
)
tool_use = ChatCompletionToolCallChunk(
id=str(uuid.uuid4()),
type="function",
function=function_call,
index=0,
)
if "finishReason" in gemini_chunk:
finish_reason = map_finish_reason(
finish_reason=gemini_chunk["finishReason"]
)
## DO NOT SET 'finish_reason' = True
## DO NOT SET 'is_finished' = True
## GEMINI SETS FINISHREASON ON EVERY CHUNK!
if "usageMetadata" in processed_chunk:

View file

@ -12,6 +12,9 @@ from typing import Tuple
import pytest
from pydantic import BaseModel
import litellm.litellm_core_utils
import litellm.litellm_core_utils.litellm_logging
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
@ -3033,8 +3036,11 @@ def test_completion_claude_3_function_call_with_streaming():
pytest.fail(f"Error occurred: {e}")
@pytest.mark.parametrize(
"model", ["gemini/gemini-1.5-flash"]
) # "claude-3-opus-20240229",
@pytest.mark.asyncio
async def test_acompletion_claude_3_function_call_with_streaming():
async def test_acompletion_claude_3_function_call_with_streaming(model):
litellm.set_verbose = True
tools = [
{
@ -3065,7 +3071,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
try:
# test without max tokens
response = await acompletion(
model="claude-3-opus-20240229",
model=model,
messages=messages,
tools=tools,
tool_choice="required",
@ -3452,3 +3458,55 @@ def test_aamazing_unit_test_custom_stream_wrapper_n():
assert (
chunk_dict == chunks[idx]
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"
def test_unit_test_custom_stream_wrapper_function_call():
"""
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
"""
from litellm.types.llms.openai import ChatCompletionDeltaChunk
litellm.set_verbose = False
delta: ChatCompletionDeltaChunk = {
"content": None,
"role": "assistant",
"tool_calls": [
{
"function": {"arguments": '"}'},
"type": "function",
"index": 0,
}
],
}
chunk = {
"id": "chatcmpl-123",
"object": "chat.completion.chunk",
"created": 1694268190,
"model": "gpt-3.5-turbo-0125",
"system_fingerprint": "fp_44709d6fcb",
"choices": [{"index": 0, "delta": delta, "finish_reason": "stop"}],
}
chunk = litellm.ModelResponse(**chunk, stream=True)
completion_stream = ModelResponseIterator(model_response=chunk)
response = litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model="gpt-3.5-turbo",
custom_llm_provider="cached_response",
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
finish_reason: Optional[str] = None
for chunk in response:
if chunk.choices[0].finish_reason is not None:
finish_reason = chunk.choices[0].finish_reason
assert finish_reason == "tool_calls"

View file

@ -300,7 +300,7 @@ class ListBatchRequest(TypedDict, total=False):
timeout: Optional[float]
class ChatCompletionToolCallFunctionChunk(TypedDict):
class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
name: Optional[str]
arguments: str
@ -312,7 +312,7 @@ class ChatCompletionToolCallChunk(TypedDict):
index: int
class ChatCompletionDeltaToolCallChunk(TypedDict):
class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
id: str
type: Literal["function"]
function: ChatCompletionToolCallFunctionChunk

View file

@ -7970,6 +7970,7 @@ class CustomStreamWrapper:
)
self.messages = getattr(logging_obj, "messages", None)
self.sent_stream_usage = False
self.tool_call = False
self.chunks: List = (
[]
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
@ -9212,9 +9213,16 @@ class CustomStreamWrapper:
"is_finished": True,
"finish_reason": chunk.choices[0].finish_reason,
"original_chunk": chunk,
"tool_calls": (
chunk.choices[0].delta.tool_calls
if hasattr(chunk.choices[0].delta, "tool_calls")
else None
),
}
completion_obj["content"] = response_obj["text"]
if response_obj["tool_calls"] is not None:
completion_obj["tool_calls"] = response_obj["tool_calls"]
print_verbose(f"completion obj content: {completion_obj['content']}")
if hasattr(chunk, "id"):
model_response.id = chunk.id
@ -9372,6 +9380,10 @@ class CustomStreamWrapper:
)
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## CHECK FOR TOOL USE
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
self.tool_call = True
## RETURN ARG
if (
"content" in completion_obj
@ -9550,6 +9562,12 @@ class CustomStreamWrapper:
)
else:
model_response.choices[0].finish_reason = "stop"
## if tool use
if (
model_response.choices[0].finish_reason == "stop" and self.tool_call
): # don't overwrite for other - potential error finish reasons
model_response.choices[0].finish_reason = "tool_calls"
return model_response
def __next__(self):
@ -9603,7 +9621,7 @@ class CustomStreamWrapper:
return response
except StopIteration:
if self.sent_last_chunk == True:
if self.sent_last_chunk is True:
if (
self.sent_stream_usage == False
and self.stream_options is not None