mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(vertex_ai_anthropic.py): set vertex_credentials for vertex ai anthropic calls
allows setting vertex credentials as a json string for vertex ai anthropic calls
This commit is contained in:
parent
3d645f95a5
commit
8c3c45fbb5
5 changed files with 36 additions and 14 deletions
|
@ -270,6 +270,7 @@ def completion(
|
|||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
vertex_credentials=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -348,6 +349,7 @@ def completion(
|
|||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
|
||||
)
|
||||
|
||||
creds, _ = google.auth.default(quota_project_id=vertex_project)
|
||||
print_verbose(
|
||||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
|
|
|
@ -129,6 +129,18 @@ class VertexAIAnthropicConfig:
|
|||
|
||||
|
||||
# makes headers for API call
|
||||
def refresh_auth(
|
||||
credentials,
|
||||
) -> str: # used when user passes in credentials as json string
|
||||
from google.auth.transport.requests import Request # type: ignore[import-untyped]
|
||||
|
||||
if credentials.token is None:
|
||||
credentials.refresh(Request())
|
||||
|
||||
if not credentials.token:
|
||||
raise RuntimeError("Could not resolve API token from the credentials")
|
||||
|
||||
return credentials.token
|
||||
|
||||
|
||||
def completion(
|
||||
|
@ -220,6 +232,7 @@ def completion(
|
|||
print_verbose(
|
||||
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}; vertex_credentials={vertex_credentials}"
|
||||
)
|
||||
access_token = None
|
||||
if client is None:
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
@ -228,14 +241,17 @@ def completion(
|
|||
|
||||
creds = (
|
||||
google.oauth2.service_account.Credentials.from_service_account_info(
|
||||
json_obj
|
||||
json_obj,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
)
|
||||
### CHECK IF ACCESS
|
||||
access_token = refresh_auth(credentials=creds)
|
||||
|
||||
vertexai.init(credentials=creds)
|
||||
vertex_ai_client = AnthropicVertex(
|
||||
project_id=vertex_project,
|
||||
region=vertex_location,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
@ -257,6 +273,7 @@ def completion(
|
|||
vertex_location=vertex_location,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
access_token=access_token,
|
||||
)
|
||||
else:
|
||||
return async_completion(
|
||||
|
@ -270,6 +287,7 @@ def completion(
|
|||
vertex_location=vertex_location,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
access_token=access_token,
|
||||
)
|
||||
if stream is not None and stream == True:
|
||||
## LOGGING
|
||||
|
@ -348,12 +366,13 @@ async def async_completion(
|
|||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
access_token=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location
|
||||
project_id=vertex_project, region=vertex_location, access_token=access_token
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
@ -418,12 +437,13 @@ async def async_streaming(
|
|||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
access_token=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location
|
||||
project_id=vertex_project, region=vertex_location, access_token=access_token
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
|
|
@ -1676,13 +1676,12 @@ def completion(
|
|||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
)
|
||||
|
||||
if "claude-3" in model:
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
if "claude-3" in model:
|
||||
model_response = vertex_ai_anthropic.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
|
|
@ -40,7 +40,7 @@ general_settings:
|
|||
master_key: sk-1234
|
||||
allow_user_auth: true
|
||||
alerting: ["slack"]
|
||||
# store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
||||
store_model_in_db: True // set via environment variable - os.environ["STORE_MODEL_IN_DB"] = "True"
|
||||
proxy_batch_write_at: 5 # 👈 Frequency of batch writing logs to server (in seconds)
|
||||
enable_jwt_auth: True
|
||||
alerting: ["slack"]
|
||||
|
|
|
@ -123,8 +123,6 @@ async def get_response():
|
|||
# reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
# )
|
||||
def test_vertex_ai_anthropic():
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
model = "claude-3-sonnet@20240229"
|
||||
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
|
@ -179,12 +177,14 @@ def test_vertex_ai_anthropic_streaming():
|
|||
# )
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_ai_anthropic_async():
|
||||
load_vertex_ai_credentials()
|
||||
# load_vertex_ai_credentials()
|
||||
|
||||
model = "claude-3-sonnet@20240229"
|
||||
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
vertex_ai_location = "asia-southeast1"
|
||||
json_obj = get_vertex_ai_creds_json()
|
||||
vertex_credentials = json.dumps(json_obj)
|
||||
|
||||
response = await acompletion(
|
||||
model="vertex_ai/" + model,
|
||||
|
@ -192,6 +192,7 @@ async def test_vertex_ai_anthropic_async():
|
|||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
print(f"Model Response: {response}")
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue