From aa5ee6a626374eb8cc2b3419d3a76b10c09715e2 Mon Sep 17 00:00:00 2001 From: David Manouchehri Date: Mon, 22 Apr 2024 17:00:37 +0000 Subject: [PATCH] improve(vertex_ai.py): Switch to simpler dict type. --- litellm/llms/vertex_ai.py | 137 +++++++------------------------------- 1 file changed, 23 insertions(+), 114 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index fc2d882afd..d5836d52c3 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -21,6 +21,29 @@ class VertexAIError(Exception): self.message ) # Call the base class constructor with the parameters it needs +class ExtendedGenerationConfig(dict): + """Extended parameters for the generation.""" + + def __init__( + self, + *, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + response_mime_type: Optional[str] = None, + ): + super().__init__( + temperature=temperature, + top_p=top_p, + top_k=top_k, + candidate_count=candidate_count, + max_output_tokens=max_output_tokens, + stop_sequences=stop_sequences, + response_mime_type=response_mime_type, + ) class VertexAIConfig: """ @@ -363,42 +386,6 @@ def completion( from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore import google.auth # type: ignore - class ExtendedGenerationConfig(GenerationConfig): - """Extended parameters for the generation.""" - - def __init__( - self, - *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, - response_mime_type: Optional[str] = None, - ): - args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig) - - if "response_mime_type" in args_spec.args: - self._raw_generation_config = gapic_content_types.GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - response_mime_type=response_mime_type, - ) - else: - self._raw_generation_config = gapic_content_types.GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - ) - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" @@ -811,45 +798,6 @@ async def async_completion( Add support for acompletion calls for gemini-pro """ try: - from vertexai.preview.generative_models import GenerationConfig - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - - class ExtendedGenerationConfig(GenerationConfig): - """Extended parameters for the generation.""" - - def __init__( - self, - *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, - response_mime_type: Optional[str] = None, - ): - args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig) - - if "response_mime_type" in args_spec.args: - self._raw_generation_config = gapic_content_types.GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - response_mime_type=response_mime_type, - ) - else: - self._raw_generation_config = gapic_content_types.GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - ) - if mode == "vision": print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") @@ -1056,45 +1004,6 @@ async def async_streaming( """ Add support for async streaming calls for gemini-pro """ - from vertexai.preview.generative_models import GenerationConfig - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - - class ExtendedGenerationConfig(GenerationConfig): - """Extended parameters for the generation.""" - - def __init__( - self, - *, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - top_k: Optional[int] = None, - candidate_count: Optional[int] = None, - max_output_tokens: Optional[int] = None, - stop_sequences: Optional[List[str]] = None, - response_mime_type: Optional[str] = None, - ): - args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig) - - if "response_mime_type" in args_spec.args: - self._raw_generation_config = gapic_content_types.GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - response_mime_type=response_mime_type, - ) - else: - self._raw_generation_config = gapic_content_types.GenerationConfig( - temperature=temperature, - top_p=top_p, - top_k=top_k, - candidate_count=candidate_count, - max_output_tokens=max_output_tokens, - stop_sequences=stop_sequences, - ) - if mode == "vision": stream = optional_params.pop("stream") tools = optional_params.pop("tools", None)