From 05029fdcc77dc7e94b49436c37600a971b589fa9 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Sat, 23 Mar 2024 15:53:04 -0700 Subject: [PATCH 1/2] feat(vertex_ai_anthropic.py): Add support for claude 3 on vertex ai --- litellm/llms/custom_httpx/http_handler.py | 41 +++++++++ litellm/llms/vertex_ai_anthropic.py | 90 +++++++++++++++++++ ...odel_prices_and_context_window_backup.json | 16 ++++ model_prices_and_context_window.json | 16 ++++ 4 files changed, 163 insertions(+) create mode 100644 litellm/llms/custom_httpx/http_handler.py create mode 100644 litellm/llms/vertex_ai_anthropic.py diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py new file mode 100644 index 000000000..98fecb5d8 --- /dev/null +++ b/litellm/llms/custom_httpx/http_handler.py @@ -0,0 +1,41 @@ +import httpx, asyncio +from typing import Optional + + +class AsyncHTTPHandler: + def __init__(self, concurrent_limit=1000): + # Create a client with a connection pool + self.client = httpx.AsyncClient( + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ) + ) + + async def close(self): + # Close the client when you're done with it + await self.client.aclose() + + async def get( + self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + ): + response = await self.client.get(url, params=params, headers=headers) + return response + + async def post( + self, + url: str, + data: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ): + response = await self.client.post( + url, data=data, params=params, headers=headers + ) + return response + + def __del__(self) -> None: + try: + asyncio.get_running_loop().create_task(self.close()) + except Exception: + pass diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py new file mode 100644 index 000000000..4a54e087b --- /dev/null +++ b/litellm/llms/vertex_ai_anthropic.py @@ -0,0 +1,90 @@ +# What is this? +## Handler file for calling claude-3 on vertex ai +from typing import Callable, Optional, Any, Union, List +import litellm + + +class VertexAIAnthropicConfig: + """ + Reference: https://docs.anthropic.com/claude/reference/messages_post + + Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways: + + - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL. + - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16". + + The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters: + + - `max_tokens` Required (integer) max tokens, + - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" + - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py + - `temperature` Optional (float) The amount of randomness injected into the response + - `top_p` Optional (float) Use nucleus sampling. + - `top_k` Optional (int) Only sample from the top K options for each subsequent token + - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + max_tokens: Optional[int] = litellm.max_tokens + anthropic_version: Optional[str] = "bedrock-2023-05-31" + system: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + stop_sequences: Optional[List[str]] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + anthropic_version: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "max_tokens", + "tools", + "tool_choice", + "stream", + "stop", + "temperature", + "top_p", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "tools": + optional_params["tools"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + return optional_params diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index 7cbece528..a9fc993c9 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -904,6 +904,22 @@ "litellm_provider": "vertex_ai-vision-models", "mode": "chat" }, + "vertex_ai/claude-3-sonnet@20240229": { + "max_tokens": 200000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000003, + "output_cost_per_token": 0.000015, + "litellm_provider": "vertex_ai", + "mode": "chat" + }, + "vertex_ai/claude-3-haiku@20240307": { + "max_tokens": 200000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000025, + "output_cost_per_token": 0.00000125, + "litellm_provider": "vertex_ai", + "mode": "chat" + }, "textembedding-gecko": { "max_tokens": 3072, "max_input_tokens": 3072, diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 7cbece528..a9fc993c9 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -904,6 +904,22 @@ "litellm_provider": "vertex_ai-vision-models", "mode": "chat" }, + "vertex_ai/claude-3-sonnet@20240229": { + "max_tokens": 200000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.000003, + "output_cost_per_token": 0.000015, + "litellm_provider": "vertex_ai", + "mode": "chat" + }, + "vertex_ai/claude-3-haiku@20240307": { + "max_tokens": 200000, + "max_output_tokens": 4096, + "input_cost_per_token": 0.00000025, + "output_cost_per_token": 0.00000125, + "litellm_provider": "vertex_ai", + "mode": "chat" + }, "textembedding-gecko": { "max_tokens": 3072, "max_input_tokens": 3072, From 1d341970ba90e417d7b323823851000d87700752 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 2 Apr 2024 22:07:39 -0700 Subject: [PATCH 2/2] feat(vertex_ai_anthropic.py): add claude 3 on vertex ai support - working .completions call .completions() call works --- litellm/__init__.py | 1 + litellm/llms/custom_httpx/http_handler.py | 37 ++++ litellm/llms/vertex_ai_anthropic.py | 183 +++++++++++++++++- litellm/main.py | 45 +++-- .../tests/test_amazing_vertex_completion.py | 18 ++ litellm/tests/vertex_key.json | 10 +- 6 files changed, 273 insertions(+), 21 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index b14b07f5a..acae3a4ce 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -597,6 +597,7 @@ from .llms.nlp_cloud import NLPCloudConfig from .llms.aleph_alpha import AlephAlphaConfig from .llms.petals import PetalsConfig from .llms.vertex_ai import VertexAIConfig +from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 98fecb5d8..10314d831 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -39,3 +39,40 @@ class AsyncHTTPHandler: asyncio.get_running_loop().create_task(self.close()) except Exception: pass + + +class HTTPHandler: + def __init__(self, concurrent_limit=1000): + # Create a client with a connection pool + self.client = httpx.Client( + limits=httpx.Limits( + max_connections=concurrent_limit, + max_keepalive_connections=concurrent_limit, + ) + ) + + def close(self): + # Close the client when you're done with it + self.client.close() + + def get( + self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None + ): + response = self.client.get(url, params=params, headers=headers) + return response + + def post( + self, + url: str, + data: Optional[dict] = None, + params: Optional[dict] = None, + headers: Optional[dict] = None, + ): + response = self.client.post(url, data=data, params=params, headers=headers) + return response + + def __del__(self) -> None: + try: + self.close() + except Exception: + pass diff --git a/litellm/llms/vertex_ai_anthropic.py b/litellm/llms/vertex_ai_anthropic.py index 4a54e087b..e1ab527b7 100644 --- a/litellm/llms/vertex_ai_anthropic.py +++ b/litellm/llms/vertex_ai_anthropic.py @@ -1,7 +1,36 @@ # What is this? ## Handler file for calling claude-3 on vertex ai -from typing import Callable, Optional, Any, Union, List +import os, types +import json +from enum import Enum +import requests, copy +import time, uuid +from typing import Callable, Optional, List +from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper import litellm +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from .prompt_templates.factory import ( + contains_tag, + prompt_factory, + custom_prompt, + construct_tool_use_system_prompt, + extract_between_tags, + parse_xml_params, +) +import httpx + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs class VertexAIAnthropicConfig: @@ -27,7 +56,6 @@ class VertexAIAnthropicConfig: """ max_tokens: Optional[int] = litellm.max_tokens - anthropic_version: Optional[str] = "bedrock-2023-05-31" system: Optional[str] = None temperature: Optional[float] = None top_p: Optional[float] = None @@ -88,3 +116,154 @@ class VertexAIAnthropicConfig: if param == "top_p": optional_params["top_p"] = value return optional_params + + +""" +- Run client init +- Support async completion, streaming +""" + + +# makes headers for API call + + +def completion( + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + vertex_project=None, + vertex_location=None, + optional_params=None, + litellm_params=None, + logger_fn=None, + acompletion: bool = False, + client=None, +): + try: + import vertexai + except: + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + import google.auth # type: ignore + from google.auth.transport.requests import Request + from anthropic import AnthropicVertex + + ## Load Config + config = litellm.VertexAIAnthropicConfig.get_config() + for k, v in config.items(): + if k not in optional_params: + optional_params[k] = v + + ## Format Prompt + _is_function_call = False + messages = copy.deepcopy(messages) + optional_params = copy.deepcopy(optional_params) + # Separate system prompt from rest of message + system_prompt_indices = [] + system_prompt = "" + for idx, message in enumerate(messages): + if message["role"] == "system": + system_prompt += message["content"] + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + if len(system_prompt) > 0: + optional_params["system"] = system_prompt + # Format rest of message according to anthropic guidelines + try: + messages = prompt_factory( + model=model, messages=messages, custom_llm_provider="anthropic" + ) + except Exception as e: + raise VertexAIError(status_code=400, message=str(e)) + + ## Handle Tool Calling + if "tools" in optional_params: + _is_function_call = True + tool_calling_system_prompt = construct_tool_use_system_prompt( + tools=optional_params["tools"] + ) + optional_params["system"] = ( + optional_params.get("system", "\n") + tool_calling_system_prompt + ) # add the anthropic tool calling prompt to the system prompt + optional_params.pop("tools") + + stream = optional_params.pop("stream", None) + + data = { + "model": model, + "messages": messages, + **optional_params, + } + print_verbose(f"_is_function_call: {_is_function_call}") + + ## Completion Call + + print_verbose( + f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" + ) + if client is None: + vertex_ai_client = AnthropicVertex( + project_id=vertex_project, region=vertex_location + ) + else: + vertex_ai_client = client + + message = vertex_ai_client.messages.create(**data) # type: ignore + text_content = message.content[0].text + ## TOOL CALLING - OUTPUT PARSE + if text_content is not None and contains_tag("invoke", text_content): + function_name = extract_between_tags("tool_name", text_content)[0] + function_arguments_str = extract_between_tags("invoke", text_content)[ + 0 + ].strip() + function_arguments_str = f"{function_arguments_str}" + function_arguments = parse_xml_params(function_arguments_str) + _message = litellm.Message( + tool_calls=[ + { + "id": f"call_{uuid.uuid4()}", + "type": "function", + "function": { + "name": function_name, + "arguments": json.dumps(function_arguments), + }, + } + ], + content=None, + ) + model_response.choices[0].message = _message # type: ignore + else: + model_response.choices[0].message.content = text_content # type: ignore + model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason) + + ## CALCULATING USAGE + prompt_tokens = message.usage.input_tokens + completion_tokens = message.usage.output_tokens + + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=prompt_tokens + completion_tokens, + ) + model_response.usage = usage + return model_response + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/main.py b/litellm/main.py index 5f2b34482..5b48d1f72 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -62,6 +62,7 @@ from .llms import ( palm, gemini, vertex_ai, + vertex_ai_anthropic, maritalk, ) from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion @@ -1627,20 +1628,36 @@ def completion( or get_secret("VERTEXAI_LOCATION") ) - model_response = vertex_ai.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - logging_obj=logging, - acompletion=acompletion, - ) + if "claude-3" in model: + model_response = vertex_ai_anthropic.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + logging_obj=logging, + acompletion=acompletion, + ) + else: + model_response = vertex_ai.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + logging_obj=logging, + acompletion=acompletion, + ) if ( "stream" in optional_params diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 264bb7a70..0ebe9ce70 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -84,6 +84,24 @@ async def get_response(): pytest.fail(f"An error occurred - {str(e)}") +def test_vertex_ai_anthropic(): + load_vertex_ai_credentials() + + model = "claude-3-sonnet@20240229" + + vertex_ai_project = "adroit-crow-413218" + vertex_ai_location = "asia-southeast1" + + response = completion( + model="vertex_ai/" + model, + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + vertex_ai_project=vertex_ai_project, + vertex_ai_location=vertex_ai_location, + ) + print("\nModel Response", response) + + def test_vertex_ai(): import random diff --git a/litellm/tests/vertex_key.json b/litellm/tests/vertex_key.json index bd319ac94..e2fd8512b 100644 --- a/litellm/tests/vertex_key.json +++ b/litellm/tests/vertex_key.json @@ -1,13 +1,13 @@ { "type": "service_account", - "project_id": "reliablekeys", + "project_id": "adroit-crow-413218", "private_key_id": "", "private_key": "", - "client_email": "73470430121-compute@developer.gserviceaccount.com", - "client_id": "108560959659377334173", + "client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com", + "client_id": "104886546564708740969", "auth_uri": "https://accounts.google.com/o/oauth2/auth", "token_uri": "https://oauth2.googleapis.com/token", "auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs", - "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com", + "client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com", "universe_domain": "googleapis.com" -} \ No newline at end of file +}