mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(vertex_ai.py): check if 'response_mime_type' in generation config before passing it in
This commit is contained in:
parent
c377ba0755
commit
77d6b882b8
1 changed files with 64 additions and 28 deletions
|
@ -6,7 +6,7 @@ import time
|
|||
from typing import Callable, Optional, Union, List
|
||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||
import litellm, uuid
|
||||
import httpx
|
||||
import httpx, inspect
|
||||
|
||||
|
||||
class VertexAIError(Exception):
|
||||
|
@ -322,6 +322,9 @@ def completion(
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
):
|
||||
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
||||
|
||||
if "response_mime_type" in args_spec.args:
|
||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -331,6 +334,15 @@ def completion(
|
|||
stop_sequences=stop_sequences,
|
||||
response_mime_type=response_mime_type,
|
||||
)
|
||||
else:
|
||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
candidate_count=candidate_count,
|
||||
max_output_tokens=max_output_tokens,
|
||||
stop_sequences=stop_sequences,
|
||||
)
|
||||
|
||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||
print_verbose(
|
||||
|
@ -751,6 +763,9 @@ async def async_completion(
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
):
|
||||
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
||||
|
||||
if "response_mime_type" in args_spec.args:
|
||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -760,6 +775,15 @@ async def async_completion(
|
|||
stop_sequences=stop_sequences,
|
||||
response_mime_type=response_mime_type,
|
||||
)
|
||||
else:
|
||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
candidate_count=candidate_count,
|
||||
max_output_tokens=max_output_tokens,
|
||||
stop_sequences=stop_sequences,
|
||||
)
|
||||
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||
|
@ -984,6 +1008,9 @@ async def async_streaming(
|
|||
stop_sequences: Optional[List[str]] = None,
|
||||
response_mime_type: Optional[str] = None,
|
||||
):
|
||||
args_spec = inspect.getfullargspec(gapic_content_types.GenerationConfig)
|
||||
|
||||
if "response_mime_type" in args_spec.args:
|
||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -993,6 +1020,15 @@ async def async_streaming(
|
|||
stop_sequences=stop_sequences,
|
||||
response_mime_type=response_mime_type,
|
||||
)
|
||||
else:
|
||||
self._raw_generation_config = gapic_content_types.GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
candidate_count=candidate_count,
|
||||
max_output_tokens=max_output_tokens,
|
||||
stop_sequences=stop_sequences,
|
||||
)
|
||||
|
||||
if mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue