From f0c4ff6e605ebbd27d0a90a397e3e987d6d2c35f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 5 Apr 2024 09:27:48 -0700 Subject: [PATCH] fix(vertex_ai_anthropic.py): support streaming, async completion, async streaming for vertex ai anthropic --- litellm/__init__.py | 4 + litellm/llms/vertex_ai_anthropic.py | 172 +++++++++++++++++- ...odel_prices_and_context_window_backup.json | 4 +- .../tests/test_amazing_vertex_completion.py | 85 +++++++++ litellm/tests/vertex_key.json | 10 +- litellm/utils.py | 107 ++++++++++- model_prices_and_context_window.json | 4 +- requirements.txt | 1 + 8 files changed, 373 insertions(+), 14 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index d26826830..476f57e10 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -269,6 +269,7 @@ vertex_code_chat_models: List = [] vertex_text_models: List = [] vertex_code_text_models: List = [] vertex_embedding_models: List = [] +vertex_anthropic_models: List = [] ai21_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] @@ -302,6 +303,9 @@ for key, value in model_cost.items(): vertex_code_chat_models.append(key) elif value.get("litellm_provider") == "vertex_ai-embedding-models": vertex_embedding_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-anthropic_models": + key = key.replace("vertex_ai/", "") + vertex_anthropic_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index e1ab527b7..adedb647a 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -55,7 +55,9 @@ class VertexAIAnthropicConfig: Note: Please make sure to modify the default parameters as required for your use case. """ - max_tokens: Optional[int] = litellm.max_tokens + max_tokens: Optional[int] = ( + 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic. + ) system: Optional[str] = None temperature: Optional[float] = None top_p: Optional[float] = None @@ -69,6 +71,8 @@ class VertexAIAnthropicConfig: ) -> None: locals_ = locals() for key, value in locals_.items(): + if key == "max_tokens" and value is None: + value = self.max_tokens if key != "self" and value is not None: setattr(self.__class__, key, value) @@ -158,8 +162,6 @@ def completion( message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", ) try: - import google.auth # type: ignore - from google.auth.transport.requests import Request from anthropic import AnthropicVertex ## Load Config @@ -224,6 +226,58 @@ def completion( else: vertex_ai_client = client + if acompletion == True: + """ + - async streaming + - async completion + """ + if stream is not None and stream == True: + return async_streaming( + model=model, + messages=messages, + data=data, + print_verbose=print_verbose, + model_response=model_response, + logging_obj=logging_obj, + vertex_project=vertex_project, + vertex_location=vertex_location, + optional_params=optional_params, + client=client, + ) + else: + return async_completion( + model=model, + messages=messages, + data=data, + print_verbose=print_verbose, + model_response=model_response, + logging_obj=logging_obj, + vertex_project=vertex_project, + vertex_location=vertex_location, + optional_params=optional_params, + client=client, + ) + if stream is not None and stream == True: + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + }, + ) + response = vertex_ai_client.messages.create(**data, stream=True) # type: ignore + return response + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + }, + ) + message = vertex_ai_client.messages.create(**data) # type: ignore text_content = message.content[0].text ## TOOL CALLING - OUTPUT PARSE @@ -267,3 +321,115 @@ def completion( return model_response except Exception as e: raise VertexAIError(status_code=500, message=str(e)) + + +async def async_completion( + model: str, + messages: list, + data: dict, + model_response: ModelResponse, + print_verbose: Callable, + logging_obj, + vertex_project=None, + vertex_location=None, + optional_params=None, + client=None, +): + from anthropic import AsyncAnthropicVertex + + if client is None: + vertex_ai_client = AsyncAnthropicVertex( + project_id=vertex_project, region=vertex_location + ) + else: + vertex_ai_client = client + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + }, + ) + message = await vertex_ai_client.messages.create(**data) # type: ignore + text_content = message.content[0].text + ## TOOL CALLING - OUTPUT PARSE + if text_content is not None and contains_tag("invoke", text_content): + function_name = extract_between_tags("tool_name", text_content)[0] + function_arguments_str = extract_between_tags("invoke", text_content)[0].strip() + function_arguments_str = f"{function_arguments_str}" + function_arguments = parse_xml_params(function_arguments_str) + _message = litellm.Message( + tool_calls=[ + { + "id": f"call_{uuid.uuid4()}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(function_arguments), + }, + } + ], + content=None, + ) + model_response.choices[0].message = _message # type: ignore + else: + model_response.choices[0].message.content = text_content # type: ignore + model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) + + ## CALCULATING USAGE + prompt_tokens = message.usage.input_tokens + completion_tokens = message.usage.output_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response + + +async def async_streaming( + model: str, + messages: list, + data: dict, + model_response: ModelResponse, + print_verbose: Callable, + logging_obj, + vertex_project=None, + vertex_location=None, + optional_params=None, + client=None, +): + from anthropic import AsyncAnthropicVertex + + if client is None: + vertex_ai_client = AsyncAnthropicVertex( + project_id=vertex_project, region=vertex_location + ) + else: + vertex_ai_client = client + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key=None, + additional_args={ + "complete_input_dict": optional_params, + }, + ) + response = await vertex_ai_client.messages.create(**data, stream=True) # type: ignore + logging_obj.post_call(input=messages, api_key=None, original_response=response) + + streamwrapper = CustomStreamWrapper( + completion_stream=response, + model=model, + custom_llm_provider="vertex_ai", + logging_obj=logging_obj, + ) + + return streamwrapper diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 356da8a46..e5e95e8d9 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1007,7 +1007,7 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000003, "output_cost_per_token": 0.000015, - "litellm_provider": "vertex_ai", + "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat" }, "vertex_ai/claude-3-haiku@20240307": { @@ -1015,7 +1015,7 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.00000025, "output_cost_per_token": 0.00000125, - "litellm_provider": "vertex_ai", + "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat" }, "textembedding-gecko": { diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 5a954fe39..4eec5c35d 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -12,6 +12,7 @@ import pytest, asyncio import litellm from litellm import embedding, completion, completion_cost, Timeout, acompletion from litellm import RateLimitError +from litellm.tests.test_streaming import streaming_format_tests import json import os import tempfile @@ -102,6 +103,90 @@ def test_vertex_ai_anthropic(): print("\nModel Response", response) +@pytest.mark.skip( + reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +) +def test_vertex_ai_anthropic_streaming(): + load_vertex_ai_credentials() + + # litellm.set_verbose = True + + model = "claude-3-sonnet@20240229" + + vertex_ai_project = "adroit-crow-413218" + vertex_ai_location = "asia-southeast1" + + response = completion( + model="vertex_ai/" + model, + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + vertex_ai_project=vertex_ai_project, + vertex_ai_location=vertex_ai_location, + stream=True, + ) + # print("\nModel Response", response) + for chunk in response: + print(f"chunk: {chunk}") + + # raise Exception("it worked!") + + +# test_vertex_ai_anthropic_streaming() + + +@pytest.mark.skip( + reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +) +@pytest.mark.asyncio +async def test_vertex_ai_anthropic_async(): + load_vertex_ai_credentials() + + model = "claude-3-sonnet@20240229" + + vertex_ai_project = "adroit-crow-413218" + vertex_ai_location = "asia-southeast1" + + response = await acompletion( + model="vertex_ai/" + model, + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + vertex_ai_project=vertex_ai_project, + vertex_ai_location=vertex_ai_location, + ) + print(f"Model Response: {response}") + + +# asyncio.run(test_vertex_ai_anthropic_async()) + + +@pytest.mark.skip( + reason="Local test. Vertex AI Quota is low. Leads to rate limit errors on ci/cd." +) +@pytest.mark.asyncio +async def test_vertex_ai_anthropic_async_streaming(): + load_vertex_ai_credentials() + + model = "claude-3-sonnet@20240229" + + vertex_ai_project = "adroit-crow-413218" + vertex_ai_location = "asia-southeast1" + + response = await acompletion( + model="vertex_ai/" + model, + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + vertex_ai_project=vertex_ai_project, + vertex_ai_location=vertex_ai_location, + stream=True, + ) + + async for chunk in response: + print(f"chunk: {chunk}") + + +# asyncio.run(test_vertex_ai_anthropic_async_streaming()) + + def test_vertex_ai(): import random diff --git a/litellm/tests/vertex_key.json b/litellm/tests/vertex_key.json index e2fd8512b..bd319ac94 100644 --- a/litellm/tests/vertex_key.json +++ b/litellm/tests/vertex_key.json @@ -1,13 +1,13 @@ { "type": "service_account", - "project_id": "adroit-crow-413218", + "project_id": "reliablekeys", "private_key_id": "", "private_key": "", - "client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", - "client_id": "104886546564708740969", + "client_email": "73470430121-compute@developer.gserviceaccount.com", + "client_id": "108560959659377334173", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com", "universe_domain": "googleapis.com" -} +} \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 17a31751d..bcd2d20c3 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4849,6 +4849,17 @@ def get_optional_params( print_verbose( f"(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {optional_params}" ) + elif ( + custom_llm_provider == "vertex_ai" and model in litellm.vertex_anthropic_models + ): + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VertexAIAnthropicConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -5229,7 +5240,9 @@ def get_optional_params( extra_body # openai client supports `extra_body` param ) else: # assume passing in params for openai/azure openai - print_verbose(f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE") + print_verbose( + f"UNMAPPED PROVIDER, ASSUMING IT'S OPENAI/AZURE - model={model}, custom_llm_provider={custom_llm_provider}" + ) supported_params = get_supported_openai_params( model=model, custom_llm_provider="openai" ) @@ -8710,6 +8723,58 @@ class CustomStreamWrapper: "finish_reason": finish_reason, } + def handle_vertexai_anthropic_chunk(self, chunk): + """ + - MessageStartEvent(message=Message(id='msg_01LeRRgvX4gwkX3ryBVgtuYZ', content=[], model='claude-3-sonnet-20240229', role='assistant', stop_reason=None, stop_sequence=None, type='message', usage=Usage(input_tokens=8, output_tokens=1)), type='message_start'); custom_llm_provider: vertex_ai + - ContentBlockStartEvent(content_block=ContentBlock(text='', type='text'), index=0, type='content_block_start'); custom_llm_provider: vertex_ai + - ContentBlockDeltaEvent(delta=TextDelta(text='Hello', type='text_delta'), index=0, type='content_block_delta'); custom_llm_provider: vertex_ai + """ + text = "" + prompt_tokens = None + completion_tokens = None + is_finished = False + finish_reason = None + type_chunk = getattr(chunk, "type", None) + if type_chunk == "message_start": + message = getattr(chunk, "message", None) + text = "" # lets us return a chunk with usage to user + _usage = getattr(message, "usage", None) + if _usage is not None: + prompt_tokens = getattr(_usage, "input_tokens", None) + completion_tokens = getattr(_usage, "output_tokens", None) + elif type_chunk == "content_block_delta": + """ + Anthropic content chunk + chunk = {'type': 'content_block_delta', 'index': 0, 'delta': {'type': 'text_delta', 'text': 'Hello'}} + """ + delta = getattr(chunk, "delta", None) + if delta is not None: + text = getattr(delta, "text", "") + else: + text = "" + elif type_chunk == "message_delta": + """ + Anthropic + chunk = {'type': 'message_delta', 'delta': {'stop_reason': 'max_tokens', 'stop_sequence': None}, 'usage': {'output_tokens': 10}} + """ + # TODO - get usage from this chunk, set in response + delta = getattr(chunk, "delta", None) + if delta is not None: + finish_reason = getattr(delta, "stop_reason", "stop") + is_finished = True + _usage = getattr(chunk, "usage", None) + if _usage is not None: + prompt_tokens = getattr(_usage, "input_tokens", None) + completion_tokens = getattr(_usage, "output_tokens", None) + + return { + "text": text, + "is_finished": is_finished, + "finish_reason": finish_reason, + "prompt_tokens": prompt_tokens, + "completion_tokens": completion_tokens, + } + def handle_together_ai_chunk(self, chunk): chunk = chunk.decode("utf-8") text = "" @@ -9377,7 +9442,33 @@ class CustomStreamWrapper: else: completion_obj["content"] = str(chunk) elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"): - if hasattr(chunk, "candidates") == True: + if self.model.startswith("claude-3"): + response_obj = self.handle_vertexai_anthropic_chunk(chunk=chunk) + if response_obj is None: + return + completion_obj["content"] = response_obj["text"] + if response_obj.get("prompt_tokens", None) is not None: + model_response.usage.prompt_tokens = response_obj[ + "prompt_tokens" + ] + if response_obj.get("completion_tokens", None) is not None: + model_response.usage.completion_tokens = response_obj[ + "completion_tokens" + ] + if hasattr(model_response.usage, "prompt_tokens"): + model_response.usage.total_tokens = ( + getattr(model_response.usage, "total_tokens", 0) + + model_response.usage.prompt_tokens + ) + if hasattr(model_response.usage, "completion_tokens"): + model_response.usage.total_tokens = ( + getattr(model_response.usage, "total_tokens", 0) + + model_response.usage.completion_tokens + ) + + if response_obj["is_finished"]: + self.received_finish_reason = response_obj["finish_reason"] + elif hasattr(chunk, "candidates") == True: try: try: completion_obj["content"] = chunk.text @@ -9629,6 +9720,18 @@ class CustomStreamWrapper: print_verbose(f"self.sent_first_chunk: {self.sent_first_chunk}") ## RETURN ARG if ( + "content" in completion_obj + and isinstance(completion_obj["content"], str) + and len(completion_obj["content"]) == 0 + and hasattr(model_response.usage, "prompt_tokens") + ): + if self.sent_first_chunk == False: + completion_obj["role"] = "assistant" + self.sent_first_chunk = True + model_response.choices[0].delta = Delta(**completion_obj) + print_verbose(f"returning model_response: {model_response}") + return model_response + elif ( "content" in completion_obj and isinstance(completion_obj["content"], str) and len(completion_obj["content"]) > 0 diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 356da8a46..e5e95e8d9 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1007,7 +1007,7 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.000003, "output_cost_per_token": 0.000015, - "litellm_provider": "vertex_ai", + "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat" }, "vertex_ai/claude-3-haiku@20240307": { @@ -1015,7 +1015,7 @@ "max_output_tokens": 4096, "input_cost_per_token": 0.00000025, "output_cost_per_token": 0.00000125, - "litellm_provider": "vertex_ai", + "litellm_provider": "vertex_ai-anthropic_models", "mode": "chat" }, "textembedding-gecko": { diff --git a/requirements.txt b/requirements.txt index 42af759e0..f09dd7501 100644 --- a/requirements.txt +++ b/requirements.txt @@ -15,6 +15,7 @@ prisma==0.11.0 # for db mangum==0.17.0 # for aws lambda functions pynacl==1.5.0 # for encrypting keys google-cloud-aiplatform==1.43.0 # for vertex ai calls +anthropic[vertex]==0.21.3 google-generativeai==0.3.2 # for vertex ai calls async_generator==1.10.0 # for async ollama calls langfuse>=2.6.3 # for langfuse self-hosted logging