Merge pull request #2964 from Manouchehri/gemini-json-mode-2962

Add JSON mode to Gemini (Vertex AI)
This commit is contained in:
Krish Dholakia 2024-04-11 17:51:27 -07:00 committed by GitHub
commit cd834e9d52
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 105 additions and 6 deletions

View file

@ -3,7 +3,7 @@ import json
from enum import Enum from enum import Enum
import requests import requests
import time import time
from typing import Callable, Optional, Union from typing import Callable, Optional, Union, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx import httpx
@ -25,6 +25,7 @@ class VertexAIError(Exception):
class VertexAIConfig: class VertexAIConfig:
""" """
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters: The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
@ -36,6 +37,12 @@ class VertexAIConfig:
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40. - `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
- `candidate_count` (int): Number of generated responses to return.
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
Note: Please make sure to modify the default parameters as required for your use case. Note: Please make sure to modify the default parameters as required for your use case.
""" """
@ -43,6 +50,9 @@ class VertexAIConfig:
max_output_tokens: Optional[int] = None max_output_tokens: Optional[int] = None
top_p: Optional[float] = None top_p: Optional[float] = None
top_k: Optional[int] = None top_k: Optional[int] = None
response_mime_type: Optional[str] = None
candidate_count: Optional[int] = None
stop_sequences: Optional[list] = None
def __init__( def __init__(
self, self,
@ -50,6 +60,9 @@ class VertexAIConfig:
max_output_tokens: Optional[int] = None, max_output_tokens: Optional[int] = None,
top_p: Optional[float] = None, top_p: Optional[float] = None,
top_k: Optional[int] = None, top_k: Optional[int] = None,
response_mime_type: Optional[str] = None,
candidate_count: Optional[int] = None,
stop_sequences: Optional[list] = None,
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
@ -295,6 +308,30 @@ def completion(
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
import google.auth # type: ignore import google.auth # type: ignore
class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
):
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
@ -436,7 +473,7 @@ def completion(
model_response = llm_model.generate_content( model_response = llm_model.generate_content(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=ExtendedGenerationConfig(**optional_params),
safety_settings=safety_settings, safety_settings=safety_settings,
stream=True, stream=True,
tools=tools, tools=tools,
@ -458,7 +495,7 @@ def completion(
## LLM Call ## LLM Call
response = llm_model.generate_content( response = llm_model.generate_content(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=ExtendedGenerationConfig(**optional_params),
safety_settings=safety_settings, safety_settings=safety_settings,
tools=tools, tools=tools,
) )
@ -698,6 +735,31 @@ async def async_completion(
""" """
try: try:
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
):
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
@ -721,7 +783,7 @@ async def async_completion(
## LLM Call ## LLM Call
response = await llm_model._generate_content_async( response = await llm_model._generate_content_async(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=ExtendedGenerationConfig(**optional_params),
tools=tools, tools=tools,
) )
@ -906,6 +968,31 @@ async def async_streaming(
Add support for async streaming calls for gemini-pro Add support for async streaming calls for gemini-pro
""" """
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation."""
def __init__(
self,
*,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
top_k: Optional[int] = None,
candidate_count: Optional[int] = None,
max_output_tokens: Optional[int] = None,
stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None,
):
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
if mode == "vision": if mode == "vision":
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
@ -927,7 +1014,7 @@ async def async_streaming(
response = await llm_model._generate_content_streaming_async( response = await llm_model._generate_content_streaming_async(
contents=content, contents=content,
generation_config=GenerationConfig(**optional_params), generation_config=ExtendedGenerationConfig(**optional_params),
tools=tools, tools=tools,
) )
optional_params["stream"] = True optional_params["stream"] = True

View file

@ -4840,8 +4840,17 @@ def get_optional_params(
optional_params["top_p"] = top_p optional_params["top_p"] = top_p
if stream: if stream:
optional_params["stream"] = stream optional_params["stream"] = stream
if n is not None:
optional_params["candidate_count"] = n
if stop is not None:
if isinstance(stop, str):
optional_params["stop_sequences"] = [stop]
elif isinstance(stop, list):
optional_params["stop_sequences"] = stop
if max_tokens is not None: if max_tokens is not None:
optional_params["max_output_tokens"] = max_tokens optional_params["max_output_tokens"] = max_tokens
if response_format is not None and response_format["type"] == "json_object":
optional_params["response_mime_type"] = "application/json"
if tools is not None and isinstance(tools, list): if tools is not None and isinstance(tools, list):
from vertexai.preview import generative_models from vertexai.preview import generative_models
@ -5528,6 +5537,9 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
"stream", "stream",
"tools", "tools",
"tool_choice", "tool_choice",
"response_format",
"n",
"stop",
] ]
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"] return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]

View file

@ -14,7 +14,7 @@ pandas==2.1.1 # for viewing clickhouse spend analytics
prisma==0.11.0 # for db prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions mangum==0.17.0 # for aws lambda functions
pynacl==1.5.0 # for encrypting keys pynacl==1.5.0 # for encrypting keys
google-cloud-aiplatform==1.43.0 # for vertex ai calls google-cloud-aiplatform==1.47.0 # for vertex ai calls
anthropic[vertex]==0.21.3 anthropic[vertex]==0.21.3
google-generativeai==0.3.2 # for vertex ai calls google-generativeai==0.3.2 # for vertex ai calls
async_generator==1.10.0 # for async ollama calls async_generator==1.10.0 # for async ollama calls