(feat) - Extreme dirty hack for response_mime_type in Vertex AI.

This commit is contained in:
David Manouchehri 2024-04-11 23:45:41 +00:00
parent d08674bf2f
commit 05350037be
No known key found for this signature in database

View file

@ -322,15 +322,15 @@ def completion(
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None, response_mime_type: Optional[str] = None,
): ):
super().__init__( self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
candidate_count=candidate_count, candidate_count=candidate_count,
max_output_tokens=max_output_tokens, max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences, stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
) )
self.response_mime_type = response_mime_type
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -735,6 +735,7 @@ async def async_completion(
""" """
try: try:
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
class ExtendedGenerationConfig(GenerationConfig): class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation.""" """Extended parameters for the generation."""
@ -750,15 +751,15 @@ async def async_completion(
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None, response_mime_type: Optional[str] = None,
): ):
super().__init__( self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
candidate_count=candidate_count, candidate_count=candidate_count,
max_output_tokens=max_output_tokens, max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences, stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
) )
self.response_mime_type = response_mime_type
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
@ -967,6 +968,7 @@ async def async_streaming(
Add support for async streaming calls for gemini-pro Add support for async streaming calls for gemini-pro
""" """
from vertexai.preview.generative_models import GenerationConfig from vertexai.preview.generative_models import GenerationConfig
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
class ExtendedGenerationConfig(GenerationConfig): class ExtendedGenerationConfig(GenerationConfig):
"""Extended parameters for the generation.""" """Extended parameters for the generation."""
@ -982,15 +984,15 @@ async def async_streaming(
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None, response_mime_type: Optional[str] = None,
): ):
super().__init__( self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature, temperature=temperature,
top_p=top_p, top_p=top_p,
top_k=top_k, top_k=top_k,
candidate_count=candidate_count, candidate_count=candidate_count,
max_output_tokens=max_output_tokens, max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences, stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
) )
self.response_mime_type = response_mime_type
if mode == "vision": if mode == "vision":
stream = optional_params.pop("stream") stream = optional_params.pop("stream")