fix(vertex_ai_anthropic.py): support streaming, async completion, async streaming for vertex ai anthropic

This commit is contained in:
Krrish Dholakia 2024-04-05 09:27:48 -07:00
parent eb34306099
commit f0c4ff6e60
8 changed files with 373 additions and 14 deletions

View file

@ -269,6 +269,7 @@ vertex_code_chat_models: List = []
vertex_text_models: List = [] vertex_text_models: List = []
vertex_code_text_models: List = [] vertex_code_text_models: List = []
vertex_embedding_models: List = [] vertex_embedding_models: List = []
vertex_anthropic_models: List = []
ai21_models: List = [] ai21_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
@ -302,6 +303,9 @@ for key, value in model_cost.items():
vertex_code_chat_models.append(key) vertex_code_chat_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-embedding-models": elif value.get("litellm_provider") == "vertex_ai-embedding-models":
vertex_embedding_models.append(key) vertex_embedding_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
key = key.replace("vertex_ai/", "")
vertex_anthropic_models.append(key)
elif value.get("litellm_provider") == "ai21": elif value.get("litellm_provider") == "ai21":
ai21_models.append(key) ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud": elif value.get("litellm_provider") == "nlp_cloud":

View file

@ -55,7 +55,9 @@ class VertexAIAnthropicConfig:
Note: Please make sure to modify the default parameters as required for your use case. Note: Please make sure to modify the default parameters as required for your use case.
""" """
max_tokens: Optional[int] = litellm.max_tokens max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None system: Optional[str] = None
temperature: Optional[float] = None temperature: Optional[float] = None
top_p: Optional[float] = None top_p: Optional[float] = None
@ -69,6 +71,8 @@ class VertexAIAnthropicConfig:
) -> None: ) -> None:
locals_ = locals() locals_ = locals()
for key, value in locals_.items(): for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None: if key != "self" and value is not None:
setattr(self.__class__, key, value) setattr(self.__class__, key, value)
@ -158,8 +162,6 @@ def completion(
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
) )
try: try:
import google.auth # type: ignore
from google.auth.transport.requests import Request
from anthropic import AnthropicVertex from anthropic import AnthropicVertex
## Load Config ## Load Config
@ -224,6 +226,58 @@ def completion(
else: else:
vertex_ai_client = client vertex_ai_client = client
if acompletion == True:
"""
- async streaming
- async completion
"""
if stream is not None and stream == True:
return async_streaming(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
)
else:
return async_completion(
model=model,
messages=messages,
data=data,
print_verbose=print_verbose,
model_response=model_response,
logging_obj=logging_obj,
vertex_project=vertex_project,
vertex_location=vertex_location,
optional_params=optional_params,
client=client,
)
if stream is not None and stream == True:
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
return response
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
message = vertex_ai_client.messages.create(**data) # type: ignore message = vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE ## TOOL CALLING - OUTPUT PARSE
@ -267,3 +321,115 @@ def completion(
return model_response return model_response
except Exception as e: except Exception as e:
raise VertexAIError(status_code=500, message=str(e)) raise VertexAIError(status_code=500, message=str(e))
async def async_completion(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
message = await vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
async def async_streaming(
model: str,
messages: list,
data: dict,
model_response: ModelResponse,
print_verbose: Callable,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
client=None,
):
from anthropic import AsyncAnthropicVertex
if client is None:
vertex_ai_client = AsyncAnthropicVertex(
project_id=vertex_project, region=vertex_location
)
else:
vertex_ai_client = client
## LOGGING
logging_obj.pre_call(
input=messages,
api_key=None,
additional_args={
"complete_input_dict": optional_params,
},
)
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
logging_obj.post_call(input=messages, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(
completion_stream=response,
model=model,
custom_llm_provider="vertex_ai",
logging_obj=logging_obj,
)
return streamwrapper

View file

@ -1007,7 +1007,7 @@
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat" "mode": "chat"
}, },
"vertex_ai/claude-3-haiku@20240307": { "vertex_ai/claude-3-haiku@20240307": {
@ -1015,7 +1015,7 @@
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.00000025, "input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125, "output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat" "mode": "chat"
}, },
"textembedding-gecko": { "textembedding-gecko": {

View file

@ -12,6 +12,7 @@ import pytest, asyncio
import litellm import litellm
from litellm import embedding, completion, completion_cost, Timeout, acompletion from litellm import embedding, completion, completion_cost, Timeout, acompletion
from litellm import RateLimitError from litellm import RateLimitError
from litellm.tests.test_streaming import streaming_format_tests
import json import json
import os import os
import tempfile import tempfile
@ -102,6 +103,90 @@ def test_vertex_ai_anthropic():
print("\nModel Response", response) print("\nModel Response", response)
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
def test_vertex_ai_anthropic_streaming():
load_vertex_ai_credentials()
# litellm.set_verbose = True
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
response = completion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
stream=True,
)
# print("\nModel Response", response)
for chunk in response:
print(f"chunk: {chunk}")
# raise Exception("it worked!")
# test_vertex_ai_anthropic_streaming()
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async():
load_vertex_ai_credentials()
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
response = await acompletion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
)
print(f"Model Response: {response}")
# asyncio.run(test_vertex_ai_anthropic_async())
@pytest.mark.skip(
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
)
@pytest.mark.asyncio
async def test_vertex_ai_anthropic_async_streaming():
load_vertex_ai_credentials()
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
response = await acompletion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
stream=True,
)
async for chunk in response:
print(f"chunk: {chunk}")
# asyncio.run(test_vertex_ai_anthropic_async_streaming())
def test_vertex_ai(): def test_vertex_ai():
import random import random

View file

@ -1,13 +1,13 @@
{ {
"type": "service_account", "type": "service_account",
"project_id": "adroit-crow-413218", "project_id": "reliablekeys",
"private_key_id": "", "private_key_id": "",
"private_key": "", "private_key": "",
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", "client_email": "73470430121-compute@developer.gserviceaccount.com",
"client_id": "104886546564708740969", "client_id": "108560959659377334173",
"auth_uri": "https://accounts.google.com/o/oauth2/auth", "auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token", "token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com",
"universe_domain": "googleapis.com" "universe_domain": "googleapis.com"
} }

View file

@ -4849,6 +4849,17 @@ def get_optional_params(
print_verbose( print_verbose(
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
) )
elif (
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
):
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAIAnthropicConfig().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker": elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in ## check if unsupported param passed in
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
@ -5229,7 +5240,9 @@ def get_optional_params(
extra_body # openai client supports `extra_body` param extra_body # openai client supports `extra_body` param
) )
else: # assume passing in params for openai/azure openai else: # assume passing in params for openai/azure openai
print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE") print_verbose(
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
)
supported_params = get_supported_openai_params( supported_params = get_supported_openai_params(
model=model, custom_llm_provider="openai" model=model, custom_llm_provider="openai"
) )
@ -8710,6 +8723,58 @@ class CustomStreamWrapper:
"finish_reason": finish_reason, "finish_reason": finish_reason,
} }
def handle_vertexai_anthropic_chunk(self, chunk):
"""
- MessageStartEvent(message=Message(id='msg_01LeRRgvX4gwkX3ryBVgtuYZ', content=[], model='claude-3-sonnet-20240229', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start'); custom_llm_provider: vertex_ai
- ContentBlockStartEvent(content_block=ContentBlock(text='', type='text'), index=0, type='content_block_start'); custom_llm_provider: vertex_ai
- ContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta'); custom_llm_provider: vertex_ai
"""
text = ""
prompt_tokens = None
completion_tokens = None
is_finished = False
finish_reason = None
type_chunk = getattr(chunk, "type", None)
if type_chunk == "message_start":
message = getattr(chunk, "message", None)
text = "" # lets us return a chunk with usage to user
_usage = getattr(message, "usage", None)
if _usage is not None:
prompt_tokens = getattr(_usage, "input_tokens", None)
completion_tokens = getattr(_usage, "output_tokens", None)
elif type_chunk == "content_block_delta":
"""
Anthropic content chunk
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
"""
delta = getattr(chunk, "delta", None)
if delta is not None:
text = getattr(delta, "text", "")
else:
text = ""
elif type_chunk == "message_delta":
"""
Anthropic
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
"""
# TODO - get usage from this chunk, set in response
delta = getattr(chunk, "delta", None)
if delta is not None:
finish_reason = getattr(delta, "stop_reason", "stop")
is_finished = True
_usage = getattr(chunk, "usage", None)
if _usage is not None:
prompt_tokens = getattr(_usage, "input_tokens", None)
completion_tokens = getattr(_usage, "output_tokens", None)
return {
"text": text,
"is_finished": is_finished,
"finish_reason": finish_reason,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
}
def handle_together_ai_chunk(self, chunk): def handle_together_ai_chunk(self, chunk):
chunk = chunk.decode("utf-8") chunk = chunk.decode("utf-8")
text = "" text = ""
@ -9377,7 +9442,33 @@ class CustomStreamWrapper:
else: else:
completion_obj["content"] = str(chunk) completion_obj["content"] = str(chunk)
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
if hasattr(chunk, "candidates") == True: if self.model.startswith("claude-3"):
response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk)
if response_obj is None:
return
completion_obj["content"] = response_obj["text"]
if response_obj.get("prompt_tokens", None) is not None:
model_response.usage.prompt_tokens = response_obj[
"prompt_tokens"
]
if response_obj.get("completion_tokens", None) is not None:
model_response.usage.completion_tokens = response_obj[
"completion_tokens"
]
if hasattr(model_response.usage, "prompt_tokens"):
model_response.usage.total_tokens = (
getattr(model_response.usage, "total_tokens", 0)
+ model_response.usage.prompt_tokens
)
if hasattr(model_response.usage, "completion_tokens"):
model_response.usage.total_tokens = (
getattr(model_response.usage, "total_tokens", 0)
+ model_response.usage.completion_tokens
)
if response_obj["is_finished"]:
self.received_finish_reason = response_obj["finish_reason"]
elif hasattr(chunk, "candidates") == True:
try: try:
try: try:
completion_obj["content"] = chunk.text completion_obj["content"] = chunk.text
@ -9629,6 +9720,18 @@ class CustomStreamWrapper:
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
## RETURN ARG ## RETURN ARG
if ( if (
"content" in completion_obj
and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) == 0
and hasattr(model_response.usage, "prompt_tokens")
):
if self.sent_first_chunk == False:
completion_obj["role"] = "assistant"
self.sent_first_chunk = True
model_response.choices[0].delta = Delta(**completion_obj)
print_verbose(f"returning model_response: {model_response}")
return model_response
elif (
"content" in completion_obj "content" in completion_obj
and isinstance(completion_obj["content"], str) and isinstance(completion_obj["content"], str)
and len(completion_obj["content"]) > 0 and len(completion_obj["content"]) > 0

View file

@ -1007,7 +1007,7 @@
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.000003, "input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015, "output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat" "mode": "chat"
}, },
"vertex_ai/claude-3-haiku@20240307": { "vertex_ai/claude-3-haiku@20240307": {
@ -1015,7 +1015,7 @@
"max_output_tokens": 4096, "max_output_tokens": 4096,
"input_cost_per_token": 0.00000025, "input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125, "output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai", "litellm_provider": "vertex_ai-anthropic_models",
"mode": "chat" "mode": "chat"
}, },
"textembedding-gecko": { "textembedding-gecko": {

View file

@ -15,6 +15,7 @@ prisma==0.11.0 # for db
mangum==0.17.0 # for aws lambda functions mangum==0.17.0 # for aws lambda functions
pynacl==1.5.0 # for encrypting keys pynacl==1.5.0 # for encrypting keys
google-cloud-aiplatform==1.43.0 # for vertex ai calls google-cloud-aiplatform==1.43.0 # for vertex ai calls
anthropic[vertex]==0.21.3
google-generativeai==0.3.2 # for vertex ai calls google-generativeai==0.3.2 # for vertex ai calls
async_generator==1.10.0 # for async ollama calls async_generator==1.10.0 # for async ollama calls
langfuse>=2.6.3 # for langfuse self-hosted logging langfuse>=2.6.3 # for langfuse self-hosted logging