diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 5042d0f77e..2960dd82f9 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -30,6 +30,7 @@ from litellm.types.llms.openai import ( ChatCompletionResponseMessage, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, + ChatCompletionToolParamFunctionChunk, ChatCompletionUsageBlock, ) from litellm.types.llms.vertex_ai import ( @@ -296,11 +297,50 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty "stream", "tools", "tool_choice", + "functions", "response_format", "n", "stop", ] + def _map_function(self, value: List[dict]) -> List[Tools]: + gtool_func_declarations = [] + googleSearchRetrieval: Optional[dict] = None + + for tool in value: + openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( + None + ) + if "function" in tool: # tools list + openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore + **tool["function"] + ) + elif "name" in tool: # functions list + openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore + + # check if grounding + if tool.get("googleSearchRetrieval", None) is not None: + googleSearchRetrieval = tool["googleSearchRetrieval"] + elif openai_function_object is not None: + gtool_func_declaration = FunctionDeclaration( + name=openai_function_object["name"], + description=openai_function_object.get("description", ""), + parameters=openai_function_object.get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + else: + # assume it's a provider-specific param + verbose_logger.warning( + "Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request." + ) + + _tools = Tools( + function_declarations=gtool_func_declarations, + ) + if googleSearchRetrieval is not None: + _tools["googleSearchRetrieval"] = googleSearchRetrieval + return [_tools] + def map_tool_choice_values( self, model: str, tool_choice: Union[str, dict] ) -> Optional[ToolConfig]: @@ -363,26 +403,11 @@ class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty if "json_schema" in value and "schema" in value["json_schema"]: # type: ignore optional_params["response_mime_type"] = "application/json" optional_params["response_schema"] = value["json_schema"]["schema"] # type: ignore - if param == "tools" and isinstance(value, list): - gtool_func_declarations = [] - for tool in value: - _parameters = tool.get("function", {}).get("parameters", {}) - _properties = _parameters.get("properties", {}) - if isinstance(_properties, dict): - for _, _property in _properties.items(): - if "enum" in _property and "format" not in _property: - _property["format"] = "enum" - - gtool_func_declaration = FunctionDeclaration( - name=tool["function"]["name"], - description=tool["function"].get("description", ""), - ) - if len(_parameters.keys()) > 0: - gtool_func_declaration["parameters"] = _parameters - gtool_func_declarations.append(gtool_func_declaration) - optional_params["tools"] = [ - Tools(function_declarations=gtool_func_declarations) - ] + if (param == "tools" or param == "functions") and isinstance(value, list): + optional_params["tools"] = self._map_function(value=value) + optional_params["litellm_param_is_function_call"] = ( + True if param == "functions" else False + ) if param == "tool_choice" and ( isinstance(value, str) or isinstance(value, dict) ): @@ -506,6 +531,7 @@ class VertexGeminiConfig: "max_tokens", "stream", "tools", + "functions", "tool_choice", "response_format", "n", @@ -541,6 +567,44 @@ class VertexGeminiConfig: status_code=400, ) + def _map_function(self, value: List[dict]) -> List[Tools]: + gtool_func_declarations = [] + googleSearchRetrieval: Optional[dict] = None + + for tool in value: + openai_function_object: Optional[ChatCompletionToolParamFunctionChunk] = ( + None + ) + if "function" in tool: # tools list + openai_function_object = ChatCompletionToolParamFunctionChunk( # type: ignore + **tool["function"] + ) + elif "name" in tool: # functions list + openai_function_object = ChatCompletionToolParamFunctionChunk(**tool) # type: ignore + + # check if grounding + if tool.get("googleSearchRetrieval", None) is not None: + googleSearchRetrieval = tool["googleSearchRetrieval"] + elif openai_function_object is not None: + gtool_func_declaration = FunctionDeclaration( + name=openai_function_object["name"], + description=openai_function_object.get("description", ""), + parameters=openai_function_object.get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + else: + # assume it's a provider-specific param + verbose_logger.warning( + "Invalid tool={}. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request." + ) + + _tools = Tools( + function_declarations=gtool_func_declarations, + ) + if googleSearchRetrieval is not None: + _tools["googleSearchRetrieval"] = googleSearchRetrieval + return [_tools] + def map_openai_params( self, model: str, @@ -582,33 +646,11 @@ class VertexGeminiConfig: optional_params["frequency_penalty"] = value if param == "presence_penalty": optional_params["presence_penalty"] = value - if param == "tools" and isinstance(value, list): - gtool_func_declarations = [] - googleSearchRetrieval: Optional[dict] = None - provider_specific_tools: List[dict] = [] - for tool in value: - # check if grounding - try: - gtool_func_declaration = FunctionDeclaration( - name=tool["function"]["name"], - description=tool["function"].get("description", ""), - parameters=tool["function"].get("parameters", {}), - ) - gtool_func_declarations.append(gtool_func_declaration) - except KeyError: - if tool.get("googleSearchRetrieval", None) is not None: - googleSearchRetrieval = tool["googleSearchRetrieval"] - else: - # assume it's a provider-specific param - verbose_logger.warning( - "Got KeyError parsing tool={}. Assuming it's a provider-specific param. Use `litellm.set_verbose` or `litellm --detailed_debug` to see raw request." - ) - _tools = Tools( - function_declarations=gtool_func_declarations, + if (param == "tools" or param == "functions") and isinstance(value, list): + optional_params["tools"] = self._map_function(value=value) + optional_params["litellm_param_is_function_call"] = ( + True if param == "functions" else False ) - if googleSearchRetrieval is not None: - _tools["googleSearchRetrieval"] = googleSearchRetrieval - optional_params["tools"] = [_tools] + provider_specific_tools if param == "tool_choice" and ( isinstance(value, str) or isinstance(value, dict) ): @@ -780,6 +822,7 @@ class VertexLLM(BaseLLM): model_response: ModelResponse, logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, optional_params: dict, + litellm_params: dict, api_key: str, data: Union[dict, str], messages: List, @@ -796,7 +839,6 @@ class VertexLLM(BaseLLM): ) print_verbose(f"raw model_response: {response.text}") - ## RESPONSE OBJECT try: completion_response = GenerateContentResponseBody(**response.json()) # type: ignore @@ -904,6 +946,7 @@ class VertexLLM(BaseLLM): chat_completion_message = {"role": "assistant"} content_str = "" tools: List[ChatCompletionToolCallChunk] = [] + functions: Optional[ChatCompletionToolCallFunctionChunk] = None for idx, candidate in enumerate(completion_response["candidates"]): if "content" not in candidate: continue @@ -926,18 +969,25 @@ class VertexLLM(BaseLLM): candidate["content"]["parts"][0]["functionCall"]["args"] ), ) - _tool_response_chunk = ChatCompletionToolCallChunk( - id=f"call_{str(uuid.uuid4())}", - type="function", - function=_function_chunk, - index=candidate.get("index", idx), - ) - tools.append(_tool_response_chunk) + if litellm_params.get("litellm_param_is_function_call") is True: + functions = _function_chunk + else: + _tool_response_chunk = ChatCompletionToolCallChunk( + id=f"call_{str(uuid.uuid4())}", + type="function", + function=_function_chunk, + index=candidate.get("index", idx), + ) + tools.append(_tool_response_chunk) chat_completion_message["content"] = ( content_str if len(content_str) > 0 else None ) - chat_completion_message["tool_calls"] = tools + if len(tools) > 0: + chat_completion_message["tool_calls"] = tools + + if functions is not None: + chat_completion_message["function_call"] = functions choice = litellm.Choices( finish_reason=candidate.get("finishReason", "stop"), @@ -1235,7 +1285,7 @@ class VertexLLM(BaseLLM): logging_obj, stream, optional_params: dict, - litellm_params=None, + litellm_params: dict, logger_fn=None, headers={}, client: Optional[AsyncHTTPHandler] = None, @@ -1269,6 +1319,7 @@ class VertexLLM(BaseLLM): messages=messages, print_verbose=print_verbose, optional_params=optional_params, + litellm_params=litellm_params, encoding=encoding, ) @@ -1290,7 +1341,7 @@ class VertexLLM(BaseLLM): vertex_location: Optional[str], vertex_credentials: Optional[str], gemini_api_key: Optional[str], - litellm_params=None, + litellm_params: dict, logger_fn=None, extra_headers: Optional[dict] = None, client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, @@ -1302,7 +1353,6 @@ class VertexLLM(BaseLLM): optional_params=optional_params ) - print_verbose("Incoming Vertex Args - {}".format(locals())) auth_header, url = self._get_token_and_url( model=model, gemini_api_key=gemini_api_key, @@ -1314,7 +1364,6 @@ class VertexLLM(BaseLLM): api_base=api_base, should_use_v1beta1_features=should_use_v1beta1_features, ) - print_verbose("Updated URL - {}".format(url)) ## TRANSFORMATION ## try: @@ -1358,6 +1407,18 @@ class VertexLLM(BaseLLM): ) optional_params.pop("response_schema") + # Check for any 'litellm_param_*' set during optional param mapping + + remove_keys = [] + for k, v in optional_params.items(): + if k.startswith("litellm_param_"): + litellm_params.update({k: v}) + remove_keys.append(k) + + optional_params = { + k: v for k, v in optional_params.items() if k not in remove_keys + } + try: content = _gemini_convert_messages_with_history(messages=messages) tools: Optional[Tools] = optional_params.pop("tools", None) @@ -1491,6 +1552,7 @@ class VertexLLM(BaseLLM): model_response=model_response, logging_obj=logging_obj, optional_params=optional_params, + litellm_params=litellm_params, api_key="", data=data, # type: ignore messages=messages, diff --git a/litellm/main.py b/litellm/main.py index 753366754e..86ca1ce8a9 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -2009,7 +2009,7 @@ def completion( model_response=model_response, print_verbose=print_verbose, optional_params=new_params, - litellm_params=litellm_params, + litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, encoding=encoding, vertex_location=vertex_ai_location, @@ -2096,7 +2096,7 @@ def completion( model_response=model_response, print_verbose=print_verbose, optional_params=new_params, - litellm_params=litellm_params, + litellm_params=litellm_params, # type: ignore logger_fn=logger_fn, encoding=encoding, vertex_location=vertex_ai_location, diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 0e96f056bd..d5179da7a9 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2691,8 +2691,61 @@ def test_completion_hf_model_no_provider(): # test_completion_hf_model_no_provider() -@pytest.mark.skip(reason="anyscale stopped serving public api endpoints") -def test_completion_anyscale_with_functions(): +def gemini_mock_post(*args, **kwargs): + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json = MagicMock( + return_value={ + "candidates": [ + { + "content": { + "parts": [ + { + "functionCall": { + "name": "get_current_weather", + "args": {"location": "Boston, MA"}, + } + } + ], + "role": "model", + }, + "finishReason": "STOP", + "index": 0, + "safetyRatings": [ + { + "category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", + "probability": "NEGLIGIBLE", + }, + { + "category": "HARM_CATEGORY_HARASSMENT", + "probability": "NEGLIGIBLE", + }, + { + "category": "HARM_CATEGORY_HATE_SPEECH", + "probability": "NEGLIGIBLE", + }, + { + "category": "HARM_CATEGORY_DANGEROUS_CONTENT", + "probability": "NEGLIGIBLE", + }, + ], + } + ], + "usageMetadata": { + "promptTokenCount": 86, + "candidatesTokenCount": 19, + "totalTokenCount": 105, + }, + } + ) + + return mock_response + + +@pytest.mark.asyncio +async def test_completion_functions_param(): + litellm.set_verbose = True function1 = [ { "name": "get_current_weather", @@ -2711,18 +2764,33 @@ def test_completion_anyscale_with_functions(): } ] try: - messages = [{"role": "user", "content": "What is the weather like in Boston?"}] - response = completion( - model="anyscale/mistralai/Mistral-7B-Instruct-v0.1", - messages=messages, - functions=function1, - ) - # Add any assertions here to check the response - print(response) + from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler - cost = litellm.completion_cost(completion_response=response) - print("cost to make anyscale completion=", cost) - assert cost > 0.0 + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] + + client = AsyncHTTPHandler(concurrent_limit=1) + + with patch.object(client, "post", side_effect=gemini_mock_post) as mock_client: + response: litellm.ModelResponse = await litellm.acompletion( + model="gemini/gemini-1.5-pro", + messages=messages, + functions=function1, + client=client, + ) + print(response) + # Add any assertions here to check the response + mock_client.assert_called() + print(f"mock_client.call_args.kwargs: {mock_client.call_args.kwargs}") + assert "tools" in mock_client.call_args.kwargs["json"] + assert ( + "litellm_param_is_function_call" + not in mock_client.call_args.kwargs["json"] + ) + assert ( + "litellm_param_is_function_call" + not in mock_client.call_args.kwargs["json"]["generationConfig"] + ) + assert response.choices[0].message.function_call is not None except Exception as e: pytest.fail(f"Error occurred: {e}") diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index e310092bd9..7d099afc67 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -755,27 +755,40 @@ async def test_completion_gemini_stream(sync_mode): try: litellm.set_verbose = True print("Streaming gemini response") - messages = [ - {"role": "system", "content": "You are a helpful assistant."}, + function1 = [ { - "role": "user", - "content": "Who was Alexander?", - }, + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } ] + messages = [{"role": "user", "content": "What is the weather like in Boston?"}] print("testing gemini streaming") complete_response = "" # Add any assertions here to check the response non_empty_chunks = 0 - + chunks = [] if sync_mode: response = completion( model="gemini/gemini-1.5-flash", messages=messages, stream=True, + functions=function1, ) for idx, chunk in enumerate(response): print(chunk) + chunks.append(chunk) # print(chunk.choices[0].delta) chunk, finished = streaming_format_tests(idx, chunk) if finished: @@ -787,11 +800,13 @@ async def test_completion_gemini_stream(sync_mode): model="gemini/gemini-1.5-flash", messages=messages, stream=True, + functions=function1, ) idx = 0 async for chunk in response: print(chunk) + chunks.append(chunk) # print(chunk.choices[0].delta) chunk, finished = streaming_format_tests(idx, chunk) if finished: @@ -800,10 +815,17 @@ async def test_completion_gemini_stream(sync_mode): complete_response += chunk idx += 1 - if complete_response.strip() == "": - raise Exception("Empty response received") + # if complete_response.strip() == "": + # raise Exception("Empty response received") print(f"completion_response: {complete_response}") - assert non_empty_chunks > 1 + + complete_response = litellm.stream_chunk_builder( + chunks=chunks, messages=messages + ) + + assert complete_response.choices[0].message.function_call is not None + + # assert non_empty_chunks > 1 except litellm.InternalServerError as e: pass except litellm.RateLimitError as e: diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 8d7520f25c..ce1bd64fa8 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -449,6 +449,7 @@ class ChatCompletionResponseMessage(TypedDict, total=False): content: Optional[str] tool_calls: List[ChatCompletionToolCallChunk] role: Literal["assistant"] + function_call: ChatCompletionToolCallFunctionChunk class ChatCompletionUsageBlock(TypedDict): diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index 5586d4861c..74acd4fec4 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -90,7 +90,7 @@ class Schema(TypedDict, total=False): class FunctionDeclaration(TypedDict, total=False): name: Required[str] description: str - parameters: Schema + parameters: Union[Schema, dict] response: Schema diff --git a/litellm/utils.py b/litellm/utils.py index 81ba950870..3187bb9cb1 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -8779,6 +8779,7 @@ class CustomStreamWrapper: self.chunks: List = ( [] ) # keep track of the returned chunks - used for calculating the input/output tokens for stream options + self.is_function_call = self.check_is_function_call(logging_obj=logging_obj) def __iter__(self): return self @@ -8786,6 +8787,19 @@ class CustomStreamWrapper: def __aiter__(self): return self + def check_is_function_call(self, logging_obj) -> bool: + if hasattr(logging_obj, "optional_params") and isinstance( + logging_obj.optional_params, dict + ): + if ( + "litellm_param_is_function_call" in logging_obj.optional_params + and logging_obj.optional_params["litellm_param_is_function_call"] + is True + ): + return True + + return False + def process_chunk(self, chunk: str): """ NLP Cloud streaming returns the entire response, for each chunk. Process this, to only return the delta. @@ -10283,6 +10297,12 @@ class CustomStreamWrapper: ## CHECK FOR TOOL USE if "tool_calls" in completion_obj and len(completion_obj["tool_calls"]) > 0: + if self.is_function_call is True: # user passed in 'functions' param + completion_obj["function_call"] = completion_obj["tool_calls"][0][ + "function" + ] + completion_obj["tool_calls"] = None + self.tool_call = True ## RETURN ARG @@ -10294,8 +10314,13 @@ class CustomStreamWrapper: ) or ( "tool_calls" in completion_obj + and completion_obj["tool_calls"] is not None and len(completion_obj["tool_calls"]) > 0 ) + or ( + "function_call" in completion_obj + and completion_obj["function_call"] is not None + ) ): # cannot set content of an OpenAI Object to be an empty string self.safety_checker() hold, model_response_str = self.check_special_tokens( @@ -10355,6 +10380,7 @@ class CustomStreamWrapper: if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) if completion_obj.get("index") is not None: model_response.choices[0].index = completion_obj.get(