From 83ef52e18005db1e0b6ee9756c1edebf5820887e Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Tue, 23 Jul 2024 17:07:30 -0700 Subject: [PATCH] feat(vertex_ai_llama.py): vertex ai llama3.1 api support Initial working commit for vertex ai llama 3.1 api support --- litellm/llms/vertex_ai_llama.py | 270 ++++++++++++++++++ litellm/llms/vertex_httpx.py | 2 +- litellm/main.py | 50 ++-- .../tests/test_amazing_vertex_completion.py | 46 +++ litellm/utils.py | 6 +- 5 files changed, 355 insertions(+), 19 deletions(-) create mode 100644 litellm/llms/vertex_ai_llama.py diff --git a/litellm/llms/vertex_ai_llama.py b/litellm/llms/vertex_ai_llama.py new file mode 100644 index 000000000..4b5407faa --- /dev/null +++ b/litellm/llms/vertex_ai_llama.py @@ -0,0 +1,270 @@ +# What is this? +## Handler for calling llama 3.1 API on Vertex AI +import copy +import json +import os +import time +import types +import uuid +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.anthropic import ( + AnthropicMessagesTool, + AnthropicMessagesToolChoice, +) +from litellm.types.llms.openai import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import ( + construct_tool_use_system_prompt, + contains_tag, + custom_prompt, + extract_between_tags, + parse_xml_params, + prompt_factory, + response_schema_prompt, +) + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class VertexAILlama3Config: + """ + Reference:https://docs.anthropic.com/claude/reference/messages_post + + Note that the API for Claude on Vertex differs from the Anthropic API documentation in the following ways: + + - `model` is not a valid parameter. The model is instead specified in the Google Cloud endpoint URL. + - `anthropic_version` is a required parameter and must be set to "vertex-2023-10-16". + + The class `VertexAIAnthropicConfig` provides configuration for the VertexAI's Anthropic API interface. Below are the parameters: + + - `max_tokens` Required (integer) max tokens, + - `anthropic_version` Required (string) version of anthropic for bedrock - e.g. "bedrock-2023-05-31" + - `system` Optional (string) the system prompt, conversion from openai format to this is handled in factory.py + - `temperature` Optional (float) The amount of randomness injected into the response + - `top_p` Optional (float) Use nucleus sampling. + - `top_k` Optional (int) Only sample from the top K options for each subsequent token + - `stop_sequences` Optional (List[str]) Custom text sequences that cause the model to stop generating + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + max_tokens: Optional[int] = ( + 4096 # anthropic max - setting this doesn't impact response, but is required by anthropic. + ) + system: Optional[str] = None + temperature: Optional[float] = None + top_p: Optional[float] = None + top_k: Optional[int] = None + stop_sequences: Optional[List[str]] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + anthropic_version: Optional[str] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key == "max_tokens" and value is None: + value = self.max_tokens + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "max_tokens", + "tools", + "tool_choice", + "stream", + "stop", + "temperature", + "top_p", + "response_format", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice: Optional[AnthropicMessagesToolChoice] = None + if value == "auto": + _tool_choice = {"type": "auto"} + elif value == "required": + _tool_choice = {"type": "any"} + elif isinstance(value, dict): + _tool_choice = {"type": "tool", "name": value["function"]["name"]} + + if _tool_choice is not None: + optional_params["tool_choice"] = _tool_choice + if param == "stream": + optional_params["stream"] = value + if param == "stop": + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["top_p"] = value + if param == "response_format" and "response_schema" in value: + """ + When using tools in this way: - https://docs.anthropic.com/en/docs/build-with-claude/tool-use#json-mode + - You usually want to provide a single tool + - You should set tool_choice (see Forcing tool use) to instruct the model to explicitly use that tool + - Remember that the model will pass the input to the tool, so the name of the tool and description should be from the model’s perspective. + """ + _tool_choice = None + _tool_choice = {"name": "json_tool_call", "type": "tool"} + + _tool = AnthropicMessagesTool( + name="json_tool_call", + input_schema={ + "type": "object", + "properties": {"values": value["response_schema"]}, # type: ignore + }, + ) + + optional_params["tools"] = [_tool] + optional_params["tool_choice"] = _tool_choice + optional_params["json_mode"] = True + + return optional_params + + +class VertexAILlama3(BaseLLM): + def __init__(self) -> None: + pass + + def create_vertex_llama3_url( + self, vertex_location: str, vertex_project: str + ) -> str: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], + timeout: Union[float, httpx.Timeout], + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + litellm_params=None, + logger_fn=None, + acompletion: bool = False, + client=None, + ): + try: + import vertexai + from google.cloud import aiplatform + + from litellm.llms.openai import OpenAIChatCompletion + from litellm.llms.vertex_httpx import VertexLLM + except Exception: + + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + + openai_chat_completions = OpenAIChatCompletion() + + ## Load Config + # config = litellm.VertexAILlama3.get_config() + # for k, v in config.items(): + # if k not in optional_params: + # optional_params[k] = v + + ## CONSTRUCT API BASE + stream: bool = optional_params.get("stream", False) or False + + optional_params["stream"] = stream + + api_base = self.create_vertex_llama3_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + return openai_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + client=client, + timeout=timeout, + ) + + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index a8de79aff..93d8f4282 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1189,7 +1189,7 @@ class VertexLLM(BaseLLM): response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise VertexAIError(status_code=error_code, message=response.text) + raise VertexAIError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise VertexAIError(status_code=408, message="Timeout error occurred.") diff --git a/litellm/main.py b/litellm/main.py index fad2e15cc..35fad5e02 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -120,6 +120,7 @@ from .llms.prompt_templates.factory import ( ) from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion +from .llms.vertex_ai_llama import VertexAILlama3 from .llms.vertex_httpx import VertexLLM from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent @@ -156,6 +157,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +vertex_llama_chat_completion = VertexAILlama3() watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -2064,7 +2066,26 @@ def completion( timeout=timeout, client=client, ) - + elif model.startswith("meta/"): + model_response = vertex_llama_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) else: model_response = vertex_ai.completion( model=model, @@ -2478,28 +2499,25 @@ def completion( return generator response = generator - + elif custom_llm_provider == "triton": - api_base = ( - litellm.api_base or api_base - ) + api_base = litellm.api_base or api_base model_response = triton_chat_completions.completion( - api_base=api_base, - timeout=timeout, # type: ignore - model=model, - messages=messages, - model_response=model_response, - optional_params=optional_params, - logging_obj=logging, - stream=stream, - acompletion=acompletion + api_base=api_base, + timeout=timeout, # type: ignore + model=model, + messages=messages, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging, + stream=stream, + acompletion=acompletion, ) ## RESPONSE OBJECT response = model_response return response - - + elif custom_llm_provider == "cloudflare": api_key = ( api_key diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 3def5a1ec..b9762afcb 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -895,6 +895,52 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +from litellm.tests.test_completion import response_format_tests + + +@pytest.mark.parametrize( + "model", ["vertex_ai/meta/llama3-405b-instruct-maas"] +) # "vertex_ai", +@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai", +@pytest.mark.asyncio +async def test_llama_3_httpx(model, sync_mode): + try: + load_vertex_ai_credentials() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + data = { + "model": model, + "messages": messages, + } + if sync_mode: + response = litellm.completion(**data) + else: + response = await litellm.acompletion(**data) + + response_format_tests(response=response) + + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + def vertex_httpx_mock_reject_prompt_post(*args, **kwargs): mock_response = MagicMock() mock_response.status_code = 200 diff --git a/litellm/utils.py b/litellm/utils.py index 7f615ab61..8baced4c5 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5752,10 +5752,12 @@ def convert_to_model_response_object( model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore if "created" in response_object: - model_response_object.created = response_object["created"] + model_response_object.created = response_object["created"] or int( + time.time() + ) if "id" in response_object: - model_response_object.id = response_object["id"] + model_response_object.id = response_object["id"] or str(uuid.uuid4()) if "system_fingerprint" in response_object: model_response_object.system_fingerprint = response_object[