From c426d75e91699d4f242996c38a7a2c02a8fc4bc5 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 12 Jun 2024 21:11:00 -0700 Subject: [PATCH] fix(vertex_httpx.py): add function calling support to httpx route --- litellm/__init__.py | 1 + litellm/llms/vertex_httpx.py | 239 +++++++++++++++++- .../tests/test_amazing_vertex_completion.py | 73 +++++- litellm/types/llms/vertex_ai.py | 32 ++- litellm/utils.py | 10 + log.txt | 10 + 6 files changed, 345 insertions(+), 20 deletions(-) create mode 100644 log.txt diff --git a/litellm/__init__.py b/litellm/__init__.py index 19c3bcca6..523ce4684 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -767,6 +767,7 @@ from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig from .llms.vertex_ai import VertexAIConfig +from .llms.vertex_httpx import VertexGeminiConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index 550fffe4a..e660e4d72 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -19,6 +19,10 @@ from litellm.types.llms.vertex_ai import ( PartType, RequestBody, GenerateContentResponseBody, + FunctionCallingConfig, + FunctionDeclaration, + Tools, + ToolConfig, ) from litellm.llms.vertex_ai import _gemini_convert_messages_with_history from litellm.types.utils import GenericStreamingChunk @@ -26,18 +30,203 @@ from litellm.types.llms.openai import ( ChatCompletionUsageBlock, ChatCompletionToolCallChunk, ChatCompletionToolCallFunctionChunk, + ChatCompletionResponseMessage, ) class VertexGeminiConfig: - def __init__(self) -> None: - pass + """ + Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts + Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference - def supports_system_message(self) -> bool: + The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters: + + - `temperature` (float): This controls the degree of randomness in token selection. + + - `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256. + + - `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95. + + - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40. + + - `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. + + - `candidate_count` (int): Number of generated responses to return. + + - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. + + - `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0. + + - `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0. + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + temperature: Optional[float] = None + max_output_tokens: Optional[int] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + response_mime_type: Optional[str] = None + candidate_count: Optional[int] = None + stop_sequences: Optional[list] = None + frequency_penalty: Optional[float] = None + presence_penalty: Optional[float] = None + + def __init__( + self, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + response_mime_type: Optional[str] = None, + candidate_count: Optional[int] = None, + stop_sequences: Optional[list] = None, + frequency_penalty: Optional[float] = None, + presence_penalty: Optional[float] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "temperature", + "top_p", + "max_tokens", + "stream", + "tools", + "tool_choice", + "response_format", + "n", + "stop", + ] + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict] + ) -> Optional[ToolConfig]: + if tool_choice == "none": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE")) + elif tool_choice == "required": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY")) + elif tool_choice == "auto": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO")) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + name = tool_choice.get("function", {}).get("name", "") + return ToolConfig( + functionCallingConfig=FunctionCallingConfig( + mode="ANY", allowed_function_names=[name] + ) + ) + else: + raise litellm.utils.UnsupportedParamsError( + message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def map_openai_params( + self, + model: str, + non_default_params: dict, + optional_params: dict, + ): + for param, value in non_default_params.items(): + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if ( + param == "stream" and value is True + ): # sending stream = False, can cause it to get passed unchecked and raise issues + optional_params["stream"] = value + if param == "n": + optional_params["candidate_count"] = value + if param == "stop": + if isinstance(value, str): + optional_params["stop_sequences"] = [value] + elif isinstance(value, list): + optional_params["stop_sequences"] = value + if param == "max_tokens": + optional_params["max_output_tokens"] = value + if param == "response_format" and value["type"] == "json_object": # type: ignore + optional_params["response_mime_type"] = "application/json" + if param == "frequency_penalty": + optional_params["frequency_penalty"] = value + if param == "presence_penalty": + optional_params["presence_penalty"] = value + if param == "tools" and isinstance(value, list): + gtool_func_declarations = [] + for tool in value: + gtool_func_declaration = FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"].get("description", ""), + parameters=tool["function"].get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + optional_params["tools"] = [ + Tools(function_declarations=gtool_func_declarations) + ] + if param == "tool_choice" and ( + isinstance(value, str) or isinstance(value, dict) + ): + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + return optional_params + + def get_mapped_special_auth_params(self) -> dict: """ - Not all gemini models support system instructions + Common auth params across bedrock/vertex_ai/azure/watsonx """ - return True + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_eu_regions(self) -> List[str]: + """ + Source: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#available-regions + """ + return [ + "europe-central2", + "europe-north1", + "europe-southwest1", + "europe-west1", + "europe-west2", + "europe-west3", + "europe-west4", + "europe-west6", + "europe-west8", + "europe-west9", + ] async def make_call( @@ -165,21 +354,37 @@ class VertexLLM(BaseLLM): ## GET MODEL ## model_response.model = model ## GET TEXT ## + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] for idx, candidate in enumerate(completion_response["candidates"]): - if candidate.get("content", None) is None: + if "content" not in candidate: continue - message = litellm.Message( - content=candidate["content"]["parts"][0]["text"], - role="assistant", - logprobs=None, - function_call=None, - tool_calls=None, - ) + if "text" in candidate["content"]["parts"][0]: + content_str = candidate["content"]["parts"][0]["text"] + + if "functionCall" in candidate["content"]["parts"][0]: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=candidate["content"]["parts"][0]["functionCall"]["name"], + arguments=json.dumps( + candidate["content"]["parts"][0]["functionCall"]["args"] + ), + ) + _tool_response_chunk = ChatCompletionToolCallChunk( + id=f"call_{str(uuid.uuid4())}", + type="function", + function=_function_chunk, + ) + tools.append(_tool_response_chunk) + + chat_completion_message["content"] = content_str + chat_completion_message["tool_calls"] = tools + choice = litellm.Choices( finish_reason=candidate.get("finishReason", "stop"), index=candidate.get("index", idx), - message=message, + message=chat_completion_message, # type: ignore logprobs=None, enhancements=None, ) @@ -402,8 +607,14 @@ class VertexLLM(BaseLLM): messages.pop(idx) system_instructions = SystemInstructions(parts=system_content_blocks) content = _gemini_convert_messages_with_history(messages=messages) + tools: Optional[Tools] = optional_params.pop("tools", None) + tool_choice: Optional[ToolConfig] = optional_params.pop("tool_choice", None) data = RequestBody(system_instruction=system_instructions, contents=content) + if tools is not None: + data["tools"] = tools + if tool_choice is not None: + data["toolConfig"] = tool_choice headers = { "Content-Type": "application/json; charset=utf-8", "Authorization": f"Bearer {auth_header}", diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 7f0b49808..3037f51e6 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -615,9 +615,76 @@ def test_gemini_pro_vision_base64(): pytest.fail(f"An exception occurred - {str(e)}") -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("provider", ["vertex_ai_beta"]) # "vertex_ai", +@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai", @pytest.mark.asyncio -async def test_gemini_pro_function_calling(sync_mode): +async def test_gemini_pro_function_calling_httpx(provider, sync_mode): + try: + load_vertex_ai_credentials() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + } + }, + "required": ["location"], + }, + }, + } + ] + + data = { + "model": "{}/gemini-1.5-pro-preview-0514".format(provider), + "messages": messages, + "tools": tools, + "tool_choice": "required", + } + if sync_mode: + response = litellm.completion(**data) + else: + response = await litellm.acompletion(**data) + + print(f"response: {response}") + + assert response.choices[0].message.tool_calls[0].function.arguments is not None + assert isinstance( + response.choices[0].message.tool_calls[0].function.arguments, str + ) + except litellm.RateLimitError as e: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + +@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("provider", ["vertex_ai"]) +@pytest.mark.asyncio +async def test_gemini_pro_function_calling(provider, sync_mode): try: load_vertex_ai_credentials() litellm.set_verbose = True @@ -679,7 +746,7 @@ async def test_gemini_pro_function_calling(sync_mode): ] data = { - "model": "vertex_ai/gemini-1.5-pro-preview-0514", + "model": "{}/gemini-1.5-pro-preview-0514".format(provider), "messages": messages, "tools": tools, } diff --git a/litellm/types/llms/vertex_ai.py b/litellm/types/llms/vertex_ai.py index fe903841e..18207b88e 100644 --- a/litellm/types/llms/vertex_ai.py +++ b/litellm/types/llms/vertex_ai.py @@ -49,6 +49,24 @@ class PartType(TypedDict, total=False): function_response: FunctionResponse +class HttpxFunctionCall(TypedDict): + name: str + args: dict + + +class HttpxPartType(TypedDict, total=False): + text: str + inline_data: BlobType + file_data: FileDataType + functionCall: HttpxFunctionCall + function_response: FunctionResponse + + +class HttpxContentType(TypedDict, total=False): + role: Literal["user", "model"] + parts: Required[List[HttpxPartType]] + + class ContentType(TypedDict, total=False): role: Literal["user", "model"] parts: Required[List[PartType]] @@ -128,11 +146,19 @@ class GenerationConfig(TypedDict, total=False): response_mime_type: Literal["text/plain", "application/json"] +class Tools(TypedDict): + function_declarations: List[FunctionDeclaration] + + +class ToolConfig(TypedDict): + functionCallingConfig: FunctionCallingConfig + + class RequestBody(TypedDict, total=False): contents: Required[List[ContentType]] system_instruction: SystemInstructions - tools: FunctionDeclaration - tool_config: FunctionCallingConfig + tools: Tools + toolConfig: ToolConfig safety_settings: SafetSettingsConfig generation_config: GenerationConfig @@ -176,7 +202,7 @@ class GroundingMetadata(TypedDict, total=False): class Candidates(TypedDict, total=False): index: int - content: ContentType + content: HttpxContentType finishReason: Literal[ "FINISH_REASON_UNSPECIFIED", "STOP", diff --git a/litellm/utils.py b/litellm/utils.py index f132e3202..cfec3fd4a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5386,6 +5386,16 @@ def get_optional_params( print_verbose( f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" ) + elif custom_llm_provider == "vertex_ai_beta": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VertexGeminiConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) elif ( custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models ): diff --git a/log.txt b/log.txt new file mode 100644 index 000000000..9f7660563 --- /dev/null +++ b/log.txt @@ -0,0 +1,10 @@ +============================= test session starts ============================== +platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -- /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/bin/python3.11 +cachedir: .pytest_cache +rootdir: /Users/krrishdholakia/Documents/litellm +configfile: pyproject.toml +plugins: logfire-0.35.0, asyncio-0.23.6, mock-3.14.0, anyio-4.2.0 +asyncio: mode=Mode.STRICT +collecting ... collected 0 items + +============================ no tests ran in 0.00s =============================