forked from phoenix/litellm-mirror
fix(vertex_ai_anthropic.py): support streaming, async completion, async streaming for vertex ai anthropic
This commit is contained in:
parent
eb34306099
commit
f0c4ff6e60
8 changed files with 373 additions and 14 deletions
|
@ -269,6 +269,7 @@ vertex_code_chat_models: List = []
|
|||
vertex_text_models: List = []
|
||||
vertex_code_text_models: List = []
|
||||
vertex_embedding_models: List = []
|
||||
vertex_anthropic_models: List = []
|
||||
ai21_models: List = []
|
||||
nlp_cloud_models: List = []
|
||||
aleph_alpha_models: List = []
|
||||
|
@ -302,6 +303,9 @@ for key, value in model_cost.items():
|
|||
vertex_code_chat_models.append(key)
|
||||
elif value.get("litellm_provider") == "vertex_ai-embedding-models":
|
||||
vertex_embedding_models.append(key)
|
||||
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
|
||||
key = key.replace("vertex_ai/", "")
|
||||
vertex_anthropic_models.append(key)
|
||||
elif value.get("litellm_provider") == "ai21":
|
||||
ai21_models.append(key)
|
||||
elif value.get("litellm_provider") == "nlp_cloud":
|
||||
|
|
|
@ -55,7 +55,9 @@ class VertexAIAnthropicConfig:
|
|||
Note: Please make sure to modify the default parameters as required for your use case.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = litellm.max_tokens
|
||||
max_tokens: Optional[int] = (
|
||||
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
|
||||
)
|
||||
system: Optional[str] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
|
@ -69,6 +71,8 @@ class VertexAIAnthropicConfig:
|
|||
) -> None:
|
||||
locals_ = locals()
|
||||
for key, value in locals_.items():
|
||||
if key == "max_tokens" and value is None:
|
||||
value = self.max_tokens
|
||||
if key != "self" and value is not None:
|
||||
setattr(self.__class__, key, value)
|
||||
|
||||
|
@ -158,8 +162,6 @@ def completion(
|
|||
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
|
||||
)
|
||||
try:
|
||||
import google.auth # type: ignore
|
||||
from google.auth.transport.requests import Request
|
||||
from anthropic import AnthropicVertex
|
||||
|
||||
## Load Config
|
||||
|
@ -224,6 +226,58 @@ def completion(
|
|||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
if acompletion == True:
|
||||
"""
|
||||
- async streaming
|
||||
- async completion
|
||||
"""
|
||||
if stream is not None and stream == True:
|
||||
return async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
print_verbose=print_verbose,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
)
|
||||
else:
|
||||
return async_completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
data=data,
|
||||
print_verbose=print_verbose,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
optional_params=optional_params,
|
||||
client=client,
|
||||
)
|
||||
if stream is not None and stream == True:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
||||
return response
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
|
||||
message = vertex_ai_client.messages.create(**data) # type: ignore
|
||||
text_content = message.content[0].text
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
|
@ -267,3 +321,115 @@ def completion(
|
|||
return model_response
|
||||
except Exception as e:
|
||||
raise VertexAIError(status_code=500, message=str(e))
|
||||
|
||||
|
||||
async def async_completion(
|
||||
model: str,
|
||||
messages: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
message = await vertex_ai_client.messages.create(**data) # type: ignore
|
||||
text_content = message.content[0].text
|
||||
## TOOL CALLING - OUTPUT PARSE
|
||||
if text_content is not None and contains_tag("invoke", text_content):
|
||||
function_name = extract_between_tags("tool_name", text_content)[0]
|
||||
function_arguments_str = extract_between_tags("invoke", text_content)[0].strip()
|
||||
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
|
||||
function_arguments = parse_xml_params(function_arguments_str)
|
||||
_message = litellm.Message(
|
||||
tool_calls=[
|
||||
{
|
||||
"id": f"call_{uuid.uuid4()}",
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": function_name,
|
||||
"arguments": json.dumps(function_arguments),
|
||||
},
|
||||
}
|
||||
],
|
||||
content=None,
|
||||
)
|
||||
model_response.choices[0].message = _message # type: ignore
|
||||
else:
|
||||
model_response.choices[0].message.content = text_content # type: ignore
|
||||
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
|
||||
|
||||
## CALCULATING USAGE
|
||||
prompt_tokens = message.usage.input_tokens
|
||||
completion_tokens = message.usage.output_tokens
|
||||
|
||||
model_response["created"] = int(time.time())
|
||||
model_response["model"] = model
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=prompt_tokens + completion_tokens,
|
||||
)
|
||||
model_response.usage = usage
|
||||
return model_response
|
||||
|
||||
|
||||
async def async_streaming(
|
||||
model: str,
|
||||
messages: list,
|
||||
data: dict,
|
||||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
logging_obj,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
):
|
||||
from anthropic import AsyncAnthropicVertex
|
||||
|
||||
if client is None:
|
||||
vertex_ai_client = AsyncAnthropicVertex(
|
||||
project_id=vertex_project, region=vertex_location
|
||||
)
|
||||
else:
|
||||
vertex_ai_client = client
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=messages,
|
||||
api_key=None,
|
||||
additional_args={
|
||||
"complete_input_dict": optional_params,
|
||||
},
|
||||
)
|
||||
response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore
|
||||
logging_obj.post_call(input=messages, api_key=None, original_response=response)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
custom_llm_provider="vertex_ai",
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
return streamwrapper
|
||||
|
|
|
@ -1007,7 +1007,7 @@
|
|||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "vertex_ai",
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat"
|
||||
},
|
||||
"vertex_ai/claude-3-haiku@20240307": {
|
||||
|
@ -1015,7 +1015,7 @@
|
|||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000025,
|
||||
"output_cost_per_token": 0.00000125,
|
||||
"litellm_provider": "vertex_ai",
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat"
|
||||
},
|
||||
"textembedding-gecko": {
|
||||
|
|
|
@ -12,6 +12,7 @@ import pytest, asyncio
|
|||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout, acompletion
|
||||
from litellm import RateLimitError
|
||||
from litellm.tests.test_streaming import streaming_format_tests
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
|
@ -102,6 +103,90 @@ def test_vertex_ai_anthropic():
|
|||
print("\nModel Response", response)
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
def test_vertex_ai_anthropic_streaming():
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
# litellm.set_verbose = True
|
||||
|
||||
model = "claude-3-sonnet@20240229"
|
||||
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
vertex_ai_location = "asia-southeast1"
|
||||
|
||||
response = completion(
|
||||
model="vertex_ai/" + model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
stream=True,
|
||||
)
|
||||
# print("\nModel Response", response)
|
||||
for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
|
||||
# raise Exception("it worked!")
|
||||
|
||||
|
||||
# test_vertex_ai_anthropic_streaming()
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_ai_anthropic_async():
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
model = "claude-3-sonnet@20240229"
|
||||
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
vertex_ai_location = "asia-southeast1"
|
||||
|
||||
response = await acompletion(
|
||||
model="vertex_ai/" + model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
)
|
||||
print(f"Model Response: {response}")
|
||||
|
||||
|
||||
# asyncio.run(test_vertex_ai_anthropic_async())
|
||||
|
||||
|
||||
@pytest.mark.skip(
|
||||
reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd."
|
||||
)
|
||||
@pytest.mark.asyncio
|
||||
async def test_vertex_ai_anthropic_async_streaming():
|
||||
load_vertex_ai_credentials()
|
||||
|
||||
model = "claude-3-sonnet@20240229"
|
||||
|
||||
vertex_ai_project = "adroit-crow-413218"
|
||||
vertex_ai_location = "asia-southeast1"
|
||||
|
||||
response = await acompletion(
|
||||
model="vertex_ai/" + model,
|
||||
messages=[{"role": "user", "content": "hi"}],
|
||||
temperature=0.7,
|
||||
vertex_ai_project=vertex_ai_project,
|
||||
vertex_ai_location=vertex_ai_location,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
async for chunk in response:
|
||||
print(f"chunk: {chunk}")
|
||||
|
||||
|
||||
# asyncio.run(test_vertex_ai_anthropic_async_streaming())
|
||||
|
||||
|
||||
def test_vertex_ai():
|
||||
import random
|
||||
|
||||
|
|
|
@ -1,13 +1,13 @@
|
|||
{
|
||||
"type": "service_account",
|
||||
"project_id": "adroit-crow-413218",
|
||||
"project_id": "reliablekeys",
|
||||
"private_key_id": "",
|
||||
"private_key": "",
|
||||
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
|
||||
"client_id": "104886546564708740969",
|
||||
"client_email": "73470430121-compute@developer.gserviceaccount.com",
|
||||
"client_id": "108560959659377334173",
|
||||
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
|
||||
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com",
|
||||
"universe_domain": "googleapis.com"
|
||||
}
|
||||
}
|
107
litellm/utils.py
107
litellm/utils.py
|
@ -4849,6 +4849,17 @@ def get_optional_params(
|
|||
print_verbose(
|
||||
f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}"
|
||||
)
|
||||
elif (
|
||||
custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models
|
||||
):
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
_check_valid_arg(supported_params=supported_params)
|
||||
optional_params = litellm.VertexAIAnthropicConfig().map_openai_params(
|
||||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
elif custom_llm_provider == "sagemaker":
|
||||
## check if unsupported param passed in
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -5229,7 +5240,9 @@ def get_optional_params(
|
|||
extra_body # openai client supports `extra_body` param
|
||||
)
|
||||
else: # assume passing in params for openai/azure openai
|
||||
print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE")
|
||||
print_verbose(
|
||||
f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}"
|
||||
)
|
||||
supported_params = get_supported_openai_params(
|
||||
model=model, custom_llm_provider="openai"
|
||||
)
|
||||
|
@ -8710,6 +8723,58 @@ class CustomStreamWrapper:
|
|||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_vertexai_anthropic_chunk(self, chunk):
|
||||
"""
|
||||
- MessageStartEvent(message=Message(id='msg_01LeRRgvX4gwkX3ryBVgtuYZ', content=[], model='claude-3-sonnet-20240229', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start'); custom_llm_provider: vertex_ai
|
||||
- ContentBlockStartEvent(content_block=ContentBlock(text='', type='text'), index=0, type='content_block_start'); custom_llm_provider: vertex_ai
|
||||
- ContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta'); custom_llm_provider: vertex_ai
|
||||
"""
|
||||
text = ""
|
||||
prompt_tokens = None
|
||||
completion_tokens = None
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
type_chunk = getattr(chunk, "type", None)
|
||||
if type_chunk == "message_start":
|
||||
message = getattr(chunk, "message", None)
|
||||
text = "" # lets us return a chunk with usage to user
|
||||
_usage = getattr(message, "usage", None)
|
||||
if _usage is not None:
|
||||
prompt_tokens = getattr(_usage, "input_tokens", None)
|
||||
completion_tokens = getattr(_usage, "output_tokens", None)
|
||||
elif type_chunk == "content_block_delta":
|
||||
"""
|
||||
Anthropic content chunk
|
||||
chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}}
|
||||
"""
|
||||
delta = getattr(chunk, "delta", None)
|
||||
if delta is not None:
|
||||
text = getattr(delta, "text", "")
|
||||
else:
|
||||
text = ""
|
||||
elif type_chunk == "message_delta":
|
||||
"""
|
||||
Anthropic
|
||||
chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}}
|
||||
"""
|
||||
# TODO - get usage from this chunk, set in response
|
||||
delta = getattr(chunk, "delta", None)
|
||||
if delta is not None:
|
||||
finish_reason = getattr(delta, "stop_reason", "stop")
|
||||
is_finished = True
|
||||
_usage = getattr(chunk, "usage", None)
|
||||
if _usage is not None:
|
||||
prompt_tokens = getattr(_usage, "input_tokens", None)
|
||||
completion_tokens = getattr(_usage, "output_tokens", None)
|
||||
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
def handle_together_ai_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
text = ""
|
||||
|
@ -9377,7 +9442,33 @@ class CustomStreamWrapper:
|
|||
else:
|
||||
completion_obj["content"] = str(chunk)
|
||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||
if hasattr(chunk, "candidates") == True:
|
||||
if self.model.startswith("claude-3"):
|
||||
response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk)
|
||||
if response_obj is None:
|
||||
return
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj.get("prompt_tokens", None) is not None:
|
||||
model_response.usage.prompt_tokens = response_obj[
|
||||
"prompt_tokens"
|
||||
]
|
||||
if response_obj.get("completion_tokens", None) is not None:
|
||||
model_response.usage.completion_tokens = response_obj[
|
||||
"completion_tokens"
|
||||
]
|
||||
if hasattr(model_response.usage, "prompt_tokens"):
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "total_tokens", 0)
|
||||
+ model_response.usage.prompt_tokens
|
||||
)
|
||||
if hasattr(model_response.usage, "completion_tokens"):
|
||||
model_response.usage.total_tokens = (
|
||||
getattr(model_response.usage, "total_tokens", 0)
|
||||
+ model_response.usage.completion_tokens
|
||||
)
|
||||
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif hasattr(chunk, "candidates") == True:
|
||||
try:
|
||||
try:
|
||||
completion_obj["content"] = chunk.text
|
||||
|
@ -9629,6 +9720,18 @@ class CustomStreamWrapper:
|
|||
print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}")
|
||||
## RETURN ARG
|
||||
if (
|
||||
"content" in completion_obj
|
||||
and isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) == 0
|
||||
and hasattr(model_response.usage, "prompt_tokens")
|
||||
):
|
||||
if self.sent_first_chunk == False:
|
||||
completion_obj["role"] = "assistant"
|
||||
self.sent_first_chunk = True
|
||||
model_response.choices[0].delta = Delta(**completion_obj)
|
||||
print_verbose(f"returning model_response: {model_response}")
|
||||
return model_response
|
||||
elif (
|
||||
"content" in completion_obj
|
||||
and isinstance(completion_obj["content"], str)
|
||||
and len(completion_obj["content"]) > 0
|
||||
|
|
|
@ -1007,7 +1007,7 @@
|
|||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.000003,
|
||||
"output_cost_per_token": 0.000015,
|
||||
"litellm_provider": "vertex_ai",
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat"
|
||||
},
|
||||
"vertex_ai/claude-3-haiku@20240307": {
|
||||
|
@ -1015,7 +1015,7 @@
|
|||
"max_output_tokens": 4096,
|
||||
"input_cost_per_token": 0.00000025,
|
||||
"output_cost_per_token": 0.00000125,
|
||||
"litellm_provider": "vertex_ai",
|
||||
"litellm_provider": "vertex_ai-anthropic_models",
|
||||
"mode": "chat"
|
||||
},
|
||||
"textembedding-gecko": {
|
||||
|
|
|
@ -15,6 +15,7 @@ prisma==0.11.0 # for db
|
|||
mangum==0.17.0 # for aws lambda functions
|
||||
pynacl==1.5.0 # for encrypting keys
|
||||
google-cloud-aiplatform==1.43.0 # for vertex ai calls
|
||||
anthropic[vertex]==0.21.3
|
||||
google-generativeai==0.3.2 # for vertex ai calls
|
||||
async_generator==1.10.0 # for async ollama calls
|
||||
langfuse>=2.6.3 # for langfuse self-hosted logging
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue