Merge pull request #5368 from BerriAI/litellm_vertex_function_support

feat(vertex_httpx.py): support 'functions' param for gemini google ai studio + vertex ai
This commit is contained in:
Krish Dholakia 2024-08-26 22:11:42 -07:00 committed by GitHub
commit c503ff435e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 263 additions and 84 deletions

View file

@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
ChatCompletionResponseMessage,
ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock,
)
from litellm.types.llms.vertex_ai import (
@ -296,11 +297,50 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
"stream",
"tools",
"tool_choice",
"functions",
"response_format",
"n",
"stop",
]
def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
for tool in value:
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
None
)
if "function" in tool: # tools list
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
**tool["function"]
)
elif "name" in tool: # functions list
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
# check if grounding
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
elif openai_function_object is not None:
gtool_func_declaration = FunctionDeclaration(
name=openai_function_object["name"],
description=openai_function_object.get("description", ""),
parameters=openai_function_object.get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
else:
# assume it's a provider-specific param
verbose_logger.warning(
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
_tools = Tools(
function_declarations=gtool_func_declarations,
)
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
return [_tools]
def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict]
) -> Optional[ToolConfig]:
@ -363,26 +403,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
for tool in value:
_parameters = tool.get("function", {}).get("parameters", {})
_properties = _parameters.get("properties", {})
if isinstance(_properties, dict):
for _, _property in _properties.items():
if "enum" in _property and "format" not in _property:
_property["format"] = "enum"
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
if (param == "tools" or param == "functions") and isinstance(value, list):
optional_params["tools"] = self._map_function(value=value)
optional_params["litellm_param_is_function_call"] = (
True if param == "functions" else False
)
if len(_parameters.keys()) > 0:
gtool_func_declaration["parameters"] = _parameters
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
Tools(function_declarations=gtool_func_declarations)
]
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
@ -506,6 +531,7 @@ class VertexGeminiConfig:
"max_tokens",
"stream",
"tools",
"functions",
"tool_choice",
"response_format",
"n",
@ -541,6 +567,44 @@ class VertexGeminiConfig:
status_code=400,
)
def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
for tool in value:
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
None
)
if "function" in tool: # tools list
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
**tool["function"]
)
elif "name" in tool: # functions list
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
# check if grounding
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
elif openai_function_object is not None:
gtool_func_declaration = FunctionDeclaration(
name=openai_function_object["name"],
description=openai_function_object.get("description", ""),
parameters=openai_function_object.get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
else:
# assume it's a provider-specific param
verbose_logger.warning(
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
_tools = Tools(
function_declarations=gtool_func_declarations,
)
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
return [_tools]
def map_openai_params(
self,
model: str,
@ -582,33 +646,11 @@ class VertexGeminiConfig:
optional_params["frequency_penalty"] = value
if param == "presence_penalty":
optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list):
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
provider_specific_tools: List[dict] = []
for tool in value:
# check if grounding
try:
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
if (param == "tools" or param == "functions") and isinstance(value, list):
optional_params["tools"] = self._map_function(value=value)
optional_params["litellm_param_is_function_call"] = (
True if param == "functions" else False
)
gtool_func_declarations.append(gtool_func_declaration)
except KeyError:
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
else:
# assume it's a provider-specific param
verbose_logger.warning(
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
_tools = Tools(
function_declarations=gtool_func_declarations,
)
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
optional_params["tools"] = [_tools] + provider_specific_tools
if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict)
):
@ -780,6 +822,7 @@ class VertexLLM(BaseLLM):
model_response: ModelResponse,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict,
litellm_params: dict,
api_key: str,
data: Union[dict, str],
messages: List,
@ -796,7 +839,6 @@ class VertexLLM(BaseLLM):
)
print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT
try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
@ -904,6 +946,7 @@ class VertexLLM(BaseLLM):
chat_completion_message = {"role": "assistant"}
content_str = ""
tools: List[ChatCompletionToolCallChunk] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
for idx, candidate in enumerate(completion_response["candidates"]):
if "content" not in candidate:
continue
@ -926,6 +969,9 @@ class VertexLLM(BaseLLM):
candidate["content"]["parts"][0]["functionCall"]["args"]
),
)
if litellm_params.get("litellm_param_is_function_call") is True:
functions = _function_chunk
else:
_tool_response_chunk = ChatCompletionToolCallChunk(
id=f"call_{str(uuid.uuid4())}",
type="function",
@ -937,8 +983,12 @@ class VertexLLM(BaseLLM):
chat_completion_message["content"] = (
content_str if len(content_str) > 0 else None
)
if len(tools) > 0:
chat_completion_message["tool_calls"] = tools
if functions is not None:
chat_completion_message["function_call"] = functions
choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"),
index=candidate.get("index", idx),
@ -1235,7 +1285,7 @@ class VertexLLM(BaseLLM):
logging_obj,
stream,
optional_params: dict,
litellm_params=None,
litellm_params: dict,
logger_fn=None,
headers={},
client: Optional[AsyncHTTPHandler] = None,
@ -1269,6 +1319,7 @@ class VertexLLM(BaseLLM):
messages=messages,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding,
)
@ -1290,7 +1341,7 @@ class VertexLLM(BaseLLM):
vertex_location: Optional[str],
vertex_credentials: Optional[str],
gemini_api_key: Optional[str],
litellm_params=None,
litellm_params: dict,
logger_fn=None,
extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
@ -1302,7 +1353,6 @@ class VertexLLM(BaseLLM):
optional_params=optional_params
)
print_verbose("Incoming Vertex Args - {}".format(locals()))
auth_header, url = self._get_token_and_url(
model=model,
gemini_api_key=gemini_api_key,
@ -1314,7 +1364,6 @@ class VertexLLM(BaseLLM):
api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features,
)
print_verbose("Updated URL - {}".format(url))
## TRANSFORMATION ##
try:
@ -1358,6 +1407,18 @@ class VertexLLM(BaseLLM):
)
optional_params.pop("response_schema")
# Check for any 'litellm_param_*' set during optional param mapping
remove_keys = []
for k, v in optional_params.items():
if k.startswith("litellm_param_"):
litellm_params.update({k: v})
remove_keys.append(k)
optional_params = {
k: v for k, v in optional_params.items() if k not in remove_keys
}
try:
content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None)
@ -1491,6 +1552,7 @@ class VertexLLM(BaseLLM):
model_response=model_response,
logging_obj=logging_obj,
optional_params=optional_params,
litellm_params=litellm_params,
api_key="",
data=data, # type: ignore
messages=messages,

View file

@ -2009,7 +2009,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
@ -2096,7 +2096,7 @@ def completion(
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,

View file

@ -2691,8 +2691,61 @@ def test_completion_hf_model_no_provider():
# test_completion_hf_model_no_provider()
@pytest.mark.skip(reason="anyscale stopped serving public api endpoints")
def test_completion_anyscale_with_functions():
def gemini_mock_post(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = MagicMock(
return_value={
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "get_current_weather",
"args": {"location": "Boston, MA"},
}
}
],
"role": "model",
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
},
],
}
],
"usageMetadata": {
"promptTokenCount": 86,
"candidatesTokenCount": 19,
"totalTokenCount": 105,
},
}
)
return mock_response
@pytest.mark.asyncio
async def test_completion_functions_param():
litellm.set_verbose = True
function1 = [
{
"name": "get_current_weather",
@ -2711,18 +2764,33 @@ def test_completion_anyscale_with_functions():
}
]
try:
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=gemini_mock_post) as mock_client:
response: litellm.ModelResponse = await litellm.acompletion(
model="gemini/gemini-1.5-pro",
messages=messages,
functions=function1,
client=client,
)
# Add any assertions here to check the response
print(response)
cost = litellm.completion_cost(completion_response=response)
print("cost to make anyscale completion=", cost)
assert cost > 0.0
# Add any assertions here to check the response
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert "tools" in mock_client.call_args.kwargs["json"]
assert (
"litellm_param_is_function_call"
not in mock_client.call_args.kwargs["json"]
)
assert (
"litellm_param_is_function_call"
not in mock_client.call_args.kwargs["json"]["generationConfig"]
)
assert response.choices[0].message.function_call is not None
except Exception as e:
pytest.fail(f"Error occurred: {e}")

View file

@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode):
try:
litellm.set_verbose = True
print("Streaming gemini response")
messages = [
{"role": "system", "content": "You are a helpful assistant."},
function1 = [
{
"role": "user",
"content": "Who was Alexander?",
"name": "get_current_weather",
"description": "Get the current weather in a given location",
"parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
print("testing gemini streaming")
complete_response = ""
# Add any assertions here to check the response
non_empty_chunks = 0
chunks = []
if sync_mode:
response = completion(
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
functions=function1,
)
for idx, chunk in enumerate(response):
print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode):
model="gemini/gemini-1.5-flash",
messages=messages,
stream=True,
functions=function1,
)
idx = 0
async for chunk in response:
print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk)
if finished:
@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode):
complete_response += chunk
idx += 1
if complete_response.strip() == "":
raise Exception("Empty response received")
# if complete_response.strip() == "":
# raise Exception("Empty response received")
print(f"completion_response: {complete_response}")
assert non_empty_chunks > 1
complete_response = litellm.stream_chunk_builder(
chunks=chunks, messages=messages
)
assert complete_response.choices[0].message.function_call is not None
# assert non_empty_chunks > 1
except litellm.InternalServerError as e:
pass
except litellm.RateLimitError as e:

View file

@ -449,6 +449,7 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"]
function_call: ChatCompletionToolCallFunctionChunk
class ChatCompletionUsageBlock(TypedDict):

View file

@ -90,7 +90,7 @@ class Schema(TypedDict, total=False):
class FunctionDeclaration(TypedDict, total=False):
name: Required[str]
description: str
parameters: Schema
parameters: Union[Schema, dict]
response: Schema

View file

@ -8779,6 +8779,7 @@ class CustomStreamWrapper:
self.chunks: List = (
[]
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
def __iter__(self):
return self
@ -8786,6 +8787,19 @@ class CustomStreamWrapper:
def __aiter__(self):
return self
def check_is_function_call(self, logging_obj) -> bool:
if hasattr(logging_obj, "optional_params") and isinstance(
logging_obj.optional_params, dict
):
if (
"litellm_param_is_function_call" in logging_obj.optional_params
and logging_obj.optional_params["litellm_param_is_function_call"]
is True
):
return True
return False
def process_chunk(self, chunk: str):
"""
NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta.
@ -10283,6 +10297,12 @@ class CustomStreamWrapper:
## CHECK FOR TOOL USE
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
if self.is_function_call is True: # user passed in 'functions' param
completion_obj["function_call"] = completion_obj["tool_calls"][0][
"function"
]
completion_obj["tool_calls"] = None
self.tool_call = True
## RETURN ARG
@ -10294,8 +10314,13 @@ class CustomStreamWrapper:
)
or (
"tool_calls" in completion_obj
and completion_obj["tool_calls"] is not None
and len(completion_obj["tool_calls"]) > 0
)
or (
"function_call" in completion_obj
and completion_obj["function_call"] is not None
)
): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker()
hold, model_response_str = self.check_special_tokens(
@ -10355,6 +10380,7 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
if completion_obj.get("index") is not None:
model_response.choices[0].index = completion_obj.get(