forked from phoenix/litellm-mirror
Merge pull request #4579 from BerriAI/litellm_gemini_stream_tool_calling
fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini
This commit is contained in:
commit
e835f7336a
5 changed files with 105 additions and 12 deletions
|
@ -630,7 +630,11 @@ class Logging:
|
|||
model_call_details=self.model_call_details
|
||||
),
|
||||
call_type=self.call_type,
|
||||
optional_params=self.optional_params,
|
||||
optional_params=(
|
||||
self.optional_params
|
||||
if hasattr(self, "optional_params")
|
||||
else {}
|
||||
),
|
||||
)
|
||||
)
|
||||
if self.dynamic_success_callbacks is not None and isinstance(
|
||||
|
|
|
@ -1330,17 +1330,30 @@ class ModelResponseIterator:
|
|||
|
||||
gemini_chunk = processed_chunk["candidates"][0]
|
||||
|
||||
if (
|
||||
"content" in gemini_chunk
|
||||
and "text" in gemini_chunk["content"]["parts"][0]
|
||||
):
|
||||
if "content" in gemini_chunk:
|
||||
if "text" in gemini_chunk["content"]["parts"][0]:
|
||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
||||
function_call = ChatCompletionToolCallFunctionChunk(
|
||||
name=gemini_chunk["content"]["parts"][0]["functionCall"][
|
||||
"name"
|
||||
],
|
||||
arguments=json.dumps(
|
||||
gemini_chunk["content"]["parts"][0]["functionCall"]["args"]
|
||||
),
|
||||
)
|
||||
tool_use = ChatCompletionToolCallChunk(
|
||||
id=str(uuid.uuid4()),
|
||||
type="function",
|
||||
function=function_call,
|
||||
index=0,
|
||||
)
|
||||
|
||||
if "finishReason" in gemini_chunk:
|
||||
finish_reason = map_finish_reason(
|
||||
finish_reason=gemini_chunk["finishReason"]
|
||||
)
|
||||
## DO NOT SET 'finish_reason' = True
|
||||
## DO NOT SET 'is_finished' = True
|
||||
## GEMINI SETS FINISHREASON ON EVERY CHUNK!
|
||||
|
||||
if "usageMetadata" in processed_chunk:
|
||||
|
|
|
@ -12,6 +12,9 @@ from typing import Tuple
|
|||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import litellm.litellm_core_utils
|
||||
import litellm.litellm_core_utils.litellm_logging
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
|
@ -3033,8 +3036,11 @@ def test_completion_claude_3_function_call_with_streaming():
|
|||
pytest.fail(f"Error occurred: {e}")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model", ["gemini/gemini-1.5-flash"]
|
||||
) # "claude-3-opus-20240229",
|
||||
@pytest.mark.asyncio
|
||||
async def test_acompletion_claude_3_function_call_with_streaming():
|
||||
async def test_acompletion_claude_3_function_call_with_streaming(model):
|
||||
litellm.set_verbose = True
|
||||
tools = [
|
||||
{
|
||||
|
@ -3065,7 +3071,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
|
|||
try:
|
||||
# test without max tokens
|
||||
response = await acompletion(
|
||||
model="claude-3-opus-20240229",
|
||||
model=model,
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
tool_choice="required",
|
||||
|
@ -3452,3 +3458,55 @@ def test_aamazing_unit_test_custom_stream_wrapper_n():
|
|||
assert (
|
||||
chunk_dict == chunks[idx]
|
||||
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"
|
||||
|
||||
|
||||
def test_unit_test_custom_stream_wrapper_function_call():
|
||||
"""
|
||||
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
|
||||
"""
|
||||
from litellm.types.llms.openai import ChatCompletionDeltaChunk
|
||||
|
||||
litellm.set_verbose = False
|
||||
delta: ChatCompletionDeltaChunk = {
|
||||
"content": None,
|
||||
"role": "assistant",
|
||||
"tool_calls": [
|
||||
{
|
||||
"function": {"arguments": '"}'},
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
}
|
||||
],
|
||||
}
|
||||
chunk = {
|
||||
"id": "chatcmpl-123",
|
||||
"object": "chat.completion.chunk",
|
||||
"created": 1694268190,
|
||||
"model": "gpt-3.5-turbo-0125",
|
||||
"system_fingerprint": "fp_44709d6fcb",
|
||||
"choices": [{"index": 0, "delta": delta, "finish_reason": "stop"}],
|
||||
}
|
||||
chunk = litellm.ModelResponse(**chunk, stream=True)
|
||||
|
||||
completion_stream = ModelResponseIterator(model_response=chunk)
|
||||
|
||||
response = litellm.CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
model="gpt-3.5-turbo",
|
||||
custom_llm_provider="cached_response",
|
||||
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hey"}],
|
||||
stream=True,
|
||||
call_type="completion",
|
||||
start_time=time.time(),
|
||||
litellm_call_id="12345",
|
||||
function_id="1245",
|
||||
),
|
||||
)
|
||||
|
||||
finish_reason: Optional[str] = None
|
||||
for chunk in response:
|
||||
if chunk.choices[0].finish_reason is not None:
|
||||
finish_reason = chunk.choices[0].finish_reason
|
||||
assert finish_reason == "tool_calls"
|
||||
|
|
|
@ -300,7 +300,7 @@ class ListBatchRequest(TypedDict, total=False):
|
|||
timeout: Optional[float]
|
||||
|
||||
|
||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
||||
class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
|
||||
name: Optional[str]
|
||||
arguments: str
|
||||
|
||||
|
@ -312,7 +312,7 @@ class ChatCompletionToolCallChunk(TypedDict):
|
|||
index: int
|
||||
|
||||
|
||||
class ChatCompletionDeltaToolCallChunk(TypedDict):
|
||||
class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
|
||||
id: str
|
||||
type: Literal["function"]
|
||||
function: ChatCompletionToolCallFunctionChunk
|
||||
|
|
|
@ -7970,6 +7970,7 @@ class CustomStreamWrapper:
|
|||
)
|
||||
self.messages = getattr(logging_obj, "messages", None)
|
||||
self.sent_stream_usage = False
|
||||
self.tool_call = False
|
||||
self.chunks: List = (
|
||||
[]
|
||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||
|
@ -9212,9 +9213,16 @@ class CustomStreamWrapper:
|
|||
"is_finished": True,
|
||||
"finish_reason": chunk.choices[0].finish_reason,
|
||||
"original_chunk": chunk,
|
||||
"tool_calls": (
|
||||
chunk.choices[0].delta.tool_calls
|
||||
if hasattr(chunk.choices[0].delta, "tool_calls")
|
||||
else None
|
||||
),
|
||||
}
|
||||
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["tool_calls"] is not None:
|
||||
completion_obj["tool_calls"] = response_obj["tool_calls"]
|
||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||
if hasattr(chunk, "id"):
|
||||
model_response.id = chunk.id
|
||||
|
@ -9372,6 +9380,10 @@ class CustomStreamWrapper:
|
|||
)
|
||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
|
||||
## CHECK FOR TOOL USE
|
||||
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||
self.tool_call = True
|
||||
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
|
@ -9550,6 +9562,12 @@ class CustomStreamWrapper:
|
|||
)
|
||||
else:
|
||||
model_response.choices[0].finish_reason = "stop"
|
||||
|
||||
## if tool use
|
||||
if (
|
||||
model_response.choices[0].finish_reason == "stop" and self.tool_call
|
||||
): # don't overwrite for other - potential error finish reasons
|
||||
model_response.choices[0].finish_reason = "tool_calls"
|
||||
return model_response
|
||||
|
||||
def __next__(self):
|
||||
|
@ -9603,7 +9621,7 @@ class CustomStreamWrapper:
|
|||
return response
|
||||
|
||||
except StopIteration:
|
||||
if self.sent_last_chunk == True:
|
||||
if self.sent_last_chunk is True:
|
||||
if (
|
||||
self.sent_stream_usage == False
|
||||
and self.stream_options is not None
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue