forked from phoenix/litellm-mirror
Merge pull request #4579 from BerriAI/litellm_gemini_stream_tool_calling
fix(vertex_httpx.py): support tool calling w/ streaming for vertex ai + gemini
This commit is contained in:
commit
e835f7336a
5 changed files with 105 additions and 12 deletions
|
@ -630,7 +630,11 @@ class Logging:
|
||||||
model_call_details=self.model_call_details
|
model_call_details=self.model_call_details
|
||||||
),
|
),
|
||||||
call_type=self.call_type,
|
call_type=self.call_type,
|
||||||
optional_params=self.optional_params,
|
optional_params=(
|
||||||
|
self.optional_params
|
||||||
|
if hasattr(self, "optional_params")
|
||||||
|
else {}
|
||||||
|
),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if self.dynamic_success_callbacks is not None and isinstance(
|
if self.dynamic_success_callbacks is not None and isinstance(
|
||||||
|
|
|
@ -1330,17 +1330,30 @@ class ModelResponseIterator:
|
||||||
|
|
||||||
gemini_chunk = processed_chunk["candidates"][0]
|
gemini_chunk = processed_chunk["candidates"][0]
|
||||||
|
|
||||||
if (
|
if "content" in gemini_chunk:
|
||||||
"content" in gemini_chunk
|
if "text" in gemini_chunk["content"]["parts"][0]:
|
||||||
and "text" in gemini_chunk["content"]["parts"][0]
|
|
||||||
):
|
|
||||||
text = gemini_chunk["content"]["parts"][0]["text"]
|
text = gemini_chunk["content"]["parts"][0]["text"]
|
||||||
|
elif "functionCall" in gemini_chunk["content"]["parts"][0]:
|
||||||
|
function_call = ChatCompletionToolCallFunctionChunk(
|
||||||
|
name=gemini_chunk["content"]["parts"][0]["functionCall"][
|
||||||
|
"name"
|
||||||
|
],
|
||||||
|
arguments=json.dumps(
|
||||||
|
gemini_chunk["content"]["parts"][0]["functionCall"]["args"]
|
||||||
|
),
|
||||||
|
)
|
||||||
|
tool_use = ChatCompletionToolCallChunk(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
type="function",
|
||||||
|
function=function_call,
|
||||||
|
index=0,
|
||||||
|
)
|
||||||
|
|
||||||
if "finishReason" in gemini_chunk:
|
if "finishReason" in gemini_chunk:
|
||||||
finish_reason = map_finish_reason(
|
finish_reason = map_finish_reason(
|
||||||
finish_reason=gemini_chunk["finishReason"]
|
finish_reason=gemini_chunk["finishReason"]
|
||||||
)
|
)
|
||||||
## DO NOT SET 'finish_reason' = True
|
## DO NOT SET 'is_finished' = True
|
||||||
## GEMINI SETS FINISHREASON ON EVERY CHUNK!
|
## GEMINI SETS FINISHREASON ON EVERY CHUNK!
|
||||||
|
|
||||||
if "usageMetadata" in processed_chunk:
|
if "usageMetadata" in processed_chunk:
|
||||||
|
|
|
@ -12,6 +12,9 @@ from typing import Tuple
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
import litellm.litellm_core_utils
|
||||||
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
|
@ -3033,8 +3036,11 @@ def test_completion_claude_3_function_call_with_streaming():
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"model", ["gemini/gemini-1.5-flash"]
|
||||||
|
) # "claude-3-opus-20240229",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_acompletion_claude_3_function_call_with_streaming():
|
async def test_acompletion_claude_3_function_call_with_streaming(model):
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
tools = [
|
tools = [
|
||||||
{
|
{
|
||||||
|
@ -3065,7 +3071,7 @@ async def test_acompletion_claude_3_function_call_with_streaming():
|
||||||
try:
|
try:
|
||||||
# test without max tokens
|
# test without max tokens
|
||||||
response = await acompletion(
|
response = await acompletion(
|
||||||
model="claude-3-opus-20240229",
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
tool_choice="required",
|
tool_choice="required",
|
||||||
|
@ -3452,3 +3458,55 @@ def test_aamazing_unit_test_custom_stream_wrapper_n():
|
||||||
assert (
|
assert (
|
||||||
chunk_dict == chunks[idx]
|
chunk_dict == chunks[idx]
|
||||||
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"
|
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_unit_test_custom_stream_wrapper_function_call():
|
||||||
|
"""
|
||||||
|
Test if model returns a tool call, the finish reason is correctly set to 'tool_calls'
|
||||||
|
"""
|
||||||
|
from litellm.types.llms.openai import ChatCompletionDeltaChunk
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
|
delta: ChatCompletionDeltaChunk = {
|
||||||
|
"content": None,
|
||||||
|
"role": "assistant",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"function": {"arguments": '"}'},
|
||||||
|
"type": "function",
|
||||||
|
"index": 0,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}
|
||||||
|
chunk = {
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1694268190,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [{"index": 0, "delta": delta, "finish_reason": "stop"}],
|
||||||
|
}
|
||||||
|
chunk = litellm.ModelResponse(**chunk, stream=True)
|
||||||
|
|
||||||
|
completion_stream = ModelResponseIterator(model_response=chunk)
|
||||||
|
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=litellm.litellm_core_utils.litellm_logging.Logging(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey"}],
|
||||||
|
stream=True,
|
||||||
|
call_type="completion",
|
||||||
|
start_time=time.time(),
|
||||||
|
litellm_call_id="12345",
|
||||||
|
function_id="1245",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
finish_reason: Optional[str] = None
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].finish_reason is not None:
|
||||||
|
finish_reason = chunk.choices[0].finish_reason
|
||||||
|
assert finish_reason == "tool_calls"
|
||||||
|
|
|
@ -300,7 +300,7 @@ class ListBatchRequest(TypedDict, total=False):
|
||||||
timeout: Optional[float]
|
timeout: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionToolCallFunctionChunk(TypedDict):
|
class ChatCompletionToolCallFunctionChunk(TypedDict, total=False):
|
||||||
name: Optional[str]
|
name: Optional[str]
|
||||||
arguments: str
|
arguments: str
|
||||||
|
|
||||||
|
@ -312,7 +312,7 @@ class ChatCompletionToolCallChunk(TypedDict):
|
||||||
index: int
|
index: int
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionDeltaToolCallChunk(TypedDict):
|
class ChatCompletionDeltaToolCallChunk(TypedDict, total=False):
|
||||||
id: str
|
id: str
|
||||||
type: Literal["function"]
|
type: Literal["function"]
|
||||||
function: ChatCompletionToolCallFunctionChunk
|
function: ChatCompletionToolCallFunctionChunk
|
||||||
|
|
|
@ -7970,6 +7970,7 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
self.messages = getattr(logging_obj, "messages", None)
|
self.messages = getattr(logging_obj, "messages", None)
|
||||||
self.sent_stream_usage = False
|
self.sent_stream_usage = False
|
||||||
|
self.tool_call = False
|
||||||
self.chunks: List = (
|
self.chunks: List = (
|
||||||
[]
|
[]
|
||||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||||
|
@ -9212,9 +9213,16 @@ class CustomStreamWrapper:
|
||||||
"is_finished": True,
|
"is_finished": True,
|
||||||
"finish_reason": chunk.choices[0].finish_reason,
|
"finish_reason": chunk.choices[0].finish_reason,
|
||||||
"original_chunk": chunk,
|
"original_chunk": chunk,
|
||||||
|
"tool_calls": (
|
||||||
|
chunk.choices[0].delta.tool_calls
|
||||||
|
if hasattr(chunk.choices[0].delta, "tool_calls")
|
||||||
|
else None
|
||||||
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
completion_obj["content"] = response_obj["text"]
|
completion_obj["content"] = response_obj["text"]
|
||||||
|
if response_obj["tool_calls"] is not None:
|
||||||
|
completion_obj["tool_calls"] = response_obj["tool_calls"]
|
||||||
print_verbose(f"completion obj content: {completion_obj['content']}")
|
print_verbose(f"completion obj content: {completion_obj['content']}")
|
||||||
if hasattr(chunk, "id"):
|
if hasattr(chunk, "id"):
|
||||||
model_response.id = chunk.id
|
model_response.id = chunk.id
|
||||||
|
@ -9372,6 +9380,10 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||||
|
|
||||||
|
## CHECK FOR TOOL USE
|
||||||
|
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||||
|
self.tool_call = True
|
||||||
|
|
||||||
## RETURN ARG
|
## RETURN ARG
|
||||||
if (
|
if (
|
||||||
"content" in completion_obj
|
"content" in completion_obj
|
||||||
|
@ -9550,6 +9562,12 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
model_response.choices[0].finish_reason = "stop"
|
model_response.choices[0].finish_reason = "stop"
|
||||||
|
|
||||||
|
## if tool use
|
||||||
|
if (
|
||||||
|
model_response.choices[0].finish_reason == "stop" and self.tool_call
|
||||||
|
): # don't overwrite for other - potential error finish reasons
|
||||||
|
model_response.choices[0].finish_reason = "tool_calls"
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
|
@ -9603,7 +9621,7 @@ class CustomStreamWrapper:
|
||||||
return response
|
return response
|
||||||
|
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
if self.sent_last_chunk == True:
|
if self.sent_last_chunk is True:
|
||||||
if (
|
if (
|
||||||
self.sent_stream_usage == False
|
self.sent_stream_usage == False
|
||||||
and self.stream_options is not None
|
and self.stream_options is not None
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue