mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 03:34:10 +00:00
vertex ai anthropic thinking param support (#8853)
* fix(vertex_llm_base.py): handle credentials passed in as dictionary * fix(router.py): support vertex credentials as json dict * test(test_vertex.py): allows easier testing mock anthropic thinking response for vertex ai * test(vertex_ai_partner_models/): don't remove "@" from model breaks anthropic cost calculation * test: move testing * fix: fix linting error * fix: fix linting error * fix(vertex_ai_partner_models/main.py): split @ for codestral model * test: fix test * fix: fix stripping "@" on mistral models * fix: fix test * test: fix test
This commit is contained in:
parent
6fd8ce0df5
commit
ca6902e191
15 changed files with 135 additions and 45 deletions
|
@ -12,6 +12,7 @@ from litellm._logging import verbose_logger
|
|||
from litellm.litellm_core_utils.asyncify import asyncify
|
||||
from litellm.llms.base import BaseLLM
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
|
||||
|
||||
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
|
||||
|
||||
|
@ -34,7 +35,7 @@ class VertexBase(BaseLLM):
|
|||
return vertex_region or "us-central1"
|
||||
|
||||
def load_auth(
|
||||
self, credentials: Optional[str], project_id: Optional[str]
|
||||
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
|
||||
) -> Tuple[Any, str]:
|
||||
import google.auth as google_auth
|
||||
from google.auth import identity_pool
|
||||
|
@ -42,29 +43,36 @@ class VertexBase(BaseLLM):
|
|||
Request, # type: ignore[import-untyped]
|
||||
)
|
||||
|
||||
if credentials is not None and isinstance(credentials, str):
|
||||
if credentials is not None:
|
||||
import google.oauth2.service_account
|
||||
|
||||
verbose_logger.debug(
|
||||
"Vertex: Loading vertex credentials from %s", credentials
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||
credentials,
|
||||
os.path.exists(credentials),
|
||||
os.getcwd(),
|
||||
)
|
||||
if isinstance(credentials, str):
|
||||
verbose_logger.debug(
|
||||
"Vertex: Loading vertex credentials from %s", credentials
|
||||
)
|
||||
verbose_logger.debug(
|
||||
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
|
||||
credentials,
|
||||
os.path.exists(credentials),
|
||||
os.getcwd(),
|
||||
)
|
||||
|
||||
try:
|
||||
if os.path.exists(credentials):
|
||||
json_obj = json.load(open(credentials))
|
||||
else:
|
||||
json_obj = json.loads(credentials)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Unable to load vertex credentials from environment. Got={}".format(
|
||||
credentials
|
||||
try:
|
||||
if os.path.exists(credentials):
|
||||
json_obj = json.load(open(credentials))
|
||||
else:
|
||||
json_obj = json.loads(credentials)
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Unable to load vertex credentials from environment. Got={}".format(
|
||||
credentials
|
||||
)
|
||||
)
|
||||
elif isinstance(credentials, dict):
|
||||
json_obj = credentials
|
||||
else:
|
||||
raise ValueError(
|
||||
"Invalid credentials type: {}".format(type(credentials))
|
||||
)
|
||||
|
||||
# Check if the JSON object contains Workload Identity Federation configuration
|
||||
|
@ -109,7 +117,7 @@ class VertexBase(BaseLLM):
|
|||
|
||||
def _ensure_access_token(
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
|
@ -202,7 +210,7 @@ class VertexBase(BaseLLM):
|
|||
gemini_api_key: Optional[str],
|
||||
vertex_project: Optional[str],
|
||||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
stream: Optional[bool],
|
||||
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
|
||||
api_base: Optional[str],
|
||||
|
@ -253,7 +261,7 @@ class VertexBase(BaseLLM):
|
|||
|
||||
async def _ensure_access_token_async(
|
||||
self,
|
||||
credentials: Optional[str],
|
||||
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
|
||||
project_id: Optional[str],
|
||||
custom_llm_provider: Literal[
|
||||
"vertex_ai", "vertex_ai_beta", "gemini"
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue