Merge pull request #5368 from BerriAI/litellm_vertex_function_support

feat(vertex_httpx.py): support 'functions' param for gemini google ai studio + vertex ai
This commit is contained in:
Krish Dholakia 2024-08-26 22:11:42 -07:00 committed by GitHub
commit c503ff435e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 263 additions and 84 deletions

View file

@ -30,6 +30,7 @@ from litellm.types.llms.openai import (
ChatCompletionResponseMessage, ChatCompletionResponseMessage,
ChatCompletionToolCallChunk, ChatCompletionToolCallChunk,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionToolParamFunctionChunk,
ChatCompletionUsageBlock, ChatCompletionUsageBlock,
) )
from litellm.types.llms.vertex_ai import ( from litellm.types.llms.vertex_ai import (
@ -296,11 +297,50 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
"stream", "stream",
"tools", "tools",
"tool_choice", "tool_choice",
"functions",
"response_format", "response_format",
"n", "n",
"stop", "stop",
] ]
def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
for tool in value:
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
None
)
if "function" in tool: # tools list
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
**tool["function"]
)
elif "name" in tool: # functions list
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
# check if grounding
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
elif openai_function_object is not None:
gtool_func_declaration = FunctionDeclaration(
name=openai_function_object["name"],
description=openai_function_object.get("description", ""),
parameters=openai_function_object.get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
else:
# assume it's a provider-specific param
verbose_logger.warning(
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
_tools = Tools(
function_declarations=gtool_func_declarations,
)
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
return [_tools]
def map_tool_choice_values( def map_tool_choice_values(
self, model: str, tool_choice: Union[str, dict] self, model: str, tool_choice: Union[str, dict]
) -> Optional[ToolConfig]: ) -> Optional[ToolConfig]:
@ -363,26 +403,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty
if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore
optional_params["response_mime_type"] = "application/json" optional_params["response_mime_type"] = "application/json"
optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore
if param == "tools" and isinstance(value, list): if (param == "tools" or param == "functions") and isinstance(value, list):
gtool_func_declarations = [] optional_params["tools"] = self._map_function(value=value)
for tool in value: optional_params["litellm_param_is_function_call"] = (
_parameters = tool.get("function", {}).get("parameters", {}) True if param == "functions" else False
_properties = _parameters.get("properties", {}) )
if isinstance(_properties, dict):
for _, _property in _properties.items():
if "enum" in _property and "format" not in _property:
_property["format"] = "enum"
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
)
if len(_parameters.keys()) > 0:
gtool_func_declaration["parameters"] = _parameters
gtool_func_declarations.append(gtool_func_declaration)
optional_params["tools"] = [
Tools(function_declarations=gtool_func_declarations)
]
if param == "tool_choice" and ( if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict) isinstance(value, str) or isinstance(value, dict)
): ):
@ -506,6 +531,7 @@ class VertexGeminiConfig:
"max_tokens", "max_tokens",
"stream", "stream",
"tools", "tools",
"functions",
"tool_choice", "tool_choice",
"response_format", "response_format",
"n", "n",
@ -541,6 +567,44 @@ class VertexGeminiConfig:
status_code=400, status_code=400,
) )
def _map_function(self, value: List[dict]) -> List[Tools]:
gtool_func_declarations = []
googleSearchRetrieval: Optional[dict] = None
for tool in value:
openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = (
None
)
if "function" in tool: # tools list
openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore
**tool["function"]
)
elif "name" in tool: # functions list
openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore
# check if grounding
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
elif openai_function_object is not None:
gtool_func_declaration = FunctionDeclaration(
name=openai_function_object["name"],
description=openai_function_object.get("description", ""),
parameters=openai_function_object.get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
else:
# assume it's a provider-specific param
verbose_logger.warning(
"Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
_tools = Tools(
function_declarations=gtool_func_declarations,
)
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
return [_tools]
def map_openai_params( def map_openai_params(
self, self,
model: str, model: str,
@ -582,33 +646,11 @@ class VertexGeminiConfig:
optional_params["frequency_penalty"] = value optional_params["frequency_penalty"] = value
if param == "presence_penalty": if param == "presence_penalty":
optional_params["presence_penalty"] = value optional_params["presence_penalty"] = value
if param == "tools" and isinstance(value, list): if (param == "tools" or param == "functions") and isinstance(value, list):
gtool_func_declarations = [] optional_params["tools"] = self._map_function(value=value)
googleSearchRetrieval: Optional[dict] = None optional_params["litellm_param_is_function_call"] = (
provider_specific_tools: List[dict] = [] True if param == "functions" else False
for tool in value:
# check if grounding
try:
gtool_func_declaration = FunctionDeclaration(
name=tool["function"]["name"],
description=tool["function"].get("description", ""),
parameters=tool["function"].get("parameters", {}),
)
gtool_func_declarations.append(gtool_func_declaration)
except KeyError:
if tool.get("googleSearchRetrieval", None) is not None:
googleSearchRetrieval = tool["googleSearchRetrieval"]
else:
# assume it's a provider-specific param
verbose_logger.warning(
"Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request."
)
_tools = Tools(
function_declarations=gtool_func_declarations,
) )
if googleSearchRetrieval is not None:
_tools["googleSearchRetrieval"] = googleSearchRetrieval
optional_params["tools"] = [_tools] + provider_specific_tools
if param == "tool_choice" and ( if param == "tool_choice" and (
isinstance(value, str) or isinstance(value, dict) isinstance(value, str) or isinstance(value, dict)
): ):
@ -780,6 +822,7 @@ class VertexLLM(BaseLLM):
model_response: ModelResponse, model_response: ModelResponse,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
optional_params: dict, optional_params: dict,
litellm_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],
messages: List, messages: List,
@ -796,7 +839,6 @@ class VertexLLM(BaseLLM):
) )
print_verbose(f"raw model_response: {response.text}") print_verbose(f"raw model_response: {response.text}")
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = GenerateContentResponseBody(**response.json()) # type: ignore completion_response = GenerateContentResponseBody(**response.json()) # type: ignore
@ -904,6 +946,7 @@ class VertexLLM(BaseLLM):
chat_completion_message = {"role": "assistant"} chat_completion_message = {"role": "assistant"}
content_str = "" content_str = ""
tools: List[ChatCompletionToolCallChunk] = [] tools: List[ChatCompletionToolCallChunk] = []
functions: Optional[ChatCompletionToolCallFunctionChunk] = None
for idx, candidate in enumerate(completion_response["candidates"]): for idx, candidate in enumerate(completion_response["candidates"]):
if "content" not in candidate: if "content" not in candidate:
continue continue
@ -926,18 +969,25 @@ class VertexLLM(BaseLLM):
candidate["content"]["parts"][0]["functionCall"]["args"] candidate["content"]["parts"][0]["functionCall"]["args"]
), ),
) )
_tool_response_chunk = ChatCompletionToolCallChunk( if litellm_params.get("litellm_param_is_function_call") is True:
id=f"call_{str(uuid.uuid4())}", functions = _function_chunk
type="function", else:
function=_function_chunk, _tool_response_chunk = ChatCompletionToolCallChunk(
index=candidate.get("index", idx), id=f"call_{str(uuid.uuid4())}",
) type="function",
tools.append(_tool_response_chunk) function=_function_chunk,
index=candidate.get("index", idx),
)
tools.append(_tool_response_chunk)
chat_completion_message["content"] = ( chat_completion_message["content"] = (
content_str if len(content_str) > 0 else None content_str if len(content_str) > 0 else None
) )
chat_completion_message["tool_calls"] = tools if len(tools) > 0:
chat_completion_message["tool_calls"] = tools
if functions is not None:
chat_completion_message["function_call"] = functions
choice = litellm.Choices( choice = litellm.Choices(
finish_reason=candidate.get("finishReason", "stop"), finish_reason=candidate.get("finishReason", "stop"),
@ -1235,7 +1285,7 @@ class VertexLLM(BaseLLM):
logging_obj, logging_obj,
stream, stream,
optional_params: dict, optional_params: dict,
litellm_params=None, litellm_params: dict,
logger_fn=None, logger_fn=None,
headers={}, headers={},
client: Optional[AsyncHTTPHandler] = None, client: Optional[AsyncHTTPHandler] = None,
@ -1269,6 +1319,7 @@ class VertexLLM(BaseLLM):
messages=messages, messages=messages,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
encoding=encoding, encoding=encoding,
) )
@ -1290,7 +1341,7 @@ class VertexLLM(BaseLLM):
vertex_location: Optional[str], vertex_location: Optional[str],
vertex_credentials: Optional[str], vertex_credentials: Optional[str],
gemini_api_key: Optional[str], gemini_api_key: Optional[str],
litellm_params=None, litellm_params: dict,
logger_fn=None, logger_fn=None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
@ -1302,7 +1353,6 @@ class VertexLLM(BaseLLM):
optional_params=optional_params optional_params=optional_params
) )
print_verbose("Incoming Vertex Args - {}".format(locals()))
auth_header, url = self._get_token_and_url( auth_header, url = self._get_token_and_url(
model=model, model=model,
gemini_api_key=gemini_api_key, gemini_api_key=gemini_api_key,
@ -1314,7 +1364,6 @@ class VertexLLM(BaseLLM):
api_base=api_base, api_base=api_base,
should_use_v1beta1_features=should_use_v1beta1_features, should_use_v1beta1_features=should_use_v1beta1_features,
) )
print_verbose("Updated URL - {}".format(url))
## TRANSFORMATION ## ## TRANSFORMATION ##
try: try:
@ -1358,6 +1407,18 @@ class VertexLLM(BaseLLM):
) )
optional_params.pop("response_schema") optional_params.pop("response_schema")
# Check for any 'litellm_param_*' set during optional param mapping
remove_keys = []
for k, v in optional_params.items():
if k.startswith("litellm_param_"):
litellm_params.update({k: v})
remove_keys.append(k)
optional_params = {
k: v for k, v in optional_params.items() if k not in remove_keys
}
try: try:
content = _gemini_convert_messages_with_history(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
tools: Optional[Tools] = optional_params.pop("tools", None) tools: Optional[Tools] = optional_params.pop("tools", None)
@ -1491,6 +1552,7 @@ class VertexLLM(BaseLLM):
model_response=model_response, model_response=model_response,
logging_obj=logging_obj, logging_obj=logging_obj,
optional_params=optional_params, optional_params=optional_params,
litellm_params=litellm_params,
api_key="", api_key="",
data=data, # type: ignore data=data, # type: ignore
messages=messages, messages=messages,

View file

@ -2009,7 +2009,7 @@ def completion(
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=new_params, optional_params=new_params,
litellm_params=litellm_params, litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,
@ -2096,7 +2096,7 @@ def completion(
model_response=model_response, model_response=model_response,
print_verbose=print_verbose, print_verbose=print_verbose,
optional_params=new_params, optional_params=new_params,
litellm_params=litellm_params, litellm_params=litellm_params, # type: ignore
logger_fn=logger_fn, logger_fn=logger_fn,
encoding=encoding, encoding=encoding,
vertex_location=vertex_ai_location, vertex_location=vertex_ai_location,

View file

@ -2691,8 +2691,61 @@ def test_completion_hf_model_no_provider():
# test_completion_hf_model_no_provider() # test_completion_hf_model_no_provider()
@pytest.mark.skip(reason="anyscale stopped serving public api endpoints") def gemini_mock_post(*args, **kwargs):
def test_completion_anyscale_with_functions(): mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json = MagicMock(
return_value={
"candidates": [
{
"content": {
"parts": [
{
"functionCall": {
"name": "get_current_weather",
"args": {"location": "Boston, MA"},
}
}
],
"role": "model",
},
"finishReason": "STOP",
"index": 0,
"safetyRatings": [
{
"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_HARASSMENT",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_HATE_SPEECH",
"probability": "NEGLIGIBLE",
},
{
"category": "HARM_CATEGORY_DANGEROUS_CONTENT",
"probability": "NEGLIGIBLE",
},
],
}
],
"usageMetadata": {
"promptTokenCount": 86,
"candidatesTokenCount": 19,
"totalTokenCount": 105,
},
}
)
return mock_response
@pytest.mark.asyncio
async def test_completion_functions_param():
litellm.set_verbose = True
function1 = [ function1 = [
{ {
"name": "get_current_weather", "name": "get_current_weather",
@ -2711,18 +2764,33 @@ def test_completion_anyscale_with_functions():
} }
] ]
try: try:
messages = [{"role": "user", "content": "What is the weather like in Boston?"}] from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
response = completion(
model="anyscale/mistralai/Mistral-7B-Instruct-v0.1",
messages=messages,
functions=function1,
)
# Add any assertions here to check the response
print(response)
cost = litellm.completion_cost(completion_response=response) messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
print("cost to make anyscale completion=", cost)
assert cost > 0.0 client = AsyncHTTPHandler(concurrent_limit=1)
with patch.object(client, "post", side_effect=gemini_mock_post) as mock_client:
response: litellm.ModelResponse = await litellm.acompletion(
model="gemini/gemini-1.5-pro",
messages=messages,
functions=function1,
client=client,
)
print(response)
# Add any assertions here to check the response
mock_client.assert_called()
print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}")
assert "tools" in mock_client.call_args.kwargs["json"]
assert (
"litellm_param_is_function_call"
not in mock_client.call_args.kwargs["json"]
)
assert (
"litellm_param_is_function_call"
not in mock_client.call_args.kwargs["json"]["generationConfig"]
)
assert response.choices[0].message.function_call is not None
except Exception as e: except Exception as e:
pytest.fail(f"Error occurred: {e}") pytest.fail(f"Error occurred: {e}")

View file

@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode):
try: try:
litellm.set_verbose = True litellm.set_verbose = True
print("Streaming gemini response") print("Streaming gemini response")
messages = [ function1 = [
{"role": "system", "content": "You are a helpful assistant."},
{ {
"role": "user", "name": "get_current_weather",
"content": "Who was Alexander?", "description": "Get the current weather in a given location",
}, "parameters": {
"type": "object",
"properties": {
"location": {
"type": "string",
"description": "The city and state, e.g. San Francisco, CA",
},
"unit": {"type": "string", "enum": ["celsius", "fahrenheit"]},
},
"required": ["location"],
},
}
] ]
messages = [{"role": "user", "content": "What is the weather like in Boston?"}]
print("testing gemini streaming") print("testing gemini streaming")
complete_response = "" complete_response = ""
# Add any assertions here to check the response # Add any assertions here to check the response
non_empty_chunks = 0 non_empty_chunks = 0
chunks = []
if sync_mode: if sync_mode:
response = completion( response = completion(
model="gemini/gemini-1.5-flash", model="gemini/gemini-1.5-flash",
messages=messages, messages=messages,
stream=True, stream=True,
functions=function1,
) )
for idx, chunk in enumerate(response): for idx, chunk in enumerate(response):
print(chunk) print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta) # print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode):
model="gemini/gemini-1.5-flash", model="gemini/gemini-1.5-flash",
messages=messages, messages=messages,
stream=True, stream=True,
functions=function1,
) )
idx = 0 idx = 0
async for chunk in response: async for chunk in response:
print(chunk) print(chunk)
chunks.append(chunk)
# print(chunk.choices[0].delta) # print(chunk.choices[0].delta)
chunk, finished = streaming_format_tests(idx, chunk) chunk, finished = streaming_format_tests(idx, chunk)
if finished: if finished:
@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode):
complete_response += chunk complete_response += chunk
idx += 1 idx += 1
if complete_response.strip() == "": # if complete_response.strip() == "":
raise Exception("Empty response received") # raise Exception("Empty response received")
print(f"completion_response: {complete_response}") print(f"completion_response: {complete_response}")
assert non_empty_chunks > 1
complete_response = litellm.stream_chunk_builder(
chunks=chunks, messages=messages
)
assert complete_response.choices[0].message.function_call is not None
# assert non_empty_chunks > 1
except litellm.InternalServerError as e: except litellm.InternalServerError as e:
pass pass
except litellm.RateLimitError as e: except litellm.RateLimitError as e:

View file

@ -449,6 +449,7 @@ class ChatCompletionResponseMessage(TypedDict, total=False):
content: Optional[str] content: Optional[str]
tool_calls: List[ChatCompletionToolCallChunk] tool_calls: List[ChatCompletionToolCallChunk]
role: Literal["assistant"] role: Literal["assistant"]
function_call: ChatCompletionToolCallFunctionChunk
class ChatCompletionUsageBlock(TypedDict): class ChatCompletionUsageBlock(TypedDict):

View file

@ -90,7 +90,7 @@ class Schema(TypedDict, total=False):
class FunctionDeclaration(TypedDict, total=False): class FunctionDeclaration(TypedDict, total=False):
name: Required[str] name: Required[str]
description: str description: str
parameters: Schema parameters: Union[Schema, dict]
response: Schema response: Schema

View file

@ -8779,6 +8779,7 @@ class CustomStreamWrapper:
self.chunks: List = ( self.chunks: List = (
[] []
) # keep track of the returned chunks - used for calculating the input/output tokens for stream options ) # keep track of the returned chunks - used for calculating the input/output tokens for stream options
self.is_function_call = self.check_is_function_call(logging_obj=logging_obj)
def __iter__(self): def __iter__(self):
return self return self
@ -8786,6 +8787,19 @@ class CustomStreamWrapper:
def __aiter__(self): def __aiter__(self):
return self return self
def check_is_function_call(self, logging_obj) -> bool:
if hasattr(logging_obj, "optional_params") and isinstance(
logging_obj.optional_params, dict
):
if (
"litellm_param_is_function_call" in logging_obj.optional_params
and logging_obj.optional_params["litellm_param_is_function_call"]
is True
):
return True
return False
def process_chunk(self, chunk: str): def process_chunk(self, chunk: str):
""" """
NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta.
@ -10283,6 +10297,12 @@ class CustomStreamWrapper:
## CHECK FOR TOOL USE ## CHECK FOR TOOL USE
if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0: if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0:
if self.is_function_call is True: # user passed in 'functions' param
completion_obj["function_call"] = completion_obj["tool_calls"][0][
"function"
]
completion_obj["tool_calls"] = None
self.tool_call = True self.tool_call = True
## RETURN ARG ## RETURN ARG
@ -10294,8 +10314,13 @@ class CustomStreamWrapper:
) )
or ( or (
"tool_calls" in completion_obj "tool_calls" in completion_obj
and completion_obj["tool_calls"] is not None
and len(completion_obj["tool_calls"]) > 0 and len(completion_obj["tool_calls"]) > 0
) )
or (
"function_call" in completion_obj
and completion_obj["function_call"] is not None
)
): # cannot set content of an OpenAI Object to be an empty string ): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker() self.safety_checker()
hold, model_response_str = self.check_special_tokens( hold, model_response_str = self.check_special_tokens(
@ -10355,6 +10380,7 @@ class CustomStreamWrapper:
if self.sent_first_chunk is False: if self.sent_first_chunk is False:
completion_obj["role"] = "assistant" completion_obj["role"] = "assistant"
self.sent_first_chunk = True self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj) model_response.choices[0].delta = Delta(**completion_obj)
if completion_obj.get("index") is not None: if completion_obj.get("index") is not None:
model_response.choices[0].index = completion_obj.get( model_response.choices[0].index = completion_obj.get(