feat(vertex_ai_llama.py): vertex ai llama3.1 api support

Initial working commit for vertex ai llama 3.1 api support
This commit is contained in:
Krrish Dholakia 2024-07-23 17:07:30 -07:00
parent 169da8b8d0
commit 83ef52e180
5 changed files with 355 additions and 19 deletions

View file

@ -0,0 +1,270 @@
# What is this?
## Handler for calling llama 3.1 API on Vertex AI
import copy
import json
import os
import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Optional, Tuple, Union
import httpx # type: ignore
import requests # type: ignore
import litellm
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.types.llms.anthropic import (
AnthropicMessagesTool,
AnthropicMessagesToolChoice,
)
from litellm.types.llms.openai import (
ChatCompletionToolParam,
ChatCompletionToolParamFunctionChunk,
)
from litellm.types.utils import ResponseFormatChunk
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
from .base import BaseLLM
from .prompt_templates.factory import (
construct_tool_use_system_prompt,
contains_tag,
custom_prompt,
extract_between_tags,
parse_xml_params,
prompt_factory,
response_schema_prompt,
)
class VertexAIError(Exception):
def __init__(self, status_code, message):
self.status_code = status_code
self.message = message
self.request = httpx.Request(
method="POST", url=" https://cloud.google.com/vertex-ai/"
)
self.response = httpx.Response(status_code=status_code, request=self.request)
super().__init__(
self.message
) # Call the base class constructor with the parameters it needs
class VertexAILlama3Config:
"""
Reference:https://docs.anthropic.com/claude/reference/messages_post
Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways:
- `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL.
- `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16".
The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters:
- `max_tokens` Required (integer) max tokens,
- `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31"
- `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py
- `temperature` Optional (float) The amount of randomness injected into the response
- `top_p` Optional (float) Use nucleus sampling.
- `top_k` Optional (int) Only sample from the top K options for each subsequent token
- `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating
Note: Please make sure to modify the default parameters as required for your use case.
"""
max_tokens: Optional[int] = (
4096 # anthropic max - setting this doesn't impact response, but is required by anthropic.
)
system: Optional[str] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
top_k: Optional[int] = None
stop_sequences: Optional[List[str]] = None
def __init__(
self,
max_tokens: Optional[int] = None,
anthropic_version: Optional[str] = None,
) -> None:
locals_ = locals()
for key, value in locals_.items():
if key == "max_tokens" and value is None:
value = self.max_tokens
if key != "self" and value is not None:
setattr(self.__class__, key, value)
@classmethod
def get_config(cls):
return {
k: v
for k, v in cls.__dict__.items()
if not k.startswith("__")
and not isinstance(
v,
(
types.FunctionType,
types.BuiltinFunctionType,
classmethod,
staticmethod,
),
)
and v is not None
}
def get_supported_openai_params(self):
return [
"max_tokens",
"tools",
"tool_choice",
"stream",
"stop",
"temperature",
"top_p",
"response_format",
]
def map_openai_params(self, non_default_params: dict, optional_params: dict):
for param, value in non_default_params.items():
if param == "max_tokens":
optional_params["max_tokens"] = value
if param == "tools":
optional_params["tools"] = value
if param == "tool_choice":
_tool_choice: Optional[AnthropicMessagesToolChoice] = None
if value == "auto":
_tool_choice = {"type": "auto"}
elif value == "required":
_tool_choice = {"type": "any"}
elif isinstance(value, dict):
_tool_choice = {"type": "tool", "name": value["function"]["name"]}
if _tool_choice is not None:
optional_params["tool_choice"] = _tool_choice
if param == "stream":
optional_params["stream"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "temperature":
optional_params["temperature"] = value
if param == "top_p":
optional_params["top_p"] = value
if param == "response_format" and "response_schema" in value:
"""
When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode
- You usually want to provide a single tool
- You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool
- Remember that the model will pass the input to the tool, so the name of the tool and description should be from the models perspective.
"""
_tool_choice = None
_tool_choice = {"name": "json_tool_call", "type": "tool"}
_tool = AnthropicMessagesTool(
name="json_tool_call",
input_schema={
"type": "object",
"properties": {"values": value["response_schema"]}, # type: ignore
},
)
optional_params["tools"] = [_tool]
optional_params["tool_choice"] = _tool_choice
optional_params["json_mode"] = True
return optional_params
class VertexAILlama3(BaseLLM):
def __init__(self) -> None:
pass
def create_vertex_llama3_url(
self, vertex_location: str, vertex_project: str
) -> str:
return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi"
def completion(
self,
model: str,
messages: list,
model_response: ModelResponse,
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict: dict,
headers: Optional[dict],
timeout: Union[float, httpx.Timeout],
vertex_project=None,
vertex_location=None,
vertex_credentials=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
client=None,
):
try:
import vertexai
from google.cloud import aiplatform
from litellm.llms.openai import OpenAIChatCompletion
from litellm.llms.vertex_httpx import VertexLLM
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""",
)
if not (
hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models")
):
raise VertexAIError(
status_code=400,
message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""",
)
try:
vertex_httpx_logic = VertexLLM()
access_token, project_id = vertex_httpx_logic._ensure_access_token(
credentials=vertex_credentials, project_id=vertex_project
)
openai_chat_completions = OpenAIChatCompletion()
## Load Config
# config = litellm.VertexAILlama3.get_config()
# for k, v in config.items():
# if k not in optional_params:
# optional_params[k] = v
## CONSTRUCT API BASE
stream: bool = optional_params.get("stream", False) or False
optional_params["stream"] = stream
api_base = self.create_vertex_llama3_url(
vertex_location=vertex_location or "us-central1",
vertex_project=vertex_project or project_id,
)
return openai_chat_completions.completion(
model=model,
messages=messages,
api_base=api_base,
api_key=access_token,
custom_prompt_dict=custom_prompt_dict,
model_response=model_response,
print_verbose=print_verbose,
logging_obj=logging_obj,
optional_params=optional_params,
acompletion=acompletion,
litellm_params=litellm_params,
logger_fn=logger_fn,
client=client,
timeout=timeout,
)
except Exception as e:
raise VertexAIError(status_code=500, message=str(e))

