forked from phoenix/litellm-mirror
(feat) - Dirty hack to get response_mime_type working before it's released in the Python SDK.
This commit is contained in:
parent
649c3bb0dd
commit
d08674bf2f
1 changed files with 77 additions and 5 deletions
|
@ -3,7 +3,7 @@ import json
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
import requests
|
import requests
|
||||||
import time
|
import time
|
||||||
from typing import Callable, Optional, Union
|
from typing import Callable, Optional, Union, List
|
||||||
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
from litellm.utils import ModelResponse, Usage, CustomStreamWrapper, map_finish_reason
|
||||||
import litellm, uuid
|
import litellm, uuid
|
||||||
import httpx
|
import httpx
|
||||||
|
@ -308,6 +308,30 @@ def completion(
|
||||||
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore
|
||||||
import google.auth # type: ignore
|
import google.auth # type: ignore
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
)
|
||||||
|
self.response_mime_type = response_mime_type
|
||||||
|
|
||||||
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||||
|
@ -449,7 +473,7 @@ def completion(
|
||||||
|
|
||||||
model_response = llm_model.generate_content(
|
model_response = llm_model.generate_content(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
stream=True,
|
stream=True,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
|
@ -471,7 +495,7 @@ def completion(
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = llm_model.generate_content(
|
response = llm_model.generate_content(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
safety_settings=safety_settings,
|
safety_settings=safety_settings,
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
@ -712,6 +736,30 @@ async def async_completion(
|
||||||
try:
|
try:
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
)
|
||||||
|
self.response_mime_type = response_mime_type
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
print_verbose("\nMaking VertexAI Gemini Pro Vision Call")
|
||||||
print_verbose(f"\nProcessing input messages = {messages}")
|
print_verbose(f"\nProcessing input messages = {messages}")
|
||||||
|
@ -734,7 +782,7 @@ async def async_completion(
|
||||||
## LLM Call
|
## LLM Call
|
||||||
response = await llm_model._generate_content_async(
|
response = await llm_model._generate_content_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -920,6 +968,30 @@ async def async_streaming(
|
||||||
"""
|
"""
|
||||||
from vertexai.preview.generative_models import GenerationConfig
|
from vertexai.preview.generative_models import GenerationConfig
|
||||||
|
|
||||||
|
class ExtendedGenerationConfig(GenerationConfig):
|
||||||
|
"""Extended parameters for the generation."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
top_k: Optional[int] = None,
|
||||||
|
candidate_count: Optional[int] = None,
|
||||||
|
max_output_tokens: Optional[int] = None,
|
||||||
|
stop_sequences: Optional[List[str]] = None,
|
||||||
|
response_mime_type: Optional[str] = None,
|
||||||
|
):
|
||||||
|
super().__init__(
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
top_k=top_k,
|
||||||
|
candidate_count=candidate_count,
|
||||||
|
max_output_tokens=max_output_tokens,
|
||||||
|
stop_sequences=stop_sequences,
|
||||||
|
)
|
||||||
|
self.response_mime_type = response_mime_type
|
||||||
|
|
||||||
if mode == "vision":
|
if mode == "vision":
|
||||||
stream = optional_params.pop("stream")
|
stream = optional_params.pop("stream")
|
||||||
tools = optional_params.pop("tools", None)
|
tools = optional_params.pop("tools", None)
|
||||||
|
@ -940,7 +1012,7 @@ async def async_streaming(
|
||||||
|
|
||||||
response = await llm_model._generate_content_streaming_async(
|
response = await llm_model._generate_content_streaming_async(
|
||||||
contents=content,
|
contents=content,
|
||||||
generation_config=GenerationConfig(**optional_params),
|
generation_config=ExtendedGenerationConfig(**optional_params),
|
||||||
tools=tools,
|
tools=tools,
|
||||||
)
|
)
|
||||||
optional_params["stream"] = True
|
optional_params["stream"] = True
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue