Merge pull request #4845 from BerriAI/litellm_vertex_ai_llama3_1_api

feat(vertex_ai_llama.py): vertex ai llama3.1 api support
This commit is contained in:
Krish Dholakia 2024-07-23 21:51:46 -07:00 committed by GitHub
commit 6c580ac8dc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
9 changed files with 338 additions and 19 deletions

View file

@ -357,6 +357,7 @@ vertex_text_models: List = []
vertex_code_text_models: List = []
vertex_embedding_models: List = []
vertex_anthropic_models: List = []
vertex_llama3_models: List = []
ai21_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
@ -399,6 +400,9 @@ for key, value in model_cost.items():
elif value.get("litellm_provider") == "vertex_ai-anthropic_models":
key = key.replace("vertex_ai/", "")
vertex_anthropic_models.append(key)
elif value.get("litellm_provider") == "vertex_ai-llama_models":
key = key.replace("vertex_ai/", "")
vertex_llama3_models.append(key)
elif value.get("litellm_provider") == "ai21":
ai21_models.append(key)
elif value.get("litellm_provider") == "nlp_cloud":
@ -828,6 +832,7 @@ from .llms.petals import PetalsConfig
from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig
from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig
from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig
from .llms.vertex_ai_llama import VertexAILlama3Config
from .llms.sagemaker import SagemakerConfig
from .llms.ollama import OllamaConfig
from .llms.ollama_chat import OllamaChatConfig

View file

@ -0,0 +1,203 @@
# What is this?
## Handler for calling llama 3.1 API on Vertex AI
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAILlama3Config:
"""
Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming
The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = None
def __init__(
self,
max_tokens: Optional[int] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"stream",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
return optional_params
class VertexAILlama3(BaseLLM):
def __init__(self) -> None:
pass
def create_vertex_llama3_url(
self, vertex_location: str, vertex_project: str
) -> str:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
from google.cloud import aiplatform
from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.vertex_httpx import VertexLLM
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
openai_chat_completions = OpenAIChatCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
# for k, v in config.items():
# if k not in optional_params:
# optional_params[k] = v
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
api_base = self.create_vertex_llama3_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
return openai_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

View file

@ -1189,7 +1189,7 @@ class VertexLLM(BaseLLM):
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=response.text)
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")

View file

@ -120,6 +120,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_llama import VertexAILlama3
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
@ -156,6 +157,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_llama_chat_completion = VertexAILlama3()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -2064,7 +2066,26 @@ def completion(
timeout=timeout,
client=client,
)
elif model.startswith("meta/"):
model_response = vertex_llama_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
else:
model_response = vertex_ai.completion(
model=model,
@ -2478,28 +2499,25 @@ def completion(
return generator
response = generator
elif custom_llm_provider == "triton":
api_base = (
litellm.api_base or api_base
)
api_base = litellm.api_base or api_base
model_response = triton_chat_completions.completion(
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion,
)
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare":
api_key = (
api_key

View file

@ -1948,6 +1948,16 @@
"supports_function_calling": true,
"supports_vision": true
},
"vertex_ai/meta/llama3-405b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/imagegeneration@006": {
"cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models",

View file

@ -895,6 +895,52 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode):
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
from litellm.tests.test_completion import response_format_tests
@pytest.mark.parametrize(
"model", ["vertex_ai/meta/llama3-405b-instruct-maas"]
) # "vertex_ai",
@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai",
@pytest.mark.asyncio
async def test_llama_3_httpx(model, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
data = {
"model": model,
"messages": messages,
}
if sync_mode:
response = litellm.completion(**data)
else:
response = await litellm.acompletion(**data)
response_format_tests(response=response)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Quota exceeded" in str(e):
pass
else:
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200

View file

@ -128,6 +128,19 @@ def test_azure_ai_mistral_optional_params():
assert "user" not in optional_params
def test_vertex_ai_llama_3_optional_params():
litellm.vertex_llama3_models = ["meta/llama3-405b-instruct-maas"]
litellm.drop_params = True
optional_params = get_optional_params(
model="meta/llama3-405b-instruct-maas",
user="John",
custom_llm_provider="vertex_ai",
max_tokens=10,
temperature=0.2,
)
assert "user" not in optional_params
def test_azure_gpt_optional_params_gpt_vision():
# for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here
optional_params = litellm.utils.get_optional_params(

View file

@ -3088,6 +3088,15 @@ def get_optional_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models:
supported_params = get_supported_openai_params(
model=model, custom_llm_provider=custom_llm_provider
)
_check_valid_arg(supported_params=supported_params)
optional_params = litellm.VertexAILlama3Config().map_openai_params(
non_default_params=non_default_params,
optional_params=optional_params,
)
elif custom_llm_provider == "sagemaker":
## check if unsupported param passed in
supported_params = get_supported_openai_params(
@ -4189,6 +4198,9 @@ def get_supported_openai_params(
return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params()
elif custom_llm_provider == "vertex_ai":
if request_type == "chat_completion":
if model.startswith("meta/"):
return litellm.VertexAILlama3Config().get_supported_openai_params()
return litellm.VertexAIConfig().get_supported_openai_params()
elif request_type == "embeddings":
return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params()
@ -5752,10 +5764,12 @@ def convert_to_model_response_object(
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if "created" in response_object:
model_response_object.created = response_object["created"]
model_response_object.created = response_object["created"] or int(
time.time()
)
if "id" in response_object:
model_response_object.id = response_object["id"]
model_response_object.id = response_object["id"] or str(uuid.uuid4())
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object[

View file

@ -1948,6 +1948,16 @@
"supports_function_calling": true,
"supports_vision": true
},
"vertex_ai/meta/llama3-405b-instruct-maas": {
"max_tokens": 32000,
"max_input_tokens": 32000,
"max_output_tokens": 32000,
"input_cost_per_token": 0.0,
"output_cost_per_token": 0.0,
"litellm_provider": "vertex_ai-llama_models",
"mode": "chat",
"source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models"
},
"vertex_ai/imagegeneration@006": {
"cost_per_image": 0.020,
"litellm_provider": "vertex_ai-image-models",