vertex ai anthropic thinking param support (#8853)

* fix(vertex_llm_base.py): handle credentials passed in as dictionary

* fix(router.py): support vertex credentials as json dict

* test(test_vertex.py): allows easier testing

mock anthropic thinking response for vertex ai

* test(vertex_ai_partner_models/): don't remove "@" from model

breaks anthropic cost calculation

* test: move testing

* fix: fix linting error

* fix: fix linting error

* fix(vertex_ai_partner_models/main.py): split @ for codestral model

* test: fix test

* fix: fix stripping "@" on mistral models

* fix: fix test

* test: fix test
This commit is contained in:
Krish Dholakia 2025-02-26 21:37:18 -08:00 committed by GitHub
parent 992e78dfd8
commit 88eedb22b9
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 135 additions and 45 deletions

2
.gitignore vendored
View file

@ -77,3 +77,5 @@ litellm/proxy/_experimental/out/404.html
litellm/proxy/_experimental/out/model_hub.html
.mypy_cache/*
litellm/proxy/application.log
tests/llm_translation/vertex_test_account.json
tests/llm_translation/test_vertex_key.json

View file

@ -10,7 +10,10 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.openai import Batch, CreateBatchRequest
from litellm.types.llms.vertex_ai import VertexAIBatchPredictionJob
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
VertexAIBatchPredictionJob,
)
from .transformation import VertexAIBatchTransformation
@ -25,7 +28,7 @@ class VertexAIBatchPrediction(VertexLLM):
_is_async: bool,
create_batch_data: CreateBatchRequest,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
@ -130,7 +133,7 @@ class VertexAIBatchPrediction(VertexLLM):
_is_async: bool,
batch_id: str,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],

View file

@ -9,6 +9,7 @@ from litellm.integrations.gcs_bucket.gcs_bucket_base import (
)
from litellm.llms.custom_httpx.http_handler import get_async_httpx_client
from litellm.types.llms.openai import CreateFileRequest, FileObject
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .transformation import VertexAIFilesTransformation
@ -34,7 +35,7 @@ class VertexAIFilesHandler(GCSBucketBase):
self,
create_file_data: CreateFileRequest,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],
@ -70,7 +71,7 @@ class VertexAIFilesHandler(GCSBucketBase):
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
vertex_project: Optional[str],
vertex_location: Optional[str],
timeout: Union[float, httpx.Timeout],

View file

@ -13,6 +13,7 @@ from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import Ver
from litellm.types.fine_tuning import OpenAIFineTuningHyperparameters
from litellm.types.llms.openai import FineTuningJobCreate
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
FineTuneHyperparameters,
FineTuneJobCreate,
FineTunesupervisedTuningSpec,
@ -222,7 +223,7 @@ class VertexFineTuningAPI(VertexLLM):
create_fine_tuning_job_data: FineTuningJobCreate,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
kwargs: Optional[dict] = None,

View file

@ -40,6 +40,7 @@ from litellm.types.llms.openai import (
ChatCompletionUsageBlock,
)
from litellm.types.llms.vertex_ai import (
VERTEX_CREDENTIALS_TYPES,
Candidates,
ContentType,
FunctionCallingConfig,
@ -930,7 +931,7 @@ class VertexLLM(VertexBase):
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> CustomStreamWrapper:
@ -1018,7 +1019,7 @@ class VertexLLM(VertexBase):
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> Union[ModelResponse, CustomStreamWrapper]:
@ -1123,7 +1124,7 @@ class VertexLLM(VertexBase):
timeout: Optional[Union[float, httpx.Timeout]],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
gemini_api_key: Optional[str],
litellm_params: dict,
logger_fn=None,

View file

@ -11,6 +11,7 @@ from litellm.llms.custom_httpx.http_handler import (
get_async_httpx_client,
)
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from litellm.types.utils import ImageResponse
@ -44,7 +45,7 @@ class VertexImageGeneration(VertexLLM):
prompt: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: ImageResponse,
logging_obj: Any,
model: Optional[
@ -139,7 +140,7 @@ class VertexImageGeneration(VertexLLM):
prompt: str,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
model_response: litellm.ImageResponse,
logging_obj: Any,
model: Optional[

View file

@ -9,6 +9,7 @@ from litellm.llms.custom_httpx.http_handler import (
)
from litellm.llms.openai.openai import HttpxBinaryResponseContent
from litellm.llms.vertex_ai.gemini.vertex_and_google_ai_studio_gemini import VertexLLM
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
class VertexInput(TypedDict, total=False):
@ -45,7 +46,7 @@ class VertexTextToSpeechAPI(VertexLLM):
logging_obj,
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
api_base: Optional[str],
timeout: Union[float, httpx.Timeout],
model: str,

View file

@ -160,7 +160,8 @@ class VertexAIPartnerModels(VertexBase):
url=default_api_base,
)
model = model.split("@")[0]
if "codestral" in model or "mistral" in model:
model = model.split("@")[0]
if "codestral" in model and litellm_params.get("text_completion") is True:
optional_params["model"] = model

View file

@ -41,7 +41,7 @@ class VertexEmbedding(VertexBase):
client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
) -> EmbeddingResponse:
@ -148,7 +148,7 @@ class VertexEmbedding(VertexBase):
client: Optional[AsyncHTTPHandler] = None,
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
encoding=None,

View file

@ -12,6 +12,7 @@ from litellm._logging import verbose_logger
from litellm.litellm_core_utils.asyncify import asyncify
from litellm.llms.base import BaseLLM
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .common_utils import _get_gemini_url, _get_vertex_url, all_gemini_url_modes
@ -34,7 +35,7 @@ class VertexBase(BaseLLM):
return vertex_region or "us-central1"
def load_auth(
self, credentials: Optional[str], project_id: Optional[str]
self, credentials: Optional[VERTEX_CREDENTIALS_TYPES], project_id: Optional[str]
) -> Tuple[Any, str]:
import google.auth as google_auth
from google.auth import identity_pool
@ -42,29 +43,36 @@ class VertexBase(BaseLLM):
Request, # type: ignore[import-untyped]
)
if credentials is not None and isinstance(credentials, str):
if credentials is not None:
import google.oauth2.service_account
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
if isinstance(credentials, str):
verbose_logger.debug(
"Vertex: Loading vertex credentials from %s", credentials
)
verbose_logger.debug(
"Vertex: checking if credentials is a valid path, os.path.exists(%s)=%s, current dir %s",
credentials,
os.path.exists(credentials),
os.getcwd(),
)
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
try:
if os.path.exists(credentials):
json_obj = json.load(open(credentials))
else:
json_obj = json.loads(credentials)
except Exception:
raise Exception(
"Unable to load vertex credentials from environment. Got={}".format(
credentials
)
)
elif isinstance(credentials, dict):
json_obj = credentials
else:
raise ValueError(
"Invalid credentials type: {}".format(type(credentials))
)
# Check if the JSON object contains Workload Identity Federation configuration
@ -109,7 +117,7 @@ class VertexBase(BaseLLM):
def _ensure_access_token(
self,
credentials: Optional[str],
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"
@ -202,7 +210,7 @@ class VertexBase(BaseLLM):
gemini_api_key: Optional[str],
vertex_project: Optional[str],
vertex_location: Optional[str],
vertex_credentials: Optional[str],
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES],
stream: Optional[bool],
custom_llm_provider: Literal["vertex_ai", "vertex_ai_beta", "gemini"],
api_base: Optional[str],
@ -253,7 +261,7 @@ class VertexBase(BaseLLM):
async def _ensure_access_token_async(
self,
credentials: Optional[str],
credentials: Optional[VERTEX_CREDENTIALS_TYPES],
project_id: Optional[str],
custom_llm_provider: Literal[
"vertex_ai", "vertex_ai_beta", "gemini"

View file

@ -6,6 +6,7 @@ from litellm._logging import verbose_proxy_logger
from litellm.proxy.vertex_ai_endpoints.vertex_endpoints import (
VertexPassThroughCredentials,
)
from litellm.types.llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
class VertexPassThroughRouter:
@ -58,7 +59,7 @@ class VertexPassThroughRouter:
self,
project_id: str,
location: str,
vertex_credentials: str,
vertex_credentials: VERTEX_CREDENTIALS_TYPES,
):
"""
Add the vertex credentials for the given project-id, location

View file

@ -481,3 +481,6 @@ class VertexBatchPredictionResponse(TypedDict, total=False):
createTime: str
updateTime: str
modelVersionId: str
VERTEX_CREDENTIALS_TYPES = Union[str, Dict[str, str]]

View file

@ -6,6 +6,8 @@ from typing import Optional
from pydantic import BaseModel
from ..llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
class VertexPassThroughCredentials(BaseModel):
# Example: vertex_project = "my-project-123"
@ -15,4 +17,4 @@ class VertexPassThroughCredentials(BaseModel):
vertex_location: Optional[str] = None
# Example: vertex_credentials = "/path/to/credentials.json" or "os.environ/GOOGLE_CREDS"
vertex_credentials: Optional[str] = None
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None

View file

@ -18,6 +18,7 @@ from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from ..exceptions import RateLimitError
from .completion import CompletionRequest
from .embedding import EmbeddingRequest
from .llms.vertex_ai import VERTEX_CREDENTIALS_TYPES
from .utils import ModelResponse, ProviderSpecificModelInfo
@ -171,7 +172,7 @@ class GenericLiteLLMParams(BaseModel):
## VERTEX AI ##
vertex_project: Optional[str] = None
vertex_location: Optional[str] = None
vertex_credentials: Optional[str] = None
vertex_credentials: Optional[Union[str, dict]] = None
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None
@ -213,7 +214,7 @@ class GenericLiteLLMParams(BaseModel):
## VERTEX AI ##
vertex_project: Optional[str] = None,
vertex_location: Optional[str] = None,
vertex_credentials: Optional[str] = None,
vertex_credentials: Optional[VERTEX_CREDENTIALS_TYPES] = None,
## AWS BEDROCK / SAGEMAKER ##
aws_access_key_id: Optional[str] = None,
aws_secret_access_key: Optional[str] = None,

View file

@ -1518,7 +1518,7 @@ async def test_gemini_pro_json_schema_args_sent_httpx(
)
elif resp is not None:
assert resp.model == model.split("/")[1].split("@")[0]
assert resp.model == model.split("/")[1]
@pytest.mark.parametrize(
@ -2740,7 +2740,7 @@ async def test_partner_models_httpx_ai21():
"total_tokens": 194,
},
"meta": {"requestDurationMillis": 501},
"model": "jamba-1.5",
"model": "jamba-1.5-mini@001",
}
mock_response.json = return_val
@ -2769,7 +2769,7 @@ async def test_partner_models_httpx_ai21():
kwargs["data"] = json.loads(kwargs["data"])
assert kwargs["data"] == {
"model": "jamba-1.5-mini",
"model": "jamba-1.5-mini@001",
"messages": [
{
"role": "system",
@ -3222,3 +3222,67 @@ def test_vertexai_code_gecko():
for chunk in response:
print(chunk)
def vertex_ai_anthropic_thinking_mock_response(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200
mock_response.headers = {"Content-Type": "application/json"}
mock_response.json.return_value = {
"id": "msg_vrtx_011pL6Np3MKxXL3R8theMRJW",
"type": "message",
"role": "assistant",
"model": "claude-3-7-sonnet-20250219",
"content": [
{
"type": "thinking",
"thinking": 'This is a very simple and common greeting in programming and computing. "Hello, world!" is often the first program people write when learning a new programming language, where they create a program that outputs this phrase.\n\nI should respond in a friendly way and acknowledge this greeting. I can keep it simple and welcoming.',
"signature": "EugBCkYQAhgCIkAqCkezmsp8DG9Jjoc/CD7yXavPXVvP4TAuwjc/ZgHRIgroz5FzAYxic3CnNiW5w2fx/4+1f4ZYVxWJVLmrEA46EgwFsxbpN2jxMxjIzy0aDIAbMy9rW6B5lGVETCIw4r2UW0A7m5Df991SMSMPvHU9VdL8p9S/F2wajLnLVpl5tH89csm4NqnMpxnou61yKlCLldFGIto1Kvit5W1jqn2gx2dGIOyR4YaJ0c8AIFfQa5TIXf+EChVDzhPKLWZ8D/Q3gCGxBx+m/4dLI8HMZA8Ob3iCMI23eBKmh62FCWJGuA==",
},
{
"type": "text",
"text": "Hi there! 👋 \n\nIt's nice to meet you! \"Hello, world!\" is such a classic phrase in computing - it's often the first output from someone's very first program.\n\nHow are you doing today? Is there something specific I can help you with?",
},
],
"stop_reason": "end_turn",
"stop_sequence": None,
"usage": {
"input_tokens": 39,
"cache_creation_input_tokens": 0,
"cache_read_input_tokens": 0,
"output_tokens": 134,
},
}
return mock_response
def test_vertex_anthropic_completion():
from litellm import completion
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
load_vertex_ai_credentials()
with patch.object(
client, "post", side_effect=vertex_ai_anthropic_thinking_mock_response
):
response = completion(
model="vertex_ai/claude-3-7-sonnet@20250219",
messages=[{"role": "user", "content": "Hello, world!"}],
vertex_ai_location="us-east5",
vertex_ai_project="test-project",
thinking={"type": "enabled", "budget_tokens": 1024},
client=client,
)
print(response)
assert response.model == "claude-3-7-sonnet@20250219"
assert response._hidden_params["response_cost"] is not None
assert response._hidden_params["response_cost"] > 0
assert response.choices[0].message.reasoning_content is not None
assert isinstance(response.choices[0].message.reasoning_content, str)
assert response.choices[0].message.thinking_blocks is not None
assert isinstance(response.choices[0].message.thinking_blocks, list)
assert len(response.choices[0].message.thinking_blocks) > 0