From 16941eee43f36f678fcbaff8fcbd224b76df1422 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 21 Jun 2024 09:01:32 -0700 Subject: [PATCH] fix(utils.py): re-integrate separate gemini optional param mapping (google ai studio) Fixes https://github.com/BerriAI/litellm/issues/4333 --- litellm/__init__.py | 2 +- litellm/llms/vertex_httpx.py | 189 ++++++++++++++++++++++++++ litellm/tests/test_optional_params.py | 13 ++ litellm/utils.py | 12 +- 4 files changed, 214 insertions(+), 2 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 43ca23948..a191d46bf 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -783,7 +783,7 @@ from .llms.gemini import GeminiConfig from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig -from .llms.vertex_httpx import VertexGeminiConfig +from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.sagemaker import SagemakerConfig diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index ad7cffd60..67f407b2a 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -48,6 +48,193 @@ from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from .base import BaseLLM +class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported + """ + Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig + + The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters: + + - `temperature` (float): This controls the degree of randomness in token selection. + + - `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256. + + - `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95. + + - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40. + + - `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`. + + - `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response. + + - `candidate_count` (int): Number of generated responses to return. + + - `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response. + + - `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0. + + - `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0. + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + temperature: Optional[float] = None + max_output_tokens: Optional[int] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + response_mime_type: Optional[str] = None + response_schema: Optional[dict] = None + candidate_count: Optional[int] = None + stop_sequences: Optional[list] = None + + def __init__( + self, + temperature: Optional[float] = None, + max_output_tokens: Optional[int] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + response_mime_type: Optional[str] = None, + response_schema: Optional[dict] = None, + candidate_count: Optional[int] = None, + stop_sequences: Optional[list] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "temperature", + "top_p", + "max_tokens", + "stream", + "tools", + "tool_choice", + "response_format", + "n", + "stop", + ] + + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict] + ) -> Optional[ToolConfig]: + if tool_choice == "none": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE")) + elif tool_choice == "required": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY")) + elif tool_choice == "auto": + return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO")) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + name = tool_choice.get("function", {}).get("name", "") + return ToolConfig( + functionCallingConfig=FunctionCallingConfig( + mode="ANY", allowed_function_names=[name] + ) + ) + else: + raise litellm.utils.UnsupportedParamsError( + message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + + def map_openai_params( + self, + model: str, + non_default_params: dict, + optional_params: dict, + ): + for param, value in non_default_params.items(): + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if ( + param == "stream" and value is True + ): # sending stream = False, can cause it to get passed unchecked and raise issues + optional_params["stream"] = value + if param == "n": + optional_params["candidate_count"] = value + if param == "stop": + if isinstance(value, str): + optional_params["stop_sequences"] = [value] + elif isinstance(value, list): + optional_params["stop_sequences"] = value + if param == "max_tokens": + optional_params["max_output_tokens"] = value + if param == "response_format" and value["type"] == "json_object": # type: ignore + optional_params["response_mime_type"] = "application/json" + if param == "tools" and isinstance(value, list): + gtool_func_declarations = [] + for tool in value: + gtool_func_declaration = FunctionDeclaration( + name=tool["function"]["name"], + description=tool["function"].get("description", ""), + parameters=tool["function"].get("parameters", {}), + ) + gtool_func_declarations.append(gtool_func_declaration) + optional_params["tools"] = [ + Tools(function_declarations=gtool_func_declarations) + ] + if param == "tool_choice" and ( + isinstance(value, str) or isinstance(value, dict) + ): + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value # type: ignore + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value + return optional_params + + def get_mapped_special_auth_params(self) -> dict: + """ + Common auth params across bedrock/vertex_ai/azure/watsonx + """ + return {"project": "vertex_project", "region_name": "vertex_location"} + + def map_special_auth_params(self, non_default_params: dict, optional_params: dict): + mapped_params = self.get_mapped_special_auth_params() + + for param, value in non_default_params.items(): + if param in mapped_params: + optional_params[mapped_params[param]] = value + return optional_params + + def get_flagged_finish_reasons(self) -> Dict[str, str]: + """ + Return Dictionary of finish reasons which indicate response was flagged + + and what it means + """ + return { + "SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.", + "RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.", + "BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.", + "PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.", + "SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.", + } + + class VertexGeminiConfig: """ Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts @@ -132,6 +319,8 @@ class VertexGeminiConfig: "response_format", "n", "stop", + "frequency_penalty", + "presence_penalty", ] def map_tool_choice_values( diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index be14e4feb..a6fa6334b 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -103,6 +103,19 @@ def test_databricks_optional_params(): assert "user" not in optional_params +def test_gemini_optional_params(): + litellm.drop_params = True + optional_params = get_optional_params( + model="", + custom_llm_provider="gemini", + max_tokens=10, + frequency_penalty=10, + ) + print(f"optional_params: {optional_params}") + assert len(optional_params) == 1 + assert "frequency_penalty" not in optional_params + + def test_azure_ai_mistral_optional_params(): litellm.drop_params = True optional_params = get_optional_params( diff --git a/litellm/utils.py b/litellm/utils.py index eb62204f5..c0529133d 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2710,6 +2710,16 @@ def get_optional_params( print_verbose( f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" ) + elif custom_llm_provider == "gemini": + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.GoogleAIStudioGeminiConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + model=model, + ) elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini": supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider @@ -3746,7 +3756,7 @@ def get_supported_openai_params( elif request_type == "embeddings": return litellm.DatabricksEmbeddingConfig().get_supported_openai_params() elif custom_llm_provider == "palm" or custom_llm_provider == "gemini": - return litellm.VertexAIConfig().get_supported_openai_params() + return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params() elif custom_llm_provider == "vertex_ai": if request_type == "chat_completion": return litellm.VertexAIConfig().get_supported_openai_params()