forked from phoenix/litellm-mirror
fix(utils.py): re-integrate separate gemini optional param mapping (google ai studio)
Fixes https://github.com/BerriAI/litellm/issues/4333
This commit is contained in:
parent
d2b4d1f7ed
commit
16941eee43
4 changed files with 214 additions and 2 deletions
|
@ -783,7 +783,7 @@ from .llms.gemini import GeminiConfig
|
|||
from .llms.nlp_cloud import NLPCloudConfig
|
||||
from .llms.aleph_alpha import AlephAlphaConfig
|
||||
from .llms.petals import PetalsConfig
|
||||
from .llms.vertex_httpx import VertexGeminiConfig
|
||||
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
|
||||
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
|
||||
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
|
||||
from .llms.sagemaker import SagemakerConfig
|
||||
|
|
|
@ -48,6 +48,193 @@ from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
|||
from .base import BaseLLM
|
||||
|
||||
|
||||
class GoogleAIStudioGeminiConfig: # key diff from VertexAI - 'frequency_penalty' and 'presence_penalty' not supported
|
||||
"""
|
||||
Reference: https://ai.google.dev/api/rest/v1beta/GenerationConfig
|
||||
|
||||
The class `GoogleAIStudioGeminiConfig` provides configuration for the Google AI Studio's Gemini API interface. Below are the parameters:
|
||||
|
||||
- `temperature` (float): This controls the degree of randomness in token selection.
|
||||
|
||||
- `max_output_tokens` (integer): This sets the limitation for the maximum amount of token in the text output. In this case, the default value is 256.
|
||||
|
||||
- `top_p` (float): The tokens are selected from the most probable to the least probable until the sum of their probabilities equals the `top_p` value. Default is 0.95.
|
||||
|
||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||
|
||||
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'. Other values - `application/json`.
|
||||
|
||||
- `response_schema` (dict): Optional. Output response schema of the generated candidate text when response mime type can have schema. Schema can be objects, primitives or arrays and is a subset of OpenAPI schema. If set, a compatible response_mime_type must also be set. Compatible mimetypes: application/json: Schema for JSON response.
|
||||
|
||||
- `candidate_count` (int): Number of generated responses to return.
|
||||
|
||||
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||
|
||||
- `frequency_penalty` (float): This parameter is used to penalize the model from repeating the same output. The default value is 0.0.
|
||||
|
||||
- `presence_penalty` (float): This parameter is used to penalize the model from generating the same output as the input. The default value is 0.0.
|
||||
|
||||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
max_output_tokens: Optional[int] = None
|
||||
top_p: Optional[float] = None
|
||||
top_k: Optional[int] = None
|
||||
response_mime_type: Optional[str] = None
|
||||
response_schema: Optional[dict] = None
|
||||
candidate_count: Optional[int] = None
|
||||
stop_sequences: Optional[list] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
temperature: Optional[float] = None,
|
||||
max_output_tokens: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
top_k: Optional[int] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
response_schema: Optional[dict] = None,
|
||||
candidate_count: Optional[int] = None,
|
||||
stop_sequences: Optional[list] = None,
|
||||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
@classmethod
|
||||
def get_config(cls):
|
||||
return {
|
||||
k: v
|
||||
for k, v in cls.__dict__.items()
|
||||
if not k.startswith("__")
|
||||
and not isinstance(
|
||||
v,
|
||||
(
|
||||
types.FunctionType,
|
||||
types.BuiltinFunctionType,
|
||||
classmethod,
|
||||
staticmethod,
|
||||
),
|
||||
)
|
||||
and v is not None
|
||||
}
|
||||
|
||||
def get_supported_openai_params(self):
|
||||
return [
|
||||
"temperature",
|
||||
"top_p",
|
||||
"max_tokens",
|
||||
"stream",
|
||||
"tools",
|
||||
"tool_choice",
|
||||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
]
|
||||
|
||||
def map_tool_choice_values(
|
||||
self, model: str, tool_choice: Union[str, dict]
|
||||
) -> Optional[ToolConfig]:
|
||||
if tool_choice == "none":
|
||||
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="NONE"))
|
||||
elif tool_choice == "required":
|
||||
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="ANY"))
|
||||
elif tool_choice == "auto":
|
||||
return ToolConfig(functionCallingConfig=FunctionCallingConfig(mode="AUTO"))
|
||||
elif isinstance(tool_choice, dict):
|
||||
# only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html
|
||||
name = tool_choice.get("function", {}).get("name", "")
|
||||
return ToolConfig(
|
||||
functionCallingConfig=FunctionCallingConfig(
|
||||
mode="ANY", allowed_function_names=[name]
|
||||
)
|
||||
)
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
message="VertexAI doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format(
|
||||
tool_choice
|
||||
),
|
||||
status_code=400,
|
||||
)
|
||||
|
||||
def map_openai_params(
|
||||
self,
|
||||
model: str,
|
||||
non_default_params: dict,
|
||||
optional_params: dict,
|
||||
):
|
||||
for param, value in non_default_params.items():
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if (
|
||||
param == "stream" and value is True
|
||||
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||
optional_params["stream"] = value
|
||||
if param == "n":
|
||||
optional_params["candidate_count"] = value
|
||||
if param == "stop":
|
||||
if isinstance(value, str):
|
||||
optional_params["stop_sequences"] = [value]
|
||||
elif isinstance(value, list):
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "max_tokens":
|
||||
optional_params["max_output_tokens"] = value
|
||||
if param == "response_format" and value["type"] == "json_object": # type: ignore
|
||||
optional_params["response_mime_type"] = "application/json"
|
||||
if param == "tools" and isinstance(value, list):
|
||||
gtool_func_declarations = []
|
||||
for tool in value:
|
||||
gtool_func_declaration = FunctionDeclaration(
|
||||
name=tool["function"]["name"],
|
||||
description=tool["function"].get("description", ""),
|
||||
parameters=tool["function"].get("parameters", {}),
|
||||
)
|
||||
gtool_func_declarations.append(gtool_func_declaration)
|
||||
optional_params["tools"] = [
|
||||
Tools(function_declarations=gtool_func_declarations)
|
||||
]
|
||||
if param == "tool_choice" and (
|
||||
isinstance(value, str) or isinstance(value, dict)
|
||||
):
|
||||
_tool_choice_value = self.map_tool_choice_values(
|
||||
model=model, tool_choice=value # type: ignore
|
||||
)
|
||||
if _tool_choice_value is not None:
|
||||
optional_params["tool_choice"] = _tool_choice_value
|
||||
return optional_params
|
||||
|
||||
def get_mapped_special_auth_params(self) -> dict:
|
||||
"""
|
||||
Common auth params across bedrock/vertex_ai/azure/watsonx
|
||||
"""
|
||||
return {"project": "vertex_project", "region_name": "vertex_location"}
|
||||
|
||||
def map_special_auth_params(self, non_default_params: dict, optional_params: dict):
|
||||
mapped_params = self.get_mapped_special_auth_params()
|
||||
|
||||
for param, value in non_default_params.items():
|
||||
if param in mapped_params:
|
||||
optional_params[mapped_params[param]] = value
|
||||
return optional_params
|
||||
|
||||
def get_flagged_finish_reasons(self) -> Dict[str, str]:
|
||||
"""
|
||||
Return Dictionary of finish reasons which indicate response was flagged
|
||||
|
||||
and what it means
|
||||
"""
|
||||
return {
|
||||
"SAFETY": "The token generation was stopped as the response was flagged for safety reasons. NOTE: When streaming the Candidate.content will be empty if content filters blocked the output.",
|
||||
"RECITATION": "The token generation was stopped as the response was flagged for unauthorized citations.",
|
||||
"BLOCKLIST": "The token generation was stopped as the response was flagged for the terms which are included from the terminology blocklist.",
|
||||
"PROHIBITED_CONTENT": "The token generation was stopped as the response was flagged for the prohibited contents.",
|
||||
"SPII": "The token generation was stopped as the response was flagged for Sensitive Personally Identifiable Information (SPII) contents.",
|
||||
}
|
||||
|
||||
|
||||
class VertexGeminiConfig:
|
||||
"""
|
||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||
|
@ -132,6 +319,8 @@ class VertexGeminiConfig:
|
|||
"response_format",
|
||||
"n",
|
||||
"stop",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
]
|
||||
|
||||
def map_tool_choice_values(
|
||||
|
|
|
@ -103,6 +103,19 @@ def test_databricks_optional_params():
|
|||
assert "user" not in optional_params
|
||||
|
||||
|
||||
def test_gemini_optional_params():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params(
|
||||
model="",
|
||||
custom_llm_provider="gemini",
|
||||
max_tokens=10,
|
||||
frequency_penalty=10,
|
||||
)
|
||||
print(f"optional_params: {optional_params}")
|
||||
assert len(optional_params) == 1
|
||||
assert "frequency_penalty" not in optional_params
|
||||
|
||||
|
||||
def test_azure_ai_mistral_optional_params():
|
||||
litellm.drop_params = True
|
||||
optional_params = get_optional_params(
|
||||
|
|
|
@ -2710,6 +2710,16 @@ def get_optional_params(
|
|||
print_verbose(
|
||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.GoogleAIStudioGeminiConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
)
|
||||
elif custom_llm_provider == "vertex_ai_beta" or custom_llm_provider == "gemini":
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
|
@ -3746,7 +3756,7 @@ def get_supported_openai_params(
|
|||
elif request_type == "embeddings":
|
||||
return litellm.DatabricksEmbeddingConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "palm" or custom_llm_provider == "gemini":
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
|
||||
elif custom_llm_provider == "vertex_ai":
|
||||
if request_type == "chat_completion":
|
||||
return litellm.VertexAIConfig().get_supported_openai_params()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue