mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 19:54:13 +00:00
improve(vertex_ai.py): Switch to simpler dict type.
This commit is contained in:
parent
6b730214dd
commit
aa5ee6a626
1 changed files with 23 additions and 114 deletions
|
@ -21,6 +21,29 @@ class VertexAIError(Exception):
|
||||||
self.message
|
self.message
|
||||||
) # Call the base class constructor with the parameters it needs
|
) # Call the base class constructor with the parameters it needs
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(dict):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
response_mime_type=response_mime_type,
|
||||||
|
)
|
||||||
|
|
||||||
class VertexAIConfig:
|
class VertexAIConfig:
|
||||||
"""
|
"""
|
||||||
|
@ -363,42 +386,6 @@ def completion(
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
|
||||||
class ExtendedGenerationConfig(GenerationConfig):
|
|
||||||
"""Extended parameters for the generation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
candidate_count: Optional[int] = None,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
response_mime_type: Optional[str] = None,
|
|
||||||
):
|
|
||||||
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
|
||||||
|
|
||||||
if "response_mime_type" in args_spec.args:
|
|
||||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
response_mime_type=response_mime_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
)
|
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||||
|
@ -811,45 +798,6 @@ async def async_completion(
|
||||||
Add support for acompletion calls for gemini-pro
|
Add support for acompletion calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
|
||||||
|
|
||||||
class ExtendedGenerationConfig(GenerationConfig):
|
|
||||||
"""Extended parameters for the generation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
candidate_count: Optional[int] = None,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
response_mime_type: Optional[str] = None,
|
|
||||||
):
|
|
||||||
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
|
||||||
|
|
||||||
if "response_mime_type" in args_spec.args:
|
|
||||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
response_mime_type=response_mime_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
@ -1056,45 +1004,6 @@ async def async_streaming(
|
||||||
"""
|
"""
|
||||||
Add support for async streaming calls for gemini-pro
|
Add support for async streaming calls for gemini-pro
|
||||||
"""
|
"""
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
|
||||||
|
|
||||||
class ExtendedGenerationConfig(GenerationConfig):
|
|
||||||
"""Extended parameters for the generation."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
*,
|
|
||||||
temperature: Optional[float] = None,
|
|
||||||
top_p: Optional[float] = None,
|
|
||||||
top_k: Optional[int] = None,
|
|
||||||
candidate_count: Optional[int] = None,
|
|
||||||
max_output_tokens: Optional[int] = None,
|
|
||||||
stop_sequences: Optional[List[str]] = None,
|
|
||||||
response_mime_type: Optional[str] = None,
|
|
||||||
):
|
|
||||||
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
|
||||||
|
|
||||||
if "response_mime_type" in args_spec.args:
|
|
||||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
response_mime_type=response_mime_type,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
top_k=top_k,
|
|
||||||
candidate_count=candidate_count,
|
|
||||||
max_output_tokens=max_output_tokens,
|
|
||||||
stop_sequences=stop_sequences,
|
|
||||||
)
|
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
stream = optional_params.pop("stream")
|
stream = optional_params.pop("stream")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue