vertex ai anthropic thinking param support (#8853)

* fix(vertex_llm_base.py): handle credentials passed in as dictionary

* fix(router.py): support vertex credentials as json dict

* test(test_vertex.py): allows easier testing

mock anthropic thinking response for vertex ai

* test(vertex_ai_partner_models/): don't remove "@" from model

breaks anthropic cost calculation

* test: move testing

* fix: fix linting error

* fix: fix linting error

* fix(vertex_ai_partner_models/main.py): split @ for codestral model

* test: fix test

* fix: fix stripping "@" on mistral models

* fix: fix test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-02-26 21:37:18 -08:00 committed by GitHub
parent 6fd8ce0df5
commit ca6902e191
15 changed files with 135 additions and 45 deletions

View file

@ -12,6 +12,7 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
@ -34,7 +35,7 @@ class VertexBase(BaseLLM):
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
) -> Tuple[Any, str]:
import google.auth as google_auth
from google.auth import identity_pool
@ -42,29 +43,36 @@ class VertexBase(BaseLLM):
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
if credentials is not None:
import google.oauth2.service_account
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
if isinstance(credentials, str):
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
)
)
elif isinstance(credentials, dict):
json_obj = credentials
else:
raise ValueError(
"Invalid credentials type: {}".format(type(credentials))
)
# Check if the JSON object contains Workload Identity Federation configuration
@ -109,7 +117,7 @@ class VertexBase(BaseLLM):
def _ensure_access_token(
self,
credentials: Optional[str],
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
@ -202,7 +210,7 @@ class VertexBase(BaseLLM):
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
@ -253,7 +261,7 @@ class VertexBase(BaseLLM):
async def _ensure_access_token_async(
self,
credentials: Optional[str],
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"