mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(vertex_ai_anthropic.py): set vertex_credentials for vertex ai anthropic calls
allows setting vertex credentials as a json string for vertex ai anthropic calls
This commit is contained in:
parent
3d645f95a5
commit
8c3c45fbb5
5 changed files with 36 additions and 14 deletions
|
@ -270,6 +270,7 @@ def completion(
|
||||||
logging_obj,
|
logging_obj,
|
||||||
vertex_project=None,
|
vertex_project=None,
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
|
vertex_credentials=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
litellm_params=None,
|
litellm_params=None,
|
||||||
logger_fn=None,
|
logger_fn=None,
|
||||||
|
@ -348,6 +349,7 @@ def completion(
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||||
)
|
)
|
||||||
|
|
||||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||||
|
|
|
@ -129,6 +129,18 @@ class VertexAIAnthropicConfig:
|
||||||
|
|
||||||
|
|
||||||
# makes headers for API call
|
# makes headers for API call
|
||||||
|
def refresh_auth(
|
||||||
|
credentials,
|
||||||
|
) -> str: # used when user passes in credentials as json string
|
||||||
|
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||||
|
|
||||||
|
if credentials.token is None:
|
||||||
|
credentials.refresh(Request())
|
||||||
|
|
||||||
|
if not credentials.token:
|
||||||
|
raise RuntimeError("Could not resolve API token from the credentials")
|
||||||
|
|
||||||
|
return credentials.token
|
||||||
|
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
|
@ -220,6 +232,7 @@ def completion(
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||||
)
|
)
|
||||||
|
access_token = None
|
||||||
if client is None:
|
if client is None:
|
||||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||||
import google.oauth2.service_account
|
import google.oauth2.service_account
|
||||||
|
@ -228,14 +241,17 @@ def completion(
|
||||||
|
|
||||||
creds = (
|
creds = (
|
||||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||||
json_obj
|
json_obj,
|
||||||
|
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
### CHECK IF ACCESS
|
||||||
|
access_token = refresh_auth(credentials=creds)
|
||||||
|
|
||||||
vertexai.init(credentials=creds)
|
|
||||||
vertex_ai_client = AnthropicVertex(
|
vertex_ai_client = AnthropicVertex(
|
||||||
project_id=vertex_project,
|
project_id=vertex_project,
|
||||||
region=vertex_location,
|
region=vertex_location,
|
||||||
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vertex_ai_client = client
|
vertex_ai_client = client
|
||||||
|
@ -257,6 +273,7 @@ def completion(
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
return async_completion(
|
return async_completion(
|
||||||
|
@ -270,6 +287,7 @@ def completion(
|
||||||
vertex_location=vertex_location,
|
vertex_location=vertex_location,
|
||||||
optional_params=optional_params,
|
optional_params=optional_params,
|
||||||
client=client,
|
client=client,
|
||||||
|
access_token=access_token,
|
||||||
)
|
)
|
||||||
if stream is not None and stream == True:
|
if stream is not None and stream == True:
|
||||||
## LOGGING
|
## LOGGING
|
||||||
|
@ -348,12 +366,13 @@ async def async_completion(
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
client=None,
|
client=None,
|
||||||
|
access_token=None,
|
||||||
):
|
):
|
||||||
from anthropic import AsyncAnthropicVertex
|
from anthropic import AsyncAnthropicVertex
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
vertex_ai_client = AsyncAnthropicVertex(
|
vertex_ai_client = AsyncAnthropicVertex(
|
||||||
project_id=vertex_project, region=vertex_location
|
project_id=vertex_project, region=vertex_location, access_token=access_token
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vertex_ai_client = client
|
vertex_ai_client = client
|
||||||
|
@ -418,12 +437,13 @@ async def async_streaming(
|
||||||
vertex_location=None,
|
vertex_location=None,
|
||||||
optional_params=None,
|
optional_params=None,
|
||||||
client=None,
|
client=None,
|
||||||
|
access_token=None,
|
||||||
):
|
):
|
||||||
from anthropic import AsyncAnthropicVertex
|
from anthropic import AsyncAnthropicVertex
|
||||||
|
|
||||||
if client is None:
|
if client is None:
|
||||||
vertex_ai_client = AsyncAnthropicVertex(
|
vertex_ai_client = AsyncAnthropicVertex(
|
||||||
project_id=vertex_project, region=vertex_location
|
project_id=vertex_project, region=vertex_location, access_token=access_token
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vertex_ai_client = client
|
vertex_ai_client = client
|
||||||
|
|
|
@ -1676,13 +1676,12 @@ def completion(
|
||||||
or litellm.vertex_location
|
or litellm.vertex_location
|
||||||
or get_secret("VERTEXAI_LOCATION")
|
or get_secret("VERTEXAI_LOCATION")
|
||||||
)
|
)
|
||||||
|
vertex_credentials = (
|
||||||
|
optional_params.pop("vertex_credentials", None)
|
||||||
|
or optional_params.pop("vertex_ai_credentials", None)
|
||||||
|
or get_secret("VERTEXAI_CREDENTIALS")
|
||||||
|
)
|
||||||
if "claude-3" in model:
|
if "claude-3" in model:
|
||||||
vertex_credentials = (
|
|
||||||
optional_params.pop("vertex_credentials", None)
|
|
||||||
or optional_params.pop("vertex_ai_credentials", None)
|
|
||||||
or get_secret("VERTEXAI_CREDENTIALS")
|
|
||||||
)
|
|
||||||
model_response = vertex_ai_anthropic.completion(
|
model_response = vertex_ai_anthropic.completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -40,7 +40,7 @@ general_settings:
|
||||||
master_key: sk-1234
|
master_key: sk-1234
|
||||||
allow_user_auth: true
|
allow_user_auth: true
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
||||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||||
enable_jwt_auth: True
|
enable_jwt_auth: True
|
||||||
alerting: ["slack"]
|
alerting: ["slack"]
|
||||||
|
|
|
@ -123,8 +123,6 @@ async def get_response():
|
||||||
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||||
# )
|
# )
|
||||||
def test_vertex_ai_anthropic():
|
def test_vertex_ai_anthropic():
|
||||||
load_vertex_ai_credentials()
|
|
||||||
|
|
||||||
model = "claude-3-sonnet@20240229"
|
model = "claude-3-sonnet@20240229"
|
||||||
|
|
||||||
vertex_ai_project = "adroit-crow-413218"
|
vertex_ai_project = "adroit-crow-413218"
|
||||||
|
@ -179,12 +177,14 @@ def test_vertex_ai_anthropic_streaming():
|
||||||
# )
|
# )
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_vertex_ai_anthropic_async():
|
async def test_vertex_ai_anthropic_async():
|
||||||
load_vertex_ai_credentials()
|
# load_vertex_ai_credentials()
|
||||||
|
|
||||||
model = "claude-3-sonnet@20240229"
|
model = "claude-3-sonnet@20240229"
|
||||||
|
|
||||||
vertex_ai_project = "adroit-crow-413218"
|
vertex_ai_project = "adroit-crow-413218"
|
||||||
vertex_ai_location = "asia-southeast1"
|
vertex_ai_location = "asia-southeast1"
|
||||||
|
json_obj = get_vertex_ai_creds_json()
|
||||||
|
vertex_credentials = json.dumps(json_obj)
|
||||||
|
|
||||||
response = await acompletion(
|
response = await acompletion(
|
||||||
model="vertex_ai/" + model,
|
model="vertex_ai/" + model,
|
||||||
|
@ -192,6 +192,7 @@ async def test_vertex_ai_anthropic_async():
|
||||||
temperature=0.7,
|
temperature=0.7,
|
||||||
vertex_ai_project=vertex_ai_project,
|
vertex_ai_project=vertex_ai_project,
|
||||||
vertex_ai_location=vertex_ai_location,
|
vertex_ai_location=vertex_ai_location,
|
||||||
|
vertex_credentials=vertex_credentials,
|
||||||
)
|
)
|
||||||
print(f"Model Response: {response}")
|
print(f"Model Response: {response}")
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue