mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
vertex ai anthropic thinking param support (#8853)
* fix(vertex_llm_base.py): handle credentials passed in as dictionary * fix(router.py): support vertex credentials as json dict * test(test_vertex.py): allows easier testing mock anthropic thinking response for vertex ai * test(vertex_ai_partner_models/): don't remove "@" from model breaks anthropic cost calculation * test: move testing * fix: fix linting error * fix: fix linting error * fix(vertex_ai_partner_models/main.py): split @ for codestral model * test: fix test * fix: fix stripping "@" on mistral models * fix: fix test * test: fix test
This commit is contained in:
parent
992e78dfd8
commit
88eedb22b9
15 changed files with 135 additions and 45 deletions
2
.gitignore
vendored
2
.gitignore
vendored
|
@ -77,3 +77,5 @@ litellm/proxy/_experimental/out/404.html
|
||||||
litellm/proxy/_experimental/out/model_hub.html
|
litellm/proxy/_experimental/out/model_hub.html
|
||||||
.mypy_cache/*
|
.mypy_cache/*
|
||||||
litellm/proxy/application.log
|
litellm/proxy/application.log
|
||||||
|
tests/llm_translation/vertex_test_account.json
|
||||||
|
tests/llm_translation/test_vertex_key.json
|
||||||
|
|
|
@ -10,7 +10,10 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
from litellm.types.llms.openai import Batch, CreateBatchRequest
|
from litellm.types.llms.openai import Batch, CreateBatchRequest
|
||||||
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
VERTEX_CREDENTIALS_TYPES,
|
||||||
|
VertexAIBatchPredictionJob,
|
||||||
|
)
|
||||||
|
|
||||||
from .transformation import VertexAIBatchTransformation
|
from .transformation import VertexAIBatchTransformation
|
||||||
|
|
||||||
|
@ -25,7 +28,7 @@ class VertexAIBatchPrediction(VertexLLM):
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
create_batch_data: CreateBatchRequest,
|
create_batch_data: CreateBatchRequest,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
@ -130,7 +133,7 @@ class VertexAIBatchPrediction(VertexLLM):
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
batch_id: str,
|
batch_id: str,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
|
|
@ -9,6 +9,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket_base import (
|
||||||
)
|
)
|
||||||
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
|
||||||
from litellm.types.llms.openai import CreateFileRequest, FileObject
|
from litellm.types.llms.openai import CreateFileRequest, FileObject
|
||||||
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
from .transformation import VertexAIFilesTransformation
|
from .transformation import VertexAIFilesTransformation
|
||||||
|
|
||||||
|
@ -34,7 +35,7 @@ class VertexAIFilesHandler(GCSBucketBase):
|
||||||
self,
|
self,
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
@ -70,7 +71,7 @@ class VertexAIFilesHandler(GCSBucketBase):
|
||||||
_is_async: bool,
|
_is_async: bool,
|
||||||
create_file_data: CreateFileRequest,
|
create_file_data: CreateFileRequest,
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
|
|
|
@ -13,6 +13,7 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import Ver
|
||||||
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
|
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
|
||||||
from litellm.types.llms.openai import FineTuningJobCreate
|
from litellm.types.llms.openai import FineTuningJobCreate
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
VERTEX_CREDENTIALS_TYPES,
|
||||||
FineTuneHyperparameters,
|
FineTuneHyperparameters,
|
||||||
FineTuneJobCreate,
|
FineTuneJobCreate,
|
||||||
FineTunesupervisedTuningSpec,
|
FineTunesupervisedTuningSpec,
|
||||||
|
@ -222,7 +223,7 @@ class VertexFineTuningAPI(VertexLLM):
|
||||||
create_fine_tuning_job_data: FineTuningJobCreate,
|
create_fine_tuning_job_data: FineTuningJobCreate,
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
kwargs: Optional[dict] = None,
|
kwargs: Optional[dict] = None,
|
||||||
|
|
|
@ -40,6 +40,7 @@ from litellm.types.llms.openai import (
|
||||||
ChatCompletionUsageBlock,
|
ChatCompletionUsageBlock,
|
||||||
)
|
)
|
||||||
from litellm.types.llms.vertex_ai import (
|
from litellm.types.llms.vertex_ai import (
|
||||||
|
VERTEX_CREDENTIALS_TYPES,
|
||||||
Candidates,
|
Candidates,
|
||||||
ContentType,
|
ContentType,
|
||||||
FunctionCallingConfig,
|
FunctionCallingConfig,
|
||||||
|
@ -930,7 +931,7 @@ class VertexLLM(VertexBase):
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
vertex_project: Optional[str] = None,
|
vertex_project: Optional[str] = None,
|
||||||
vertex_location: Optional[str] = None,
|
vertex_location: Optional[str] = None,
|
||||||
vertex_credentials: Optional[str] = None,
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||||
gemini_api_key: Optional[str] = None,
|
gemini_api_key: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
) -> CustomStreamWrapper:
|
) -> CustomStreamWrapper:
|
||||||
|
@ -1018,7 +1019,7 @@ class VertexLLM(VertexBase):
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
vertex_project: Optional[str] = None,
|
vertex_project: Optional[str] = None,
|
||||||
vertex_location: Optional[str] = None,
|
vertex_location: Optional[str] = None,
|
||||||
vertex_credentials: Optional[str] = None,
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||||
gemini_api_key: Optional[str] = None,
|
gemini_api_key: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
|
@ -1123,7 +1124,7 @@ class VertexLLM(VertexBase):
|
||||||
timeout: Optional[Union[float, httpx.Timeout]],
|
timeout: Optional[Union[float, httpx.Timeout]],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
gemini_api_key: Optional[str],
|
gemini_api_key: Optional[str],
|
||||||
litellm_params: dict,
|
litellm_params: dict,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
|
|
@ -11,6 +11,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
get_async_httpx_client,
|
get_async_httpx_client,
|
||||||
)
|
)
|
||||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
from litellm.types.utils import ImageResponse
|
from litellm.types.utils import ImageResponse
|
||||||
|
|
||||||
|
|
||||||
|
@ -44,7 +45,7 @@ class VertexImageGeneration(VertexLLM):
|
||||||
prompt: str,
|
prompt: str,
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
model_response: ImageResponse,
|
model_response: ImageResponse,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
model: Optional[
|
model: Optional[
|
||||||
|
@ -139,7 +140,7 @@ class VertexImageGeneration(VertexLLM):
|
||||||
prompt: str,
|
prompt: str,
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
model_response: litellm.ImageResponse,
|
model_response: litellm.ImageResponse,
|
||||||
logging_obj: Any,
|
logging_obj: Any,
|
||||||
model: Optional[
|
model: Optional[
|
||||||
|
|
|
@ -9,6 +9,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
||||||
)
|
)
|
||||||
from litellm.llms.openai.openai import HttpxBinaryResponseContent
|
from litellm.llms.openai.openai import HttpxBinaryResponseContent
|
||||||
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
|
||||||
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
|
|
||||||
class VertexInput(TypedDict, total=False):
|
class VertexInput(TypedDict, total=False):
|
||||||
|
@ -45,7 +46,7 @@ class VertexTextToSpeechAPI(VertexLLM):
|
||||||
logging_obj,
|
logging_obj,
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
timeout: Union[float, httpx.Timeout],
|
timeout: Union[float, httpx.Timeout],
|
||||||
model: str,
|
model: str,
|
||||||
|
|
|
@ -160,7 +160,8 @@ class VertexAIPartnerModels(VertexBase):
|
||||||
url=default_api_base,
|
url=default_api_base,
|
||||||
)
|
)
|
||||||
|
|
||||||
model = model.split("@")[0]
|
if "codestral" in model or "mistral" in model:
|
||||||
|
model = model.split("@")[0]
|
||||||
|
|
||||||
if "codestral" in model and litellm_params.get("text_completion") is True:
|
if "codestral" in model and litellm_params.get("text_completion") is True:
|
||||||
optional_params["model"] = model
|
optional_params["model"] = model
|
||||||
|
|
|
@ -41,7 +41,7 @@ class VertexEmbedding(VertexBase):
|
||||||
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
|
||||||
vertex_project: Optional[str] = None,
|
vertex_project: Optional[str] = None,
|
||||||
vertex_location: Optional[str] = None,
|
vertex_location: Optional[str] = None,
|
||||||
vertex_credentials: Optional[str] = None,
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||||
gemini_api_key: Optional[str] = None,
|
gemini_api_key: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
) -> EmbeddingResponse:
|
) -> EmbeddingResponse:
|
||||||
|
@ -148,7 +148,7 @@ class VertexEmbedding(VertexBase):
|
||||||
client: Optional[AsyncHTTPHandler] = None,
|
client: Optional[AsyncHTTPHandler] = None,
|
||||||
vertex_project: Optional[str] = None,
|
vertex_project: Optional[str] = None,
|
||||||
vertex_location: Optional[str] = None,
|
vertex_location: Optional[str] = None,
|
||||||
vertex_credentials: Optional[str] = None,
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||||
gemini_api_key: Optional[str] = None,
|
gemini_api_key: Optional[str] = None,
|
||||||
extra_headers: Optional[dict] = None,
|
extra_headers: Optional[dict] = None,
|
||||||
encoding=None,
|
encoding=None,
|
||||||
|
|
|
@ -12,6 +12,7 @@ from litellm._logging import verbose_logger
|
||||||
from litellm.litellm_core_utils.asyncify import asyncify
|
from litellm.litellm_core_utils.asyncify import asyncify
|
||||||
from litellm.llms.base import BaseLLM
|
from litellm.llms.base import BaseLLM
|
||||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||||
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
|
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
|
||||||
|
|
||||||
|
@ -34,7 +35,7 @@ class VertexBase(BaseLLM):
|
||||||
return vertex_region or "us-central1"
|
return vertex_region or "us-central1"
|
||||||
|
|
||||||
def load_auth(
|
def load_auth(
|
||||||
self, credentials: Optional[str], project_id: Optional[str]
|
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
|
||||||
) -> Tuple[Any, str]:
|
) -> Tuple[Any, str]:
|
||||||
import google.auth as google_auth
|
import google.auth as google_auth
|
||||||
from google.auth import identity_pool
|
from google.auth import identity_pool
|
||||||
|
@ -42,29 +43,36 @@ class VertexBase(BaseLLM):
|
||||||
Request, # type: ignore[import-untyped]
|
Request, # type: ignore[import-untyped]
|
||||||
)
|
)
|
||||||
|
|
||||||
if credentials is not None and isinstance(credentials, str):
|
if credentials is not None:
|
||||||
import google.oauth2.service_account
|
import google.oauth2.service_account
|
||||||
|
|
||||||
verbose_logger.debug(
|
if isinstance(credentials, str):
|
||||||
"Vertex: Loading vertex credentials from %s", credentials
|
verbose_logger.debug(
|
||||||
)
|
"Vertex: Loading vertex credentials from %s", credentials
|
||||||
verbose_logger.debug(
|
)
|
||||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
verbose_logger.debug(
|
||||||
credentials,
|
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||||
os.path.exists(credentials),
|
credentials,
|
||||||
os.getcwd(),
|
os.path.exists(credentials),
|
||||||
)
|
os.getcwd(),
|
||||||
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if os.path.exists(credentials):
|
if os.path.exists(credentials):
|
||||||
json_obj = json.load(open(credentials))
|
json_obj = json.load(open(credentials))
|
||||||
else:
|
else:
|
||||||
json_obj = json.loads(credentials)
|
json_obj = json.loads(credentials)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Unable to load vertex credentials from environment. Got={}".format(
|
"Unable to load vertex credentials from environment. Got={}".format(
|
||||||
credentials
|
credentials
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(credentials, dict):
|
||||||
|
json_obj = credentials
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid credentials type: {}".format(type(credentials))
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the JSON object contains Workload Identity Federation configuration
|
# Check if the JSON object contains Workload Identity Federation configuration
|
||||||
|
@ -109,7 +117,7 @@ class VertexBase(BaseLLM):
|
||||||
|
|
||||||
def _ensure_access_token(
|
def _ensure_access_token(
|
||||||
self,
|
self,
|
||||||
credentials: Optional[str],
|
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
project_id: Optional[str],
|
project_id: Optional[str],
|
||||||
custom_llm_provider: Literal[
|
custom_llm_provider: Literal[
|
||||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
@ -202,7 +210,7 @@ class VertexBase(BaseLLM):
|
||||||
gemini_api_key: Optional[str],
|
gemini_api_key: Optional[str],
|
||||||
vertex_project: Optional[str],
|
vertex_project: Optional[str],
|
||||||
vertex_location: Optional[str],
|
vertex_location: Optional[str],
|
||||||
vertex_credentials: Optional[str],
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
stream: Optional[bool],
|
stream: Optional[bool],
|
||||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||||
api_base: Optional[str],
|
api_base: Optional[str],
|
||||||
|
@ -253,7 +261,7 @@ class VertexBase(BaseLLM):
|
||||||
|
|
||||||
async def _ensure_access_token_async(
|
async def _ensure_access_token_async(
|
||||||
self,
|
self,
|
||||||
credentials: Optional[str],
|
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||||
project_id: Optional[str],
|
project_id: Optional[str],
|
||||||
custom_llm_provider: Literal[
|
custom_llm_provider: Literal[
|
||||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||||
|
|
|
@ -6,6 +6,7 @@ from litellm._logging import verbose_proxy_logger
|
||||||
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
|
||||||
VertexPassThroughCredentials,
|
VertexPassThroughCredentials,
|
||||||
)
|
)
|
||||||
|
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
|
|
||||||
class VertexPassThroughRouter:
|
class VertexPassThroughRouter:
|
||||||
|
@ -58,7 +59,7 @@ class VertexPassThroughRouter:
|
||||||
self,
|
self,
|
||||||
project_id: str,
|
project_id: str,
|
||||||
location: str,
|
location: str,
|
||||||
vertex_credentials: str,
|
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add the vertex credentials for the given project-id, location
|
Add the vertex credentials for the given project-id, location
|
||||||
|
|
|
@ -481,3 +481,6 @@ class VertexBatchPredictionResponse(TypedDict, total=False):
|
||||||
createTime: str
|
createTime: str
|
||||||
updateTime: str
|
updateTime: str
|
||||||
modelVersionId: str
|
modelVersionId: str
|
||||||
|
|
||||||
|
|
||||||
|
VERTEX_CREDENTIALS_TYPES = Union[str, Dict[str, str]]
|
||||||
|
|
|
@ -6,6 +6,8 @@ from typing import Optional
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from ..llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
|
|
||||||
|
|
||||||
class VertexPassThroughCredentials(BaseModel):
|
class VertexPassThroughCredentials(BaseModel):
|
||||||
# Example: vertex_project = "my-project-123"
|
# Example: vertex_project = "my-project-123"
|
||||||
|
@ -15,4 +17,4 @@ class VertexPassThroughCredentials(BaseModel):
|
||||||
vertex_location: Optional[str] = None
|
vertex_location: Optional[str] = None
|
||||||
|
|
||||||
# Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS"
|
# Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS"
|
||||||
vertex_credentials: Optional[str] = None
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None
|
||||||
|
|
|
@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
from ..exceptions import RateLimitError
|
from ..exceptions import RateLimitError
|
||||||
from .completion import CompletionRequest
|
from .completion import CompletionRequest
|
||||||
from .embedding import EmbeddingRequest
|
from .embedding import EmbeddingRequest
|
||||||
|
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||||
from .utils import ModelResponse, ProviderSpecificModelInfo
|
from .utils import ModelResponse, ProviderSpecificModelInfo
|
||||||
|
|
||||||
|
|
||||||
|
@ -171,7 +172,7 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
vertex_project: Optional[str] = None
|
vertex_project: Optional[str] = None
|
||||||
vertex_location: Optional[str] = None
|
vertex_location: Optional[str] = None
|
||||||
vertex_credentials: Optional[str] = None
|
vertex_credentials: Optional[Union[str, dict]] = None
|
||||||
## AWS BEDROCK / SAGEMAKER ##
|
## AWS BEDROCK / SAGEMAKER ##
|
||||||
aws_access_key_id: Optional[str] = None
|
aws_access_key_id: Optional[str] = None
|
||||||
aws_secret_access_key: Optional[str] = None
|
aws_secret_access_key: Optional[str] = None
|
||||||
|
@ -213,7 +214,7 @@ class GenericLiteLLMParams(BaseModel):
|
||||||
## VERTEX AI ##
|
## VERTEX AI ##
|
||||||
vertex_project: Optional[str] = None,
|
vertex_project: Optional[str] = None,
|
||||||
vertex_location: Optional[str] = None,
|
vertex_location: Optional[str] = None,
|
||||||
vertex_credentials: Optional[str] = None,
|
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
|
||||||
## AWS BEDROCK / SAGEMAKER ##
|
## AWS BEDROCK / SAGEMAKER ##
|
||||||
aws_access_key_id: Optional[str] = None,
|
aws_access_key_id: Optional[str] = None,
|
||||||
aws_secret_access_key: Optional[str] = None,
|
aws_secret_access_key: Optional[str] = None,
|
||||||
|
|
|
@ -1518,7 +1518,7 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
|
||||||
)
|
)
|
||||||
elif resp is not None:
|
elif resp is not None:
|
||||||
|
|
||||||
assert resp.model == model.split("/")[1].split("@")[0]
|
assert resp.model == model.split("/")[1]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -2740,7 +2740,7 @@ async def test_partner_models_httpx_ai21():
|
||||||
"total_tokens": 194,
|
"total_tokens": 194,
|
||||||
},
|
},
|
||||||
"meta": {"requestDurationMillis": 501},
|
"meta": {"requestDurationMillis": 501},
|
||||||
"model": "jamba-1.5",
|
"model": "jamba-1.5-mini@001",
|
||||||
}
|
}
|
||||||
|
|
||||||
mock_response.json = return_val
|
mock_response.json = return_val
|
||||||
|
@ -2769,7 +2769,7 @@ async def test_partner_models_httpx_ai21():
|
||||||
kwargs["data"] = json.loads(kwargs["data"])
|
kwargs["data"] = json.loads(kwargs["data"])
|
||||||
|
|
||||||
assert kwargs["data"] == {
|
assert kwargs["data"] == {
|
||||||
"model": "jamba-1.5-mini",
|
"model": "jamba-1.5-mini@001",
|
||||||
"messages": [
|
"messages": [
|
||||||
{
|
{
|
||||||
"role": "system",
|
"role": "system",
|
||||||
|
@ -3222,3 +3222,67 @@ def test_vertexai_code_gecko():
|
||||||
|
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
|
||||||
|
|
||||||
|
def vertex_ai_anthropic_thinking_mock_response(*args, **kwargs):
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json.return_value = {
|
||||||
|
"id": "msg_vrtx_011pL6Np3MKxXL3R8theMRJW",
|
||||||
|
"type": "message",
|
||||||
|
"role": "assistant",
|
||||||
|
"model": "claude-3-7-sonnet-20250219",
|
||||||
|
"content": [
|
||||||
|
{
|
||||||
|
"type": "thinking",
|
||||||
|
"thinking": 'This is a very simple and common greeting in programming and computing. "Hello, world!" is often the first program people write when learning a new programming language, where they create a program that outputs this phrase.\n\nI should respond in a friendly way and acknowledge this greeting. I can keep it simple and welcoming.',
|
||||||
|
"signature": "EugBCkYQAhgCIkAqCkezmsp8DG9Jjoc/CD7yXavPXVvP4TAuwjc/ZgHRIgroz5FzAYxic3CnNiW5w2fx/4+1f4ZYVxWJVLmrEA46EgwFsxbpN2jxMxjIzy0aDIAbMy9rW6B5lGVETCIw4r2UW0A7m5Df991SMSMPvHU9VdL8p9S/F2wajLnLVpl5tH89csm4NqnMpxnou61yKlCLldFGIto1Kvit5W1jqn2gx2dGIOyR4YaJ0c8AIFfQa5TIXf+EChVDzhPKLWZ8D/Q3gCGxBx+m/4dLI8HMZA8Ob3iCMI23eBKmh62FCWJGuA==",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "text",
|
||||||
|
"text": "Hi there! 👋 \n\nIt's nice to meet you! \"Hello, world!\" is such a classic phrase in computing - it's often the first output from someone's very first program.\n\nHow are you doing today? Is there something specific I can help you with?",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"stop_reason": "end_turn",
|
||||||
|
"stop_sequence": None,
|
||||||
|
"usage": {
|
||||||
|
"input_tokens": 39,
|
||||||
|
"cache_creation_input_tokens": 0,
|
||||||
|
"cache_read_input_tokens": 0,
|
||||||
|
"output_tokens": 134,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
return mock_response
|
||||||
|
|
||||||
|
|
||||||
|
def test_vertex_anthropic_completion():
|
||||||
|
from litellm import completion
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
client, "post", side_effect=vertex_ai_anthropic_thinking_mock_response
|
||||||
|
):
|
||||||
|
response = completion(
|
||||||
|
model="vertex_ai/claude-3-7-sonnet@20250219",
|
||||||
|
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||||
|
vertex_ai_location="us-east5",
|
||||||
|
vertex_ai_project="test-project",
|
||||||
|
thinking={"type": "enabled", "budget_tokens": 1024},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
assert response.model == "claude-3-7-sonnet@20250219"
|
||||||
|
assert response._hidden_params["response_cost"] is not None
|
||||||
|
assert response._hidden_params["response_cost"] > 0
|
||||||
|
|
||||||
|
assert response.choices[0].message.reasoning_content is not None
|
||||||
|
assert isinstance(response.choices[0].message.reasoning_content, str)
|
||||||
|
assert response.choices[0].message.thinking_blocks is not None
|
||||||
|
assert isinstance(response.choices[0].message.thinking_blocks, list)
|
||||||
|
assert len(response.choices[0].message.thinking_blocks) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue