mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge pull request #5368 from BerriAI/litellm_vertex_function_support
feat(vertex_httpx.py): support 'functions' param for gemini google ai studio + vertex ai
This commit is contained in:
commit
c503ff435e
7 changed files with 263 additions and 84 deletions
|
@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionResponseMessage,
|
ChatCompletionResponseMessage,
|
||||||
ChatCompletionToolCallChunk,
|
ChatCompletionToolCallChunk,
|
||||||
ChatCompletionToolCallFunctionChunk,
|
ChatCompletionToolCallFunctionChunk,
|
||||||
|
ChatCompletionToolParamFunctionChunk,
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
@ -296,11 +297,50 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
||||||
"stream",
|
"stream",
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
|
"functions",
|
||||||
"response_format",
|
"response_format",
|
||||||
"n",
|
"n",
|
||||||
"stop",
|
"stop",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
def _map_function(self, value: List[dict]) -> List[Tools]:
|
||||||
|
gtool_func_declarations = []
|
||||||
|
googleSearchRetrieval: Optional[dict] = None
|
||||||
|
|
||||||
|
for tool in value:
|
||||||
|
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if "function" in tool: # tools list
|
||||||
|
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
|
||||||
|
**tool["function"]
|
||||||
|
)
|
||||||
|
elif "name" in tool: # functions list
|
||||||
|
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
|
||||||
|
|
||||||
|
# check if grounding
|
||||||
|
if tool.get("googleSearchRetrieval", None) is not None:
|
||||||
|
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||||
|
elif openai_function_object is not None:
|
||||||
|
gtool_func_declaration = FunctionDeclaration(
|
||||||
|
name=openai_function_object["name"],
|
||||||
|
description=openai_function_object.get("description", ""),
|
||||||
|
parameters=openai_function_object.get("parameters", {}),
|
||||||
|
)
|
||||||
|
gtool_func_declarations.append(gtool_func_declaration)
|
||||||
|
else:
|
||||||
|
# assume it's a provider-specific param
|
||||||
|
verbose_logger.warning(
|
||||||
|
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||||
|
)
|
||||||
|
|
||||||
|
_tools = Tools(
|
||||||
|
function_declarations=gtool_func_declarations,
|
||||||
|
)
|
||||||
|
if googleSearchRetrieval is not None:
|
||||||
|
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||||
|
return [_tools]
|
||||||
|
|
||||||
def map_tool_choice_values(
|
def map_tool_choice_values(
|
||||||
self, model: str, tool_choice: Union[str, dict]
|
self, model: str, tool_choice: Union[str, dict]
|
||||||
) -> Optional[ToolConfig]:
|
) -> Optional[ToolConfig]:
|
||||||
|
@ -363,26 +403,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
|
||||||
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
|
||||||
optional_params["response_mime_type"] = "application/json"
|
optional_params["response_mime_type"] = "application/json"
|
||||||
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
|
||||||
if param == "tools" and isinstance(value, list):
|
if (param == "tools" or param == "functions") and isinstance(value, list):
|
||||||
gtool_func_declarations = []
|
optional_params["tools"] = self._map_function(value=value)
|
||||||
for tool in value:
|
optional_params["litellm_param_is_function_call"] = (
|
||||||
_parameters = tool.get("function", {}).get("parameters", {})
|
True if param == "functions" else False
|
||||||
_properties = _parameters.get("properties", {})
|
)
|
||||||
if isinstance(_properties, dict):
|
|
||||||
for _, _property in _properties.items():
|
|
||||||
if "enum" in _property and "format" not in _property:
|
|
||||||
_property["format"] = "enum"
|
|
||||||
|
|
||||||
gtool_func_declaration = FunctionDeclaration(
|
|
||||||
name=tool["function"]["name"],
|
|
||||||
description=tool["function"].get("description", ""),
|
|
||||||
)
|
|
||||||
if len(_parameters.keys()) > 0:
|
|
||||||
gtool_func_declaration["parameters"] = _parameters
|
|
||||||
gtool_func_declarations.append(gtool_func_declaration)
|
|
||||||
optional_params["tools"] = [
|
|
||||||
Tools(function_declarations=gtool_func_declarations)
|
|
||||||
]
|
|
||||||
if param == "tool_choice" and (
|
if param == "tool_choice" and (
|
||||||
isinstance(value, str) or isinstance(value, dict)
|
isinstance(value, str) or isinstance(value, dict)
|
||||||
):
|
):
|
||||||
|
@ -506,6 +531,7 @@ class VertexGeminiConfig:
|
||||||
"max_tokens",
|
"max_tokens",
|
||||||
"stream",
|
"stream",
|
||||||
"tools",
|
"tools",
|
||||||
|
"functions",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
"response_format",
|
"response_format",
|
||||||
"n",
|
"n",
|
||||||
|
@ -541,6 +567,44 @@ class VertexGeminiConfig:
|
||||||
status_code=400,
|
status_code=400,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _map_function(self, value: List[dict]) -> List[Tools]:
|
||||||
|
gtool_func_declarations = []
|
||||||
|
googleSearchRetrieval: Optional[dict] = None
|
||||||
|
|
||||||
|
for tool in value:
|
||||||
|
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
|
||||||
|
None
|
||||||
|
)
|
||||||
|
if "function" in tool: # tools list
|
||||||
|
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
|
||||||
|
**tool["function"]
|
||||||
|
)
|
||||||
|
elif "name" in tool: # functions list
|
||||||
|
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
|
||||||
|
|
||||||
|
# check if grounding
|
||||||
|
if tool.get("googleSearchRetrieval", None) is not None:
|
||||||
|
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
||||||
|
elif openai_function_object is not None:
|
||||||
|
gtool_func_declaration = FunctionDeclaration(
|
||||||
|
name=openai_function_object["name"],
|
||||||
|
description=openai_function_object.get("description", ""),
|
||||||
|
parameters=openai_function_object.get("parameters", {}),
|
||||||
|
)
|
||||||
|
gtool_func_declarations.append(gtool_func_declaration)
|
||||||
|
else:
|
||||||
|
# assume it's a provider-specific param
|
||||||
|
verbose_logger.warning(
|
||||||
|
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
||||||
|
)
|
||||||
|
|
||||||
|
_tools = Tools(
|
||||||
|
function_declarations=gtool_func_declarations,
|
||||||
|
)
|
||||||
|
if googleSearchRetrieval is not None:
|
||||||
|
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
||||||
|
return [_tools]
|
||||||
|
|
||||||
def map_openai_params(
|
def map_openai_params(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
|
@ -582,33 +646,11 @@ class VertexGeminiConfig:
|
||||||
optional_params["frequency_penalty"] = value
|
optional_params["frequency_penalty"] = value
|
||||||
if param == "presence_penalty":
|
if param == "presence_penalty":
|
||||||
optional_params["presence_penalty"] = value
|
optional_params["presence_penalty"] = value
|
||||||
if param == "tools" and isinstance(value, list):
|
if (param == "tools" or param == "functions") and isinstance(value, list):
|
||||||
gtool_func_declarations = []
|
optional_params["tools"] = self._map_function(value=value)
|
||||||
googleSearchRetrieval: Optional[dict] = None
|
optional_params["litellm_param_is_function_call"] = (
|
||||||
provider_specific_tools: List[dict] = []
|
True if param == "functions" else False
|
||||||
for tool in value:
|
|
||||||
# check if grounding
|
|
||||||
try:
|
|
||||||
gtool_func_declaration = FunctionDeclaration(
|
|
||||||
name=tool["function"]["name"],
|
|
||||||
description=tool["function"].get("description", ""),
|
|
||||||
parameters=tool["function"].get("parameters", {}),
|
|
||||||
)
|
|
||||||
gtool_func_declarations.append(gtool_func_declaration)
|
|
||||||
except KeyError:
|
|
||||||
if tool.get("googleSearchRetrieval", None) is not None:
|
|
||||||
googleSearchRetrieval = tool["googleSearchRetrieval"]
|
|
||||||
else:
|
|
||||||
# assume it's a provider-specific param
|
|
||||||
verbose_logger.warning(
|
|
||||||
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
|
|
||||||
)
|
|
||||||
_tools = Tools(
|
|
||||||
function_declarations=gtool_func_declarations,
|
|
||||||
)
|
)
|
||||||
if googleSearchRetrieval is not None:
|
|
||||||
_tools["googleSearchRetrieval"] = googleSearchRetrieval
|
|
||||||
optional_params["tools"] = [_tools] + provider_specific_tools
|
|
||||||
if param == "tool_choice" and (
|
if param == "tool_choice" and (
|
||||||
isinstance(value, str) or isinstance(value, dict)
|
isinstance(value, str) or isinstance(value, dict)
|
||||||
):
|
):
|
||||||
|
@ -780,6 +822,7 @@ class VertexLLM(BaseLLM):
|
||||||
model_response: ModelResponse,
|
model_response: ModelResponse,
|
||||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
|
litellm_params: dict,
|
||||||
api_key: str,
|
api_key: str,
|
||||||
data: Union[dict, str],
|
data: Union[dict, str],
|
||||||
messages: List,
|
messages: List,
|
||||||
|
@ -796,7 +839,6 @@ class VertexLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose(f"raw model_response: {response.text}")
|
print_verbose(f"raw model_response: {response.text}")
|
||||||
|
|
||||||
## RESPONSE OBJECT
|
## RESPONSE OBJECT
|
||||||
try:
|
try:
|
||||||
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
|
||||||
|
@ -904,6 +946,7 @@ class VertexLLM(BaseLLM):
|
||||||
chat_completion_message = {"role": "assistant"}
|
chat_completion_message = {"role": "assistant"}
|
||||||
content_str = ""
|
content_str = ""
|
||||||
tools: List[ChatCompletionToolCallChunk] = []
|
tools: List[ChatCompletionToolCallChunk] = []
|
||||||
|
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
|
||||||
for idx, candidate in enumerate(completion_response["candidates"]):
|
for idx, candidate in enumerate(completion_response["candidates"]):
|
||||||
if "content" not in candidate:
|
if "content" not in candidate:
|
||||||
continue
|
continue
|
||||||
|
@ -926,18 +969,25 @@ class VertexLLM(BaseLLM):
|
||||||
candidate["content"]["parts"][0]["functionCall"]["args"]
|
candidate["content"]["parts"][0]["functionCall"]["args"]
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
_tool_response_chunk = ChatCompletionToolCallChunk(
|
if litellm_params.get("litellm_param_is_function_call") is True:
|
||||||
id=f"call_{str(uuid.uuid4())}",
|
functions = _function_chunk
|
||||||
type="function",
|
else:
|
||||||
function=_function_chunk,
|
_tool_response_chunk = ChatCompletionToolCallChunk(
|
||||||
index=candidate.get("index", idx),
|
id=f"call_{str(uuid.uuid4())}",
|
||||||
)
|
type="function",
|
||||||
tools.append(_tool_response_chunk)
|
function=_function_chunk,
|
||||||
|
index=candidate.get("index", idx),
|
||||||
|
)
|
||||||
|
tools.append(_tool_response_chunk)
|
||||||
|
|
||||||
chat_completion_message["content"] = (
|
chat_completion_message["content"] = (
|
||||||
content_str if len(content_str) > 0 else None
|
content_str if len(content_str) > 0 else None
|
||||||
)
|
)
|
||||||
chat_completion_message["tool_calls"] = tools
|
if len(tools) > 0:
|
||||||
|
chat_completion_message["tool_calls"] = tools
|
||||||
|
|
||||||
|
if functions is not None:
|
||||||
|
chat_completion_message["function_call"] = functions
|
||||||
|
|
||||||
choice = litellm.Choices(
|
choice = litellm.Choices(
|
||||||
finish_reason=candidate.get("finishReason", "stop"),
|
finish_reason=candidate.get("finishReason", "stop"),
|
||||||
|
@ -1235,7 +1285,7 @@ class VertexLLM(BaseLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
stream,
|
stream,
|
||||||
optional_params: dict,
|
optional_params: dict,
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
headers={},
|
headers={},
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
|
@ -1269,6 +1319,7 @@ class VertexLLM(BaseLLM):
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -1290,7 +1341,7 @@ class VertexLLM(BaseLLM):
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[str],
|
||||||
gemini_api_key: Optional[str],
|
gemini_api_key: Optional[str],
|
||||||
litellm_params=None,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
|
@ -1302,7 +1353,6 @@ class VertexLLM(BaseLLM):
|
||||||
optional_params=optional_params
|
optional_params=optional_params
|
||||||
)
|
)
|
||||||
|
|
||||||
print_verbose("Incoming Vertex Args - {}".format(locals()))
|
|
||||||
auth_header, url = self._get_token_and_url(
|
auth_header, url = self._get_token_and_url(
|
||||||
model=model,
|
model=model,
|
||||||
gemini_api_key=gemini_api_key,
|
gemini_api_key=gemini_api_key,
|
||||||
|
@ -1314,7 +1364,6 @@ class VertexLLM(BaseLLM):
|
||||||
api_base=api_base,
|
api_base=api_base,
|
||||||
should_use_v1beta1_features=should_use_v1beta1_features,
|
should_use_v1beta1_features=should_use_v1beta1_features,
|
||||||
)
|
)
|
||||||
print_verbose("Updated URL - {}".format(url))
|
|
||||||
|
|
||||||
## TRANSFORMATION ##
|
## TRANSFORMATION ##
|
||||||
try:
|
try:
|
||||||
|
@ -1358,6 +1407,18 @@ class VertexLLM(BaseLLM):
|
||||||
)
|
)
|
||||||
optional_params.pop("response_schema")
|
optional_params.pop("response_schema")
|
||||||
|
|
||||||
|
# Check for any 'litellm_param_*' set during optional param mapping
|
||||||
|
|
||||||
|
remove_keys = []
|
||||||
|
for k, v in optional_params.items():
|
||||||
|
if k.startswith("litellm_param_"):
|
||||||
|
litellm_params.update({k: v})
|
||||||
|
remove_keys.append(k)
|
||||||
|
|
||||||
|
optional_params = {
|
||||||
|
k: v for k, v in optional_params.items() if k not in remove_keys
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
content = _gemini_convert_messages_with_history(messages=messages)
|
content = _gemini_convert_messages_with_history(messages=messages)
|
||||||
tools: Optional[Tools] = optional_params.pop("tools", None)
|
tools: Optional[Tools] = optional_params.pop("tools", None)
|
||||||
|
@ -1491,6 +1552,7 @@ class VertexLLM(BaseLLM):
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
logging_obj=logging_obj,
|
logging_obj=logging_obj,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
|
litellm_params=litellm_params,
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data, # type: ignore
|
data=data, # type: ignore
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -2009,7 +2009,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=new_params,
|
optional_params=new_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
|
@ -2096,7 +2096,7 @@ def completion(
|
||||||
model_response=model_response,
|
model_response=model_response,
|
||||||
print_verbose=print_verbose,
|
print_verbose=print_verbose,
|
||||||
optional_params=new_params,
|
optional_params=new_params,
|
||||||
litellm_params=litellm_params,
|
litellm_params=litellm_params, # type: ignore
|
||||||
logger_fn=logger_fn,
|
logger_fn=logger_fn,
|
||||||
encoding=encoding,
|
encoding=encoding,
|
||||||
vertex_location=vertex_ai_location,
|
vertex_location=vertex_ai_location,
|
||||||
|
|
|
@ -2691,8 +2691,61 @@ def test_completion_hf_model_no_provider():
|
||||||
# test_completion_hf_model_no_provider()
|
# test_completion_hf_model_no_provider()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.skip(reason="anyscale stopped serving public api endpoints")
|
def gemini_mock_post(*args, **kwargs):
|
||||||
def test_completion_anyscale_with_functions():
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json = MagicMock(
|
||||||
|
return_value={
|
||||||
|
"candidates": [
|
||||||
|
{
|
||||||
|
"content": {
|
||||||
|
"parts": [
|
||||||
|
{
|
||||||
|
"functionCall": {
|
||||||
|
"name": "get_current_weather",
|
||||||
|
"args": {"location": "Boston, MA"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"role": "model",
|
||||||
|
},
|
||||||
|
"finishReason": "STOP",
|
||||||
|
"index": 0,
|
||||||
|
"safetyRatings": [
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HARASSMENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_HATE_SPEECH",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
|
||||||
|
"probability": "NEGLIGIBLE",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"usageMetadata": {
|
||||||
|
"promptTokenCount": 86,
|
||||||
|
"candidatesTokenCount": 19,
|
||||||
|
"totalTokenCount": 105,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_functions_param():
|
||||||
|
litellm.set_verbose = True
|
||||||
function1 = [
|
function1 = [
|
||||||
{
|
{
|
||||||
"name": "get_current_weather",
|
"name": "get_current_weather",
|
||||||
|
@ -2711,18 +2764,33 @@ def test_completion_anyscale_with_functions():
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
try:
|
try:
|
||||||
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
response = completion(
|
|
||||||
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
|
|
||||||
messages=messages,
|
|
||||||
functions=function1,
|
|
||||||
)
|
|
||||||
# Add any assertions here to check the response
|
|
||||||
print(response)
|
|
||||||
|
|
||||||
cost = litellm.completion_cost(completion_response=response)
|
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||||
print("cost to make anyscale completion=", cost)
|
|
||||||
assert cost > 0.0
|
client = AsyncHTTPHandler(concurrent_limit=1)
|
||||||
|
|
||||||
|
with patch.object(client, "post", side_effect=gemini_mock_post) as mock_client:
|
||||||
|
response: litellm.ModelResponse = await litellm.acompletion(
|
||||||
|
model="gemini/gemini-1.5-pro",
|
||||||
|
messages=messages,
|
||||||
|
functions=function1,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
# Add any assertions here to check the response
|
||||||
|
mock_client.assert_called()
|
||||||
|
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
|
||||||
|
assert "tools" in mock_client.call_args.kwargs["json"]
|
||||||
|
assert (
|
||||||
|
"litellm_param_is_function_call"
|
||||||
|
not in mock_client.call_args.kwargs["json"]
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
"litellm_param_is_function_call"
|
||||||
|
not in mock_client.call_args.kwargs["json"]["generationConfig"]
|
||||||
|
)
|
||||||
|
assert response.choices[0].message.function_call is not None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
pytest.fail(f"Error occurred: {e}")
|
pytest.fail(f"Error occurred: {e}")
|
||||||
|
|
||||||
|
|
|
@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode):
|
||||||
try:
|
try:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
print("Streaming gemini response")
|
print("Streaming gemini response")
|
||||||
messages = [
|
function1 = [
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
|
||||||
{
|
{
|
||||||
"role": "user",
|
"name": "get_current_weather",
|
||||||
"content": "Who was Alexander?",
|
"description": "Get the current weather in a given location",
|
||||||
},
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The city and state, e.g. San Francisco, CA",
|
||||||
|
},
|
||||||
|
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
|
||||||
|
},
|
||||||
|
"required": ["location"],
|
||||||
|
},
|
||||||
|
}
|
||||||
]
|
]
|
||||||
|
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
|
||||||
print("testing gemini streaming")
|
print("testing gemini streaming")
|
||||||
complete_response = ""
|
complete_response = ""
|
||||||
# Add any assertions here to check the response
|
# Add any assertions here to check the response
|
||||||
non_empty_chunks = 0
|
non_empty_chunks = 0
|
||||||
|
chunks = []
|
||||||
if sync_mode:
|
if sync_mode:
|
||||||
response = completion(
|
response = completion(
|
||||||
model="gemini/gemini-1.5-flash",
|
model="gemini/gemini-1.5-flash",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
functions=function1,
|
||||||
)
|
)
|
||||||
|
|
||||||
for idx, chunk in enumerate(response):
|
for idx, chunk in enumerate(response):
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
# print(chunk.choices[0].delta)
|
# print(chunk.choices[0].delta)
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
if finished:
|
if finished:
|
||||||
|
@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode):
|
||||||
model="gemini/gemini-1.5-flash",
|
model="gemini/gemini-1.5-flash",
|
||||||
messages=messages,
|
messages=messages,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
functions=function1,
|
||||||
)
|
)
|
||||||
|
|
||||||
idx = 0
|
idx = 0
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
chunks.append(chunk)
|
||||||
# print(chunk.choices[0].delta)
|
# print(chunk.choices[0].delta)
|
||||||
chunk, finished = streaming_format_tests(idx, chunk)
|
chunk, finished = streaming_format_tests(idx, chunk)
|
||||||
if finished:
|
if finished:
|
||||||
|
@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode):
|
||||||
complete_response += chunk
|
complete_response += chunk
|
||||||
idx += 1
|
idx += 1
|
||||||
|
|
||||||
if complete_response.strip() == "":
|
# if complete_response.strip() == "":
|
||||||
raise Exception("Empty response received")
|
# raise Exception("Empty response received")
|
||||||
print(f"completion_response: {complete_response}")
|
print(f"completion_response: {complete_response}")
|
||||||
assert non_empty_chunks > 1
|
|
||||||
|
complete_response = litellm.stream_chunk_builder(
|
||||||
|
chunks=chunks, messages=messages
|
||||||
|
)
|
||||||
|
|
||||||
|
assert complete_response.choices[0].message.function_call is not None
|
||||||
|
|
||||||
|
# assert non_empty_chunks > 1
|
||||||
except litellm.InternalServerError as e:
|
except litellm.InternalServerError as e:
|
||||||
pass
|
pass
|
||||||
except litellm.RateLimitError as e:
|
except litellm.RateLimitError as e:
|
||||||
|
|
|
@ -449,6 +449,7 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
|
||||||
content: Optional[str]
|
content: Optional[str]
|
||||||
tool_calls: List[ChatCompletionToolCallChunk]
|
tool_calls: List[ChatCompletionToolCallChunk]
|
||||||
role: Literal["assistant"]
|
role: Literal["assistant"]
|
||||||
|
function_call: ChatCompletionToolCallFunctionChunk
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionUsageBlock(TypedDict):
|
class ChatCompletionUsageBlock(TypedDict):
|
||||||
|
|
|
@ -90,7 +90,7 @@ class Schema(TypedDict, total=False):
|
||||||
class FunctionDeclaration(TypedDict, total=False):
|
class FunctionDeclaration(TypedDict, total=False):
|
||||||
name: Required[str]
|
name: Required[str]
|
||||||
description: str
|
description: str
|
||||||
parameters: Schema
|
parameters: Union[Schema, dict]
|
||||||
response: Schema
|
response: Schema
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -8779,6 +8779,7 @@ class CustomStreamWrapper:
|
||||||
self.chunks: List = (
|
self.chunks: List = (
|
||||||
[]
|
[]
|
||||||
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
|
||||||
|
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
@ -8786,6 +8787,19 @@ class CustomStreamWrapper:
|
||||||
def __aiter__(self):
|
def __aiter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def check_is_function_call(self, logging_obj) -> bool:
|
||||||
|
if hasattr(logging_obj, "optional_params") and isinstance(
|
||||||
|
logging_obj.optional_params, dict
|
||||||
|
):
|
||||||
|
if (
|
||||||
|
"litellm_param_is_function_call" in logging_obj.optional_params
|
||||||
|
and logging_obj.optional_params["litellm_param_is_function_call"]
|
||||||
|
is True
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
def process_chunk(self, chunk: str):
|
def process_chunk(self, chunk: str):
|
||||||
"""
|
"""
|
||||||
NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta.
|
NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta.
|
||||||
|
@ -10283,6 +10297,12 @@ class CustomStreamWrapper:
|
||||||
|
|
||||||
## CHECK FOR TOOL USE
|
## CHECK FOR TOOL USE
|
||||||
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
|
||||||
|
if self.is_function_call is True: # user passed in 'functions' param
|
||||||
|
completion_obj["function_call"] = completion_obj["tool_calls"][0][
|
||||||
|
"function"
|
||||||
|
]
|
||||||
|
completion_obj["tool_calls"] = None
|
||||||
|
|
||||||
self.tool_call = True
|
self.tool_call = True
|
||||||
|
|
||||||
## RETURN ARG
|
## RETURN ARG
|
||||||
|
@ -10294,8 +10314,13 @@ class CustomStreamWrapper:
|
||||||
)
|
)
|
||||||
or (
|
or (
|
||||||
"tool_calls" in completion_obj
|
"tool_calls" in completion_obj
|
||||||
|
and completion_obj["tool_calls"] is not None
|
||||||
and len(completion_obj["tool_calls"]) > 0
|
and len(completion_obj["tool_calls"]) > 0
|
||||||
)
|
)
|
||||||
|
or (
|
||||||
|
"function_call" in completion_obj
|
||||||
|
and completion_obj["function_call"] is not None
|
||||||
|
)
|
||||||
): # cannot set content of an OpenAI Object to be an empty string
|
): # cannot set content of an OpenAI Object to be an empty string
|
||||||
self.safety_checker()
|
self.safety_checker()
|
||||||
hold, model_response_str = self.check_special_tokens(
|
hold, model_response_str = self.check_special_tokens(
|
||||||
|
@ -10355,6 +10380,7 @@ class CustomStreamWrapper:
|
||||||
if self.sent_first_chunk is False:
|
if self.sent_first_chunk is False:
|
||||||
completion_obj["role"] = "assistant"
|
completion_obj["role"] = "assistant"
|
||||||
self.sent_first_chunk = True
|
self.sent_first_chunk = True
|
||||||
|
|
||||||
model_response.choices[0].delta = Delta(**completion_obj)
|
model_response.choices[0].delta = Delta(**completion_obj)
|
||||||
if completion_obj.get("index") is not None:
|
if completion_obj.get("index") is not None:
|
||||||
model_response.choices[0].index = completion_obj.get(
|
model_response.choices[0].index = completion_obj.get(
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue