forked from phoenix/litellm-mirror
Merge pull request #2964 from Manouchehri/gemini-json-mode-2962
Add JSON mode to Gemini (Vertex AI)
This commit is contained in:
commit
cd834e9d52
3 changed files with 105 additions and 6 deletions
|
@ -3,7 +3,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union, List
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -25,6 +25,7 @@ class VertexAIError(Exception):
|
||||||
class VertexAIConfig:
|
class VertexAIConfig:
|
||||||
"""
|
"""
|
||||||
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
Reference: https://cloud.google.com/vertex-ai/docs/generative-ai/chat/test-chat-prompts
|
||||||
|
Reference: https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference
|
||||||
|
|
||||||
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
The class `VertexAIConfig` provides configuration for the VertexAI's API interface. Below are the parameters:
|
||||||
|
|
||||||
|
@ -36,6 +37,12 @@ class VertexAIConfig:
|
||||||
|
|
||||||
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
- `top_k` (integer): The value of `top_k` determines how many of the most probable tokens are considered in the selection. For example, a `top_k` of 1 means the selected token is the most probable among all tokens. The default value is 40.
|
||||||
|
|
||||||
|
- `response_mime_type` (str): The MIME type of the response. The default value is 'text/plain'.
|
||||||
|
|
||||||
|
- `candidate_count` (int): Number of generated responses to return.
|
||||||
|
|
||||||
|
- `stop_sequences` (List[str]): The set of character sequences (up to 5) that will stop output generation. If specified, the API will stop at the first appearance of a stop sequence. The stop sequence will not be included as part of the response.
|
||||||
|
|
||||||
Note: Please make sure to modify the default parameters as required for your use case.
|
Note: Please make sure to modify the default parameters as required for your use case.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -43,6 +50,9 @@ class VertexAIConfig:
|
||||||
max_output_tokens: Optional[int] = None
|
max_output_tokens: Optional[int] = None
|
||||||
top_p: Optional[float] = None
|
top_p: Optional[float] = None
|
||||||
top_k: Optional[int] = None
|
top_k: Optional[int] = None
|
||||||
|
response_mime_type: Optional[str] = None
|
||||||
|
candidate_count: Optional[int] = None
|
||||||
|
stop_sequences: Optional[list] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -50,6 +60,9 @@ class VertexAIConfig:
|
||||||
max_output_tokens: Optional[int] = None,
|
max_output_tokens: Optional[int] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
top_k: Optional[int] = None,
|
top_k: Optional[int] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[list] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
locals_ = locals()
|
locals_ = locals()
|
||||||
for key, value in locals_.items():
|
for key, value in locals_.items():
|
||||||
|
@ -295,6 +308,30 @@ def completion(
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||||
|
@ -436,7 +473,7 @@ def completion(
|
||||||
|
|
||||||
model_response = llm_model.generate_content(
|
model_response = llm_model.generate_content(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
stream=True,
|
stream=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -458,7 +495,7 @@ def completion(
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = llm_model.generate_content(
|
response = llm_model.generate_content(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
@ -698,6 +735,31 @@ async def async_completion(
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
|
@ -721,7 +783,7 @@ async def async_completion(
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = await llm_model._generate_content_async(
|
response = await llm_model._generate_content_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -906,6 +968,31 @@ async def async_streaming(
|
||||||
Add support for async streaming calls for gemini-pro
|
Add support for async streaming calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
stream = optional_params.pop("stream")
|
stream = optional_params.pop("stream")
|
||||||
|
@ -927,7 +1014,7 @@ async def async_streaming(
|
||||||
|
|
||||||
response = await llm_model._generate_content_streaming_async(
|
response = await llm_model._generate_content_streaming_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
|
|
|
@ -4840,8 +4840,17 @@ def get_optional_params(
|
||||||
optional_params["top_p"] = top_p
|
optional_params["top_p"] = top_p
|
||||||
if stream:
|
if stream:
|
||||||
optional_params["stream"] = stream
|
optional_params["stream"] = stream
|
||||||
|
if n is not None:
|
||||||
|
optional_params["candidate_count"] = n
|
||||||
|
if stop is not None:
|
||||||
|
if isinstance(stop, str):
|
||||||
|
optional_params["stop_sequences"] = [stop]
|
||||||
|
elif isinstance(stop, list):
|
||||||
|
optional_params["stop_sequences"] = stop
|
||||||
if max_tokens is not None:
|
if max_tokens is not None:
|
||||||
optional_params["max_output_tokens"] = max_tokens
|
optional_params["max_output_tokens"] = max_tokens
|
||||||
|
if response_format is not None and response_format["type"] == "json_object":
|
||||||
|
optional_params["response_mime_type"] = "application/json"
|
||||||
if tools is not None and isinstance(tools, list):
|
if tools is not None and isinstance(tools, list):
|
||||||
from vertexai.preview import generative_models
|
from vertexai.preview import generative_models
|
||||||
|
|
||||||
|
@ -5528,6 +5537,9 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
|
||||||
"stream",
|
"stream",
|
||||||
"tools",
|
"tools",
|
||||||
"tool_choice",
|
"tool_choice",
|
||||||
|
"response_format",
|
||||||
|
"n",
|
||||||
|
"stop",
|
||||||
]
|
]
|
||||||
elif custom_llm_provider == "sagemaker":
|
elif custom_llm_provider == "sagemaker":
|
||||||
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
return ["stream", "temperature", "max_tokens", "top_p", "stop", "n"]
|
||||||
|
|
|
@ -14,7 +14,7 @@ pandas==2.1.1 # for viewing clickhouse spend analytics
|
||||||
prisma==0.11.0 # for db
|
prisma==0.11.0 # for db
|
||||||
mangum==0.17.0 # for aws lambda functions
|
mangum==0.17.0 # for aws lambda functions
|
||||||
pynacl==1.5.0 # for encrypting keys
|
pynacl==1.5.0 # for encrypting keys
|
||||||
google-cloud-aiplatform==1.43.0 # for vertex ai calls
|
google-cloud-aiplatform==1.47.0 # for vertex ai calls
|
||||||
anthropic[vertex]==0.21.3
|
anthropic[vertex]==0.21.3
|
||||||
google-generativeai==0.3.2 # for vertex ai calls
|
google-generativeai==0.3.2 # for vertex ai calls
|
||||||
async_generator==1.10.0 # for async ollama calls
|
async_generator==1.10.0 # for async ollama calls
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue