diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 9a5e92828..3bd4579e4 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -3,7 +3,7 @@ import json from enum import Enum import requests import time -from typing import Callable, Optional, Union +from typing import Callable, Optional, Union, List from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason import litellm, uuid import httpx @@ -308,6 +308,30 @@ def completion( from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore import google.auth # type: ignore + class ExtendedGenerationConfig(GenerationConfig): + """Extended parameters for the generation.""" + + def __init__( + self, + *, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + response_mime_type: Optional[str] = None, + ): + super().__init__( + temperature=temperature, + top_p=top_p, + top_k=top_k, + candidate_count=candidate_count, + max_output_tokens=max_output_tokens, + stop_sequences=stop_sequences, + ) + self.response_mime_type = response_mime_type + ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 print_verbose( f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" @@ -449,7 +473,7 @@ def completion( model_response = llm_model.generate_content( contents=content, - generation_config=GenerationConfig(**optional_params), + generation_config=ExtendedGenerationConfig(**optional_params), safety_settings=safety_settings, stream=True, tools=tools, @@ -471,7 +495,7 @@ def completion( ## LLM Call response = llm_model.generate_content( contents=content, - generation_config=GenerationConfig(**optional_params), + generation_config=ExtendedGenerationConfig(**optional_params), safety_settings=safety_settings, tools=tools, ) @@ -712,6 +736,30 @@ async def async_completion( try: from vertexai.preview.generative_models import GenerationConfig + class ExtendedGenerationConfig(GenerationConfig): + """Extended parameters for the generation.""" + + def __init__( + self, + *, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + response_mime_type: Optional[str] = None, + ): + super().__init__( + temperature=temperature, + top_p=top_p, + top_k=top_k, + candidate_count=candidate_count, + max_output_tokens=max_output_tokens, + stop_sequences=stop_sequences, + ) + self.response_mime_type = response_mime_type + if mode == "vision": print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose(f"\nProcessing input messages = {messages}") @@ -734,7 +782,7 @@ async def async_completion( ## LLM Call response = await llm_model._generate_content_async( contents=content, - generation_config=GenerationConfig(**optional_params), + generation_config=ExtendedGenerationConfig(**optional_params), tools=tools, ) @@ -920,6 +968,30 @@ async def async_streaming( """ from vertexai.preview.generative_models import GenerationConfig + class ExtendedGenerationConfig(GenerationConfig): + """Extended parameters for the generation.""" + + def __init__( + self, + *, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + top_k: Optional[int] = None, + candidate_count: Optional[int] = None, + max_output_tokens: Optional[int] = None, + stop_sequences: Optional[List[str]] = None, + response_mime_type: Optional[str] = None, + ): + super().__init__( + temperature=temperature, + top_p=top_p, + top_k=top_k, + candidate_count=candidate_count, + max_output_tokens=max_output_tokens, + stop_sequences=stop_sequences, + ) + self.response_mime_type = response_mime_type + if mode == "vision": stream = optional_params.pop("stream") tools = optional_params.pop("tools", None) @@ -940,7 +1012,7 @@ async def async_streaming( response = await llm_model._generate_content_streaming_async( contents=content, - generation_config=GenerationConfig(**optional_params), + generation_config=ExtendedGenerationConfig(**optional_params), tools=tools, ) optional_params["stream"] = True