fix(vertex_ai_anthropic.py): set vertex_credentials for vertex ai anthropic calls

allows setting vertex credentials as a json string for vertex ai anthropic calls
This commit is contained in:
Krrish Dholakia 2024-04-15 14:16:28 -07:00
parent 3d645f95a5
commit 8c3c45fbb5
5 changed files with 36 additions and 14 deletions

View file

@ -270,6 +270,7 @@ def completion(
logging_obj,
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
@ -348,6 +349,7 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
creds, _ = google.auth.default(quota_project_id=vertex_project)
print_verbose(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"

View file

@ -129,6 +129,18 @@ class VertexAIAnthropicConfig:
# makes headers for API call
def refresh_auth(
credentials,
) -> str: # used when user passes in credentials as json string
from google.auth.transport.requests import Request # type: ignore[import-untyped]
if credentials.token is None:
credentials.refresh(Request())
if not credentials.token:
raise RuntimeError("Could not resolve API token from the credentials")
return credentials.token
def completion(
@ -220,6 +232,7 @@ def completion(
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
)
access_token = None
if client is None:
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
@ -228,14 +241,17 @@ def completion(
creds = (
google.oauth2.service_account.Credentials.from_service_account_info(
json_obj
json_obj,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
)
### CHECK IF ACCESS
access_token = refresh_auth(credentials=creds)
vertexai.init(credentials=creds)
vertex_ai_client = AnthropicVertex(
project_id=vertex_project,
region=vertex_location,
access_token=access_token,
)
else:
vertex_ai_client = client
@ -257,6 +273,7 @@ def completion(
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
else:
return async_completion(
@ -270,6 +287,7 @@ def completion(
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
access_token=access_token,
)
if stream is not None and stream == True:
## LOGGING
@ -348,12 +366,13 @@ async def async_completion(
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location
project_id=vertex_project, region=vertex_location, access_token=access_token
)
else:
vertex_ai_client = client
@ -418,12 +437,13 @@ async def async_streaming(
vertex_location=None,
optional_params=None,
client=None,
access_token=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location
project_id=vertex_project, region=vertex_location, access_token=access_token
)
else:
vertex_ai_client = client

View file

@ -1676,13 +1676,12 @@ def completion(
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
)
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
if "claude-3" in model:
vertex_credentials = (
optional_params.pop("vertex_credentials", None)
or optional_params.pop("vertex_ai_credentials", None)
or get_secret("VERTEXAI_CREDENTIALS")
)
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,

View file

@ -40,7 +40,7 @@ general_settings:
master_key: sk-1234
allow_user_auth: true
alerting: ["slack"]
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
enable_jwt_auth: True
alerting: ["slack"]

View file

@ -123,8 +123,6 @@ async def get_response():
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
# )
def test_vertex_ai_anthropic():
load_vertex_ai_credentials()
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
@ -179,12 +177,14 @@ def test_vertex_ai_anthropic_streaming():
# )
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async():
load_vertex_ai_credentials()
# load_vertex_ai_credentials()
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
json_obj = get_vertex_ai_creds_json()
vertex_credentials = json.dumps(json_obj)
response = await acompletion(
model="vertex_ai/" + model,
@ -192,6 +192,7 @@ async def test_vertex_ai_anthropic_async():
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
vertex_credentials=vertex_credentials,
)
print(f"Model Response: {response}")