View file

@ -1189,7 +1189,7 @@ class VertexLLM(BaseLLM):
response.raise_for_status()
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise VertexAIError(status_code=error_code, message=response.text)
raise VertexAIError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException:
raise VertexAIError(status_code=408, message="Timeout error occurred.")

View file

@ -120,6 +120,7 @@ from .llms.prompt_templates.factory import (
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_ai_llama import VertexAILlama3
from .llms.vertex_httpx import VertexLLM
from .llms.watsonx import IBMWatsonXAI
from .types.llms.openai import HttpxBinaryResponseContent
@ -156,6 +157,7 @@ triton_chat_completions = TritonChatCompletion()
bedrock_chat_completion = BedrockLLM()
bedrock_converse_chat_completion = BedrockConverseLLM()
vertex_chat_completion = VertexLLM()
vertex_llama_chat_completion = VertexAILlama3()
watsonxai = IBMWatsonXAI()
####### COMPLETION ENDPOINTS ################
@ -2064,7 +2066,26 @@ def completion(
timeout=timeout,
client=client,
)
elif model.startswith("meta/"):
model_response = vertex_llama_chat_completion.completion(
model=model,
messages=messages,
model_response=model_response,
print_verbose=print_verbose,
optional_params=new_params,
litellm_params=litellm_params,
logger_fn=logger_fn,
encoding=encoding,
vertex_location=vertex_ai_location,
vertex_project=vertex_ai_project,
vertex_credentials=vertex_credentials,
logging_obj=logging,
acompletion=acompletion,
headers=headers,
custom_prompt_dict=custom_prompt_dict,
timeout=timeout,
client=client,
)
else:
model_response = vertex_ai.completion(
model=model,
@ -2478,28 +2499,25 @@ def completion(
return generator
response = generator
elif custom_llm_provider == "triton":
api_base = (
litellm.api_base or api_base
)
api_base = litellm.api_base or api_base
model_response = triton_chat_completions.completion(
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion
api_base=api_base,
timeout=timeout, # type: ignore
model=model,
messages=messages,
model_response=model_response,
optional_params=optional_params,
logging_obj=logging,
stream=stream,
acompletion=acompletion,
)
## RESPONSE OBJECT
response = model_response
return response
elif custom_llm_provider == "cloudflare":
api_key = (
api_key

View file

@ -895,6 +895,52 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode):
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
from litellm.tests.test_completion import response_format_tests
@pytest.mark.parametrize(
"model", ["vertex_ai/meta/llama3-405b-instruct-maas"]
) # "vertex_ai",
@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai",
@pytest.mark.asyncio
async def test_llama_3_httpx(model, sync_mode):
try:
load_vertex_ai_credentials()
litellm.set_verbose = True
messages = [
{
"role": "system",
"content": "Your name is Litellm Bot, you are a helpful assistant",
},
# User asks for their name and weather in San Francisco
{
"role": "user",
"content": "Hello, what is your name and can you tell me the weather?",
},
]
data = {
"model": model,
"messages": messages,
}
if sync_mode:
response = litellm.completion(**data)
else:
response = await litellm.acompletion(**data)
response_format_tests(response=response)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except Exception as e:
if "429 Quota exceeded" in str(e):
pass
else:
pytest.fail("An unexpected exception occurred - {}".format(str(e)))
def vertex_httpx_mock_reject_prompt_post(*args, **kwargs):
mock_response = MagicMock()
mock_response.status_code = 200

View file

@ -5752,10 +5752,12 @@ def convert_to_model_response_object(
model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore
if "created" in response_object:
model_response_object.created = response_object["created"]
model_response_object.created = response_object["created"] or int(
time.time()
)
if "id" in response_object:
model_response_object.id = response_object["id"]
model_response_object.id = response_object["id"] or str(uuid.uuid4())
if "system_fingerprint" in response_object:
model_response_object.system_fingerprint = response_object[