fix(vertex_ai.py): check if 'response_mime_type' in generation config before passing it in

This commit is contained in:
Krrish Dholakia 2024-04-11 23:10:59 -07:00
parent c377ba0755
commit 77d6b882b8

View file

@ -6,7 +6,7 @@ import time
from typing import Callable, Optional, Union, List from typing import Callable, Optional, Union, List
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
import litellm, uuid import litellm, uuid
import httpx import httpx, inspect
class VertexAIError(Exception): class VertexAIError(Exception):
@ -322,15 +322,27 @@ def completion(
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None, response_mime_type: Optional[str] = None,
): ):
self._raw_generation_config = gapic_content_types.GenerationConfig( args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
temperature=temperature,
top_p=top_p, if "response_mime_type" in args_spec.args:
top_k=top_k, self._raw_generation_config = gapic_content_types.GenerationConfig(
candidate_count=candidate_count, temperature=temperature,
max_output_tokens=max_output_tokens, top_p=top_p,
stop_sequences=stop_sequences, top_k=top_k,
response_mime_type=response_mime_type, candidate_count=candidate_count,
) max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
else:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
)
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
print_verbose( print_verbose(
@ -751,15 +763,27 @@ async def async_completion(
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None, response_mime_type: Optional[str] = None,
): ):
self._raw_generation_config = gapic_content_types.GenerationConfig( args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
temperature=temperature,
top_p=top_p, if "response_mime_type" in args_spec.args:
top_k=top_k, self._raw_generation_config = gapic_content_types.GenerationConfig(
candidate_count=candidate_count, temperature=temperature,
max_output_tokens=max_output_tokens, top_p=top_p,
stop_sequences=stop_sequences, top_k=top_k,
response_mime_type=response_mime_type, candidate_count=candidate_count,
) max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
else:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
)
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
@ -984,15 +1008,27 @@ async def async_streaming(
stop_sequences: Optional[List[str]] = None, stop_sequences: Optional[List[str]] = None,
response_mime_type: Optional[str] = None, response_mime_type: Optional[str] = None,
): ):
self._raw_generation_config = gapic_content_types.GenerationConfig( args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
temperature=temperature,
top_p=top_p, if "response_mime_type" in args_spec.args:
top_k=top_k, self._raw_generation_config = gapic_content_types.GenerationConfig(
candidate_count=candidate_count, temperature=temperature,
max_output_tokens=max_output_tokens, top_p=top_p,
stop_sequences=stop_sequences, top_k=top_k,
response_mime_type=response_mime_type, candidate_count=candidate_count,
) max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
response_mime_type=response_mime_type,
)
else:
self._raw_generation_config = gapic_content_types.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
candidate_count=candidate_count,
max_output_tokens=max_output_tokens,
stop_sequences=stop_sequences,
)
if mode == "vision": if mode == "vision":
stream = optional_params.pop("stream") stream = optional_params.pop("stream")