diff --git a/litellm/__init__.py b/litellm/__init__.py index 9bb9a81cd..5a10ae77c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -357,6 +357,7 @@ vertex_text_models: List = [] vertex_code_text_models: List = [] vertex_embedding_models: List = [] vertex_anthropic_models: List = [] +vertex_llama3_models: List = [] ai21_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] @@ -399,6 +400,9 @@ for key, value in model_cost.items(): elif value.get("litellm_provider") == "vertex_ai-anthropic_models": key = key.replace("vertex_ai/", "") vertex_anthropic_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-llama_models": + key = key.replace("vertex_ai/", "") + vertex_llama3_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": @@ -828,6 +832,7 @@ from .llms.petals import PetalsConfig from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig +from .llms.vertex_ai_llama import VertexAILlama3Config from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig diff --git a/litellm/llms/vertex_ai_llama.py b/litellm/llms/vertex_ai_llama.py new file mode 100644 index 000000000..f33c127f7 --- /dev/null +++ b/litellm/llms/vertex_ai_llama.py @@ -0,0 +1,203 @@ +# What is this? +## Handler for calling llama 3.1 API on Vertex AI +import copy +import json +import os +import time +import types +import uuid +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.anthropic import ( + AnthropicMessagesTool, + AnthropicMessagesToolChoice, +) +from litellm.types.llms.openai import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import ( + construct_tool_use_system_prompt, + contains_tag, + custom_prompt, + extract_between_tags, + parse_xml_params, + prompt_factory, + response_schema_prompt, +) + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class VertexAILlama3Config: + """ + Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming + + The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters: + + - `max_tokens` Required (integer) max tokens, + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + max_tokens: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key == "max_tokens" and value is None: + value = self.max_tokens + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "max_tokens", + "stream", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + + return optional_params + + +class VertexAILlama3(BaseLLM): + def __init__(self) -> None: + pass + + def create_vertex_llama3_url( + self, vertex_location: str, vertex_project: str + ) -> str: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], + timeout: Union[float, httpx.Timeout], + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + litellm_params=None, + logger_fn=None, + acompletion: bool = False, + client=None, + ): + try: + import vertexai + from google.cloud import aiplatform + + from litellm.llms.openai import OpenAIChatCompletion + from litellm.llms.vertex_httpx import VertexLLM + except Exception: + + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + + openai_chat_completions = OpenAIChatCompletion() + + ## Load Config + # config = litellm.VertexAILlama3.get_config() + # for k, v in config.items(): + # if k not in optional_params: + # optional_params[k] = v + + ## CONSTRUCT API BASE + stream: bool = optional_params.get("stream", False) or False + + optional_params["stream"] = stream + + api_base = self.create_vertex_llama3_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + return openai_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + client=client, + timeout=timeout, + ) + + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index a8de79aff..93d8f4282 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1189,7 +1189,7 @@ class VertexLLM(BaseLLM): response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise VertexAIError(status_code=error_code, message=response.text) + raise VertexAIError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise VertexAIError(status_code=408, message="Timeout error occurred.") diff --git a/litellm/main.py b/litellm/main.py index fad2e15cc..35fad5e02 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -120,6 +120,7 @@ from .llms.prompt_templates.factory import ( ) from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion +from .llms.vertex_ai_llama import VertexAILlama3 from .llms.vertex_httpx import VertexLLM from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent @@ -156,6 +157,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +vertex_llama_chat_completion = VertexAILlama3() watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -2064,7 +2066,26 @@ def completion( timeout=timeout, client=client, ) - + elif model.startswith("meta/"): + model_response = vertex_llama_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) else: model_response = vertex_ai.completion( model=model, @@ -2478,28 +2499,25 @@ def completion( return generator response = generator - + elif custom_llm_provider == "triton": - api_base = ( - litellm.api_base or api_base - ) + api_base = litellm.api_base or api_base model_response = triton_chat_completions.completion( - api_base=api_base, - timeout=timeout, # type: ignore - model=model, - messages=messages, - model_response=model_response, - optional_params=optional_params, - logging_obj=logging, - stream=stream, - acompletion=acompletion + api_base=api_base, + timeout=timeout, # type: ignore + model=model, + messages=messages, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging, + stream=stream, + acompletion=acompletion, ) ## RESPONSE OBJECT response = model_response return response - - + elif custom_llm_provider == "cloudflare": api_key = ( api_key diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index f86ea8bd7..e9e599945 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1948,6 +1948,16 @@ "supports_function_calling": true, "supports_vision": true }, + "vertex_ai/meta/llama3-405b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, "vertex_ai/imagegeneration@006": { "cost_per_image": 0.020, "litellm_provider": "vertex_ai-image-models", diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 3def5a1ec..b9762afcb 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -895,6 +895,52 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +from litellm.tests.test_completion import response_format_tests + + +@pytest.mark.parametrize( + "model", ["vertex_ai/meta/llama3-405b-instruct-maas"] +) # "vertex_ai", +@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai", +@pytest.mark.asyncio +async def test_llama_3_httpx(model, sync_mode): + try: + load_vertex_ai_credentials() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + data = { + "model": model, + "messages": messages, + } + if sync_mode: + response = litellm.completion(**data) + else: + response = await litellm.acompletion(**data) + + response_format_tests(response=response) + + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + def vertex_httpx_mock_reject_prompt_post(*args, **kwargs): mock_response = MagicMock() mock_response.status_code = 200 diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index bbfc88710..b8011960e 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -128,6 +128,19 @@ def test_azure_ai_mistral_optional_params(): assert "user" not in optional_params +def test_vertex_ai_llama_3_optional_params(): + litellm.vertex_llama3_models = ["meta/llama3-405b-instruct-maas"] + litellm.drop_params = True + optional_params = get_optional_params( + model="meta/llama3-405b-instruct-maas", + user="John", + custom_llm_provider="vertex_ai", + max_tokens=10, + temperature=0.2, + ) + assert "user" not in optional_params + + def test_azure_gpt_optional_params_gpt_vision(): # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here optional_params = litellm.utils.get_optional_params( diff --git a/litellm/utils.py b/litellm/utils.py index 7f615ab61..035c1c72f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3088,6 +3088,15 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, ) + elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models: + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VertexAILlama3Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -4189,6 +4198,9 @@ def get_supported_openai_params( return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params() elif custom_llm_provider == "vertex_ai": if request_type == "chat_completion": + if model.startswith("meta/"): + return litellm.VertexAILlama3Config().get_supported_openai_params() + return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() @@ -5752,10 +5764,12 @@ def convert_to_model_response_object( model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore if "created" in response_object: - model_response_object.created = response_object["created"] + model_response_object.created = response_object["created"] or int( + time.time() + ) if "id" in response_object: - model_response_object.id = response_object["id"] + model_response_object.id = response_object["id"] or str(uuid.uuid4()) if "system_fingerprint" in response_object: model_response_object.system_fingerprint = response_object[ diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index f86ea8bd7..e9e599945 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1948,6 +1948,16 @@ "supports_function_calling": true, "supports_vision": true }, + "vertex_ai/meta/llama3-405b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, "vertex_ai/imagegeneration@006": { "cost_per_image": 0.020, "litellm_provider": "vertex_ai-image-models",