Merge pull request #2665 from BerriAI/litellm_claude_vertex_ai

[WIP] feat(vertex_ai_anthropic.py): Add support for claude 3 on vertex ai
This commit is contained in:
Krish Dholakia 2024-04-05 07:06:04 -07:00 committed by GitHub
commit eb34306099
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 434 additions and 19 deletions

View file

@ -603,6 +603,7 @@ from .llms.nlp_cloud import NLPCloudConfig
from .llms.aleph_alpha import AlephAlphaConfig
from .llms.petals import PetalsConfig
from .llms.vertex_ai import VertexAIConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig

View file

@ -0,0 +1,78 @@
import httpx, asyncio
from typing import Optional
class AsyncHTTPHandler:
def __init__(self, concurrent_limit=1000):
# Create a client with a connection pool
self.client = httpx.AsyncClient(
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
)
)
async def close(self):
# Close the client when you're done with it
await self.client.aclose()
async def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = await self.client.get(url, params=params, headers=headers)
return response
async def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = await self.client.post(
url, data=data, params=params, headers=headers
)
return response
def __del__(self) -> None:
try:
asyncio.get_running_loop().create_task(self.close())
except Exception:
pass
class HTTPHandler:
def __init__(self, concurrent_limit=1000):
# Create a client with a connection pool
self.client = httpx.Client(
limits=httpx.Limits(
max_connections=concurrent_limit,
max_keepalive_connections=concurrent_limit,
)
)
def close(self):
# Close the client when you're done with it
self.client.close()
def get(
self, url: str, params: Optional[dict] = None, headers: Optional[dict] = None
):
response = self.client.get(url, params=params, headers=headers)
return response
def post(
self,
url: str,
data: Optional[dict] = None,
params: Optional[dict] = None,
headers: Optional[dict] = None,
):
response = self.client.post(url, data=data, params=params, headers=headers)
return response
def __del__(self) -> None:
try:
self.close()
except Exception:
pass

View file

@ -0,0 +1,269 @@
# What is this?
## Handler file for calling claude-3 on vertex ai
import os, types
import json
from enum import Enum
import requests, copy
import time, uuid
from typing import Callable, Optional, List
from litellm.utils import ModelResponse, Usage, map_finish_reason, CustomStreamWrapper
import litellm
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from .prompt_templates.factory import (
contains_tag,
prompt_factory,
custom_prompt,
construct_tool_use_system_prompt,
extract_between_tags,
parse_xml_params,
)
import httpx
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAIAnthropicConfig:
"""
Reference: https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = litellm.max_tokens
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
return optional_params
"""
- Run client init
- Support async completion, streaming
"""
# makes headers for API call
def completion(
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
vertex_project=None,
vertex_location=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
except:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
import google.auth # type: ignore
from google.auth.transport.requests import Request
from anthropic import AnthropicVertex
## Load Config
config = litellm.VertexAIAnthropicConfig.get_config()
for k, v in config.items():
if k not in optional_params:
optional_params[k] = v
## Format Prompt
_is_function_call = False
messages = copy.deepcopy(messages)
optional_params = copy.deepcopy(optional_params)
# Separate system prompt from rest of message
system_prompt_indices = []
system_prompt = ""
for idx, message in enumerate(messages):
if message["role"] == "system":
system_prompt += message["content"]
system_prompt_indices.append(idx)
if len(system_prompt_indices) > 0:
for idx in reversed(system_prompt_indices):
messages.pop(idx)
if len(system_prompt) > 0:
optional_params["system"] = system_prompt
# Format rest of message according to anthropic guidelines
try:
messages = prompt_factory(
model=model, messages=messages, custom_llm_provider="anthropic"
)
except Exception as e:
raise VertexAIError(status_code=400, message=str(e))
## Handle Tool Calling
if "tools" in optional_params:
_is_function_call = True
tool_calling_system_prompt = construct_tool_use_system_prompt(
tools=optional_params["tools"]
)
optional_params["system"] = (
optional_params.get("system", "\n") + tool_calling_system_prompt
) # add the anthropic tool calling prompt to the system prompt
optional_params.pop("tools")
stream = optional_params.pop("stream", None)
data = {
"model": model,
"messages": messages,
**optional_params,
}
print_verbose(f"_is_function_call: {_is_function_call}")
## Completion Call
print_verbose(
f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}"
)
if client is None:
vertex_ai_client = AnthropicVertex(
project_id=vertex_project, region=vertex_location
)
else:
vertex_ai_client = client
message = vertex_ai_client.messages.create(**data) # type: ignore
text_content = message.content[0].text
## TOOL CALLING - OUTPUT PARSE
if text_content is not None and contains_tag("invoke", text_content):
function_name = extract_between_tags("tool_name", text_content)[0]
function_arguments_str = extract_between_tags("invoke", text_content)[
0
].strip()
function_arguments_str = f"<invoke>{function_arguments_str}</invoke>"
function_arguments = parse_xml_params(function_arguments_str)
_message = litellm.Message(
tool_calls=[
{
"id": f"call_{uuid.uuid4()}",
"type": "function",
"function": {
"name": function_name,
"arguments": json.dumps(function_arguments),
},
}
],
content=None,
)
model_response.choices[0].message = _message # type: ignore
else:
model_response.choices[0].message.content = text_content # type: ignore
model_response.choices[0].finish_reason = map_finish_reason(message.stop_reason)
## CALCULATING USAGE
prompt_tokens = message.usage.input_tokens
completion_tokens = message.usage.output_tokens
model_response["created"] = int(time.time())
model_response["model"] = model
usage = Usage(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=prompt_tokens + completion_tokens,
)
model_response.usage = usage
return model_response
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

View file

@ -62,6 +62,7 @@ from .llms import (
palm,
gemini,
vertex_ai,
vertex_ai_anthropic,
maritalk,
)
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
@ -1674,20 +1675,36 @@ def completion(
or get_secret("VERTEXAI_LOCATION")
)
model_response = vertex_ai.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
logging_obj=logging,
acompletion=acompletion,
)
if "claude-3" in model:
model_response = vertex_ai_anthropic.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
logging_obj=logging,
acompletion=acompletion,
)
else:
model_response = vertex_ai.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=optional_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
logging_obj=logging,
acompletion=acompletion,
)
if (
"stream" in optional_params

View file

@ -1002,6 +1002,22 @@
"supports_function_calling": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"vertex_ai/claude-3-sonnet@20240229": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai",
"mode": "chat"
},
"vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai",
"mode": "chat"
},
"textembedding-gecko": {
"max_tokens": 3072,
"max_input_tokens": 3072,

View file

@ -84,6 +84,24 @@ async def get_response():
pytest.fail(f"An error occurred - {str(e)}")
def test_vertex_ai_anthropic():
load_vertex_ai_credentials()
model = "claude-3-sonnet@20240229"
vertex_ai_project = "adroit-crow-413218"
vertex_ai_location = "asia-southeast1"
response = completion(
model="vertex_ai/" + model,
messages=[{"role": "user", "content": "hi"}],
temperature=0.7,
vertex_ai_project=vertex_ai_project,
vertex_ai_location=vertex_ai_location,
)
print("\nModel Response", response)
def test_vertex_ai():
import random

View file

@ -1,13 +1,13 @@
{
"type": "service_account",
"project_id": "reliablekeys",
"project_id": "adroit-crow-413218",
"private_key_id": "",
"private_key": "",
"client_email": "73470430121-compute@developer.gserviceaccount.com",
"client_id": "108560959659377334173",
"client_email": "test-adroit-crow@adroit-crow-413218.iam.gserviceaccount.com",
"client_id": "104886546564708740969",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token",
"auth_provider_x509_cert_url": "https://www.googleapis.com/oauth2/v1/certs",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/73470430121-compute%40developer.gserviceaccount.com",
"client_x509_cert_url": "https://www.googleapis.com/robot/v1/metadata/x509/test-adroit-crow%40adroit-crow-413218.iam.gserviceaccount.com",
"universe_domain": "googleapis.com"
}
}

View file

@ -1002,6 +1002,22 @@
"supports_function_calling": true,
"source": "https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models#foundation_models"
},
"vertex_ai/claude-3-sonnet@20240229": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.000003,
"output_cost_per_token": 0.000015,
"litellm_provider": "vertex_ai",
"mode": "chat"
},
"vertex_ai/claude-3-haiku@20240307": {
"max_tokens": 200000,
"max_output_tokens": 4096,
"input_cost_per_token": 0.00000025,
"output_cost_per_token": 0.00000125,
"litellm_provider": "vertex_ai",
"mode": "chat"
},
"textembedding-gecko": {
"max_tokens": 3072,
"max_input_tokens": 3072,