diff --git a/docs/my-website/docs/providers/anthropic.md b/docs/my-website/docs/providers/anthropic.md index deb640b17..496343f87 100644 --- a/docs/my-website/docs/providers/anthropic.md +++ b/docs/my-website/docs/providers/anthropic.md @@ -56,7 +56,7 @@ for chunk in response: print(chunk["choices"][0]["delta"]["content"]) # same as openai format ``` -## OpenAI Proxy Usage +## Usage with LiteLLM Proxy Here's how to call Anthropic with the LiteLLM Proxy Server @@ -69,14 +69,6 @@ export ANTHROPIC_API_KEY="your-api-key" ### 2. Start the proxy - - -```bash -$ litellm --model claude-3-opus-20240229 - -# Server running on http://0.0.0.0:4000 -``` - ```yaml @@ -91,6 +83,14 @@ model_list: litellm --config /path/to/config.yaml ``` + + +```bash +$ litellm --model claude-3-opus-20240229 + +# Server running on http://0.0.0.0:4000 +``` + ### 3. Test it diff --git a/docs/my-website/docs/providers/vertex.md b/docs/my-website/docs/providers/vertex.md index 19442e11b..f87597046 100644 --- a/docs/my-website/docs/providers/vertex.md +++ b/docs/my-website/docs/providers/vertex.md @@ -749,6 +749,85 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ + +## Llama 3 API + +| Model Name | Function Call | +|------------------|--------------------------------------| +| meta/llama3-405b-instruct-maas | `completion('vertex_ai/meta/llama3-405b-instruct-maas', messages)` | + +### Usage + + + + +```python +from litellm import completion +import os + +os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = "" + +model = "meta/llama3-405b-instruct-maas" + +vertex_ai_project = "your-vertex-project" # can also set this as os.environ["VERTEXAI_PROJECT"] +vertex_ai_location = "your-vertex-location" # can also set this as os.environ["VERTEXAI_LOCATION"] + +response = completion( + model="vertex_ai/" + model, + messages=[{"role": "user", "content": "hi"}], + temperature=0.7, + vertex_ai_project=vertex_ai_project, + vertex_ai_location=vertex_ai_location, +) +print("\nModel Response", response) +``` + + + +**1. Add to config** + +```yaml +model_list: + - model_name: anthropic-llama + litellm_params: + model: vertex_ai/meta/llama3-405b-instruct-maas + vertex_ai_project: "my-test-project" + vertex_ai_location: "us-east-1" + - model_name: anthropic-llama + litellm_params: + model: vertex_ai/meta/llama3-405b-instruct-maas + vertex_ai_project: "my-test-project" + vertex_ai_location: "us-west-1" +``` + +**2. Start proxy** + +```bash +litellm --config /path/to/config.yaml + +# RUNNING at http://0.0.0.0:4000 +``` + +**3. Test it!** + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "anthropic-llama", # 👈 the 'model_name' in config + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' +``` + + + + ## Model Garden | Model Name | Function Call | |------------------|--------------------------------------| diff --git a/docs/my-website/docs/proxy/guardrails.md b/docs/my-website/docs/proxy/guardrails.md index 053fa8cab..2cfa3980e 100644 --- a/docs/my-website/docs/proxy/guardrails.md +++ b/docs/my-website/docs/proxy/guardrails.md @@ -266,6 +266,54 @@ curl --location 'http://0.0.0.0:4000/chat/completions' \ }' ``` +## Disable team from turning on/off guardrails + + +### 1. Disable team from modifying guardrails + +```bash +curl -X POST 'http://0.0.0.0:4000/team/update' \ +-H 'Authorization: Bearer sk-1234' \ +-H 'Content-Type: application/json' \ +-D '{ + "team_id": "4198d93c-d375-4c83-8d5a-71e7c5473e50", + "metadata": {"guardrails": {"modify_guardrails": false}} +}' +``` + +### 2. Try to disable guardrails for a call + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ +--header 'Content-Type: application/json' \ +--header 'Authorization: Bearer $LITELLM_VIRTUAL_KEY' \ +--data '{ +"model": "gpt-3.5-turbo", + "messages": [ + { + "role": "user", + "content": "Think of 10 random colors." + } + ], + "metadata": {"guardrails": {"hide_secrets": false}} +}' +``` + +### 3. Get 403 Error + +``` +{ + "error": { + "message": { + "error": "Your team does not have permission to modify guardrails." + }, + "type": "auth_error", + "param": "None", + "code": 403 + } +} +``` + Expect to NOT see `+1 412-612-9992` in your server logs on your callback. :::info diff --git a/litellm/__init__.py b/litellm/__init__.py index 9bb9a81cd..5a10ae77c 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -357,6 +357,7 @@ vertex_text_models: List = [] vertex_code_text_models: List = [] vertex_embedding_models: List = [] vertex_anthropic_models: List = [] +vertex_llama3_models: List = [] ai21_models: List = [] nlp_cloud_models: List = [] aleph_alpha_models: List = [] @@ -399,6 +400,9 @@ for key, value in model_cost.items(): elif value.get("litellm_provider") == "vertex_ai-anthropic_models": key = key.replace("vertex_ai/", "") vertex_anthropic_models.append(key) + elif value.get("litellm_provider") == "vertex_ai-llama_models": + key = key.replace("vertex_ai/", "") + vertex_llama3_models.append(key) elif value.get("litellm_provider") == "ai21": ai21_models.append(key) elif value.get("litellm_provider") == "nlp_cloud": @@ -828,6 +832,7 @@ from .llms.petals import PetalsConfig from .llms.vertex_httpx import VertexGeminiConfig, GoogleAIStudioGeminiConfig from .llms.vertex_ai import VertexAIConfig, VertexAITextEmbeddingConfig from .llms.vertex_ai_anthropic import VertexAIAnthropicConfig +from .llms.vertex_ai_llama import VertexAILlama3Config from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig diff --git a/litellm/llms/anthropic.py b/litellm/llms/anthropic.py index da51e887d..d3a3c38a4 100644 --- a/litellm/llms/anthropic.py +++ b/litellm/llms/anthropic.py @@ -385,6 +385,11 @@ class AnthropicConfig: if "user_id" in anthropic_message_request["metadata"]: new_kwargs["user"] = anthropic_message_request["metadata"]["user_id"] + # Pass litellm proxy specific metadata + if "litellm_metadata" in anthropic_message_request: + # metadata will be passed to litellm.acompletion(), it's a litellm_param + new_kwargs["metadata"] = anthropic_message_request.pop("litellm_metadata") + ## CONVERT TOOL CHOICE if "tool_choice" in anthropic_message_request: new_kwargs["tool_choice"] = self.translate_anthropic_tool_choice_to_openai( @@ -775,8 +780,17 @@ class AnthropicChatCompletion(BaseLLM): system_prompt = "" for idx, message in enumerate(messages): if message["role"] == "system": - system_prompt += message["content"] - system_prompt_indices.append(idx) + valid_content: bool = False + if isinstance(message["content"], str): + system_prompt += message["content"] + valid_content = True + elif isinstance(message["content"], list): + for content in message["content"]: + system_prompt += content.get("text", "") + valid_content = True + + if valid_content: + system_prompt_indices.append(idx) if len(system_prompt_indices) > 0: for idx in reversed(system_prompt_indices): messages.pop(idx) diff --git a/litellm/llms/vertex_ai_llama.py b/litellm/llms/vertex_ai_llama.py new file mode 100644 index 000000000..f33c127f7 --- /dev/null +++ b/litellm/llms/vertex_ai_llama.py @@ -0,0 +1,203 @@ +# What is this? +## Handler for calling llama 3.1 API on Vertex AI +import copy +import json +import os +import time +import types +import uuid +from enum import Enum +from typing import Any, Callable, List, Optional, Tuple, Union + +import httpx # type: ignore +import requests # type: ignore + +import litellm +from litellm.litellm_core_utils.core_helpers import map_finish_reason +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.types.llms.anthropic import ( + AnthropicMessagesTool, + AnthropicMessagesToolChoice, +) +from litellm.types.llms.openai import ( + ChatCompletionToolParam, + ChatCompletionToolParamFunctionChunk, +) +from litellm.types.utils import ResponseFormatChunk +from litellm.utils import CustomStreamWrapper, ModelResponse, Usage + +from .base import BaseLLM +from .prompt_templates.factory import ( + construct_tool_use_system_prompt, + contains_tag, + custom_prompt, + extract_between_tags, + parse_xml_params, + prompt_factory, + response_schema_prompt, +) + + +class VertexAIError(Exception): + def __init__(self, status_code, message): + self.status_code = status_code + self.message = message + self.request = httpx.Request( + method="POST", url=" https://cloud.google.com/vertex-ai/" + ) + self.response = httpx.Response(status_code=status_code, request=self.request) + super().__init__( + self.message + ) # Call the base class constructor with the parameters it needs + + +class VertexAILlama3Config: + """ + Reference:https://cloud.google.com/vertex-ai/generative-ai/docs/partner-models/llama#streaming + + The class `VertexAILlama3Config` provides configuration for the VertexAI's Llama API interface. Below are the parameters: + + - `max_tokens` Required (integer) max tokens, + + Note: Please make sure to modify the default parameters as required for your use case. + """ + + max_tokens: Optional[int] = None + + def __init__( + self, + max_tokens: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key == "max_tokens" and value is None: + value = self.max_tokens + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self): + return [ + "max_tokens", + "stream", + ] + + def map_openai_params(self, non_default_params: dict, optional_params: dict): + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["max_tokens"] = value + + return optional_params + + +class VertexAILlama3(BaseLLM): + def __init__(self) -> None: + pass + + def create_vertex_llama3_url( + self, vertex_location: str, vertex_project: str + ) -> str: + return f"https://{vertex_location}-aiplatform.googleapis.com/v1beta1/projects/{vertex_project}/locations/{vertex_location}/endpoints/openapi" + + def completion( + self, + model: str, + messages: list, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + custom_prompt_dict: dict, + headers: Optional[dict], + timeout: Union[float, httpx.Timeout], + vertex_project=None, + vertex_location=None, + vertex_credentials=None, + litellm_params=None, + logger_fn=None, + acompletion: bool = False, + client=None, + ): + try: + import vertexai + from google.cloud import aiplatform + + from litellm.llms.openai import OpenAIChatCompletion + from litellm.llms.vertex_httpx import VertexLLM + except Exception: + + raise VertexAIError( + status_code=400, + message="""vertexai import failed please run `pip install -U "google-cloud-aiplatform>=1.38"`""", + ) + + if not ( + hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") + ): + raise VertexAIError( + status_code=400, + message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", + ) + try: + + vertex_httpx_logic = VertexLLM() + + access_token, project_id = vertex_httpx_logic._ensure_access_token( + credentials=vertex_credentials, project_id=vertex_project + ) + + openai_chat_completions = OpenAIChatCompletion() + + ## Load Config + # config = litellm.VertexAILlama3.get_config() + # for k, v in config.items(): + # if k not in optional_params: + # optional_params[k] = v + + ## CONSTRUCT API BASE + stream: bool = optional_params.get("stream", False) or False + + optional_params["stream"] = stream + + api_base = self.create_vertex_llama3_url( + vertex_location=vertex_location or "us-central1", + vertex_project=vertex_project or project_id, + ) + + return openai_chat_completions.completion( + model=model, + messages=messages, + api_base=api_base, + api_key=access_token, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + logging_obj=logging_obj, + optional_params=optional_params, + acompletion=acompletion, + litellm_params=litellm_params, + logger_fn=logger_fn, + client=client, + timeout=timeout, + ) + + except Exception as e: + raise VertexAIError(status_code=500, message=str(e)) diff --git a/litellm/llms/vertex_httpx.py b/litellm/llms/vertex_httpx.py index a8de79aff..93d8f4282 100644 --- a/litellm/llms/vertex_httpx.py +++ b/litellm/llms/vertex_httpx.py @@ -1189,7 +1189,7 @@ class VertexLLM(BaseLLM): response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code - raise VertexAIError(status_code=error_code, message=response.text) + raise VertexAIError(status_code=error_code, message=err.response.text) except httpx.TimeoutException: raise VertexAIError(status_code=408, message="Timeout error occurred.") diff --git a/litellm/main.py b/litellm/main.py index fad2e15cc..35fad5e02 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -120,6 +120,7 @@ from .llms.prompt_templates.factory import ( ) from .llms.text_completion_codestral import CodestralTextCompletion from .llms.triton import TritonChatCompletion +from .llms.vertex_ai_llama import VertexAILlama3 from .llms.vertex_httpx import VertexLLM from .llms.watsonx import IBMWatsonXAI from .types.llms.openai import HttpxBinaryResponseContent @@ -156,6 +157,7 @@ triton_chat_completions = TritonChatCompletion() bedrock_chat_completion = BedrockLLM() bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() +vertex_llama_chat_completion = VertexAILlama3() watsonxai = IBMWatsonXAI() ####### COMPLETION ENDPOINTS ################ @@ -2064,7 +2066,26 @@ def completion( timeout=timeout, client=client, ) - + elif model.startswith("meta/"): + model_response = vertex_llama_chat_completion.completion( + model=model, + messages=messages, + model_response=model_response, + print_verbose=print_verbose, + optional_params=new_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + vertex_location=vertex_ai_location, + vertex_project=vertex_ai_project, + vertex_credentials=vertex_credentials, + logging_obj=logging, + acompletion=acompletion, + headers=headers, + custom_prompt_dict=custom_prompt_dict, + timeout=timeout, + client=client, + ) else: model_response = vertex_ai.completion( model=model, @@ -2478,28 +2499,25 @@ def completion( return generator response = generator - + elif custom_llm_provider == "triton": - api_base = ( - litellm.api_base or api_base - ) + api_base = litellm.api_base or api_base model_response = triton_chat_completions.completion( - api_base=api_base, - timeout=timeout, # type: ignore - model=model, - messages=messages, - model_response=model_response, - optional_params=optional_params, - logging_obj=logging, - stream=stream, - acompletion=acompletion + api_base=api_base, + timeout=timeout, # type: ignore + model=model, + messages=messages, + model_response=model_response, + optional_params=optional_params, + logging_obj=logging, + stream=stream, + acompletion=acompletion, ) ## RESPONSE OBJECT response = model_response return response - - + elif custom_llm_provider == "cloudflare": api_key = ( api_key diff --git a/litellm/model_prices_and_context_window_backup.json b/litellm/model_prices_and_context_window_backup.json index f86ea8bd7..e9e599945 100644 --- a/litellm/model_prices_and_context_window_backup.json +++ b/litellm/model_prices_and_context_window_backup.json @@ -1948,6 +1948,16 @@ "supports_function_calling": true, "supports_vision": true }, + "vertex_ai/meta/llama3-405b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, "vertex_ai/imagegeneration@006": { "cost_per_image": 0.020, "litellm_provider": "vertex_ai-image-models", diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index a1af38379..7e3c9a241 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -1,8 +1,7 @@ model_list: - - model_name: groq-llama3 + - model_name: anthropic-claude litellm_params: - model: groq/llama3-groq-70b-8192-tool-use-preview - api_key: os.environ/GROQ_API_KEY + model: claude-3-haiku-20240307 litellm_settings: callbacks: ["logfire"] diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 8909b1da3..7384dc30b 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -39,6 +39,9 @@ def _get_metadata_variable_name(request: Request) -> str: """ if "thread" in request.url.path or "assistant" in request.url.path: return "litellm_metadata" + if "/v1/messages" in request.url.path: + # anthropic API has a field called metadata + return "litellm_metadata" else: return "metadata" diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 0ac1d82e0..106b95453 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -657,7 +657,11 @@ async def _PROXY_track_cost_callback( global prisma_client, custom_db_client try: # check if it has collected an entire stream response - verbose_proxy_logger.debug("Proxy: In track_cost_callback for: %s", kwargs) + verbose_proxy_logger.debug( + "Proxy: In track_cost_callback for: kwargs=%s and completion_response: %s", + kwargs, + completion_response, + ) verbose_proxy_logger.debug( f"kwargs stream: {kwargs.get('stream', None)} + complete streaming response: {kwargs.get('complete_streaming_response', None)}" ) diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 528d7e98d..cf61635a0 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -183,12 +183,12 @@ model LiteLLM_SpendLogs { model String @default("") model_id String? @default("") // the model id stored in proxy model db model_group String? @default("") // public model_name / model_group - api_base String @default("") - user String @default("") - metadata Json @default("{}") - cache_hit String @default("") - cache_key String @default("") - request_tags Json @default("[]") + api_base String? @default("") + user String? @default("") + metadata Json? @default("{}") + cache_hit String? @default("") + cache_key String? @default("") + request_tags Json? @default("[]") team_id String? end_user String? requester_ip_address String? @@ -257,4 +257,4 @@ model LiteLLM_AuditLog { object_id String // id of the object being audited. This can be the key id, team id, user id, model id before_value Json? // value of the row updated_values Json? // value of the row after change -} \ No newline at end of file +} diff --git a/litellm/tests/test_amazing_vertex_completion.py b/litellm/tests/test_amazing_vertex_completion.py index 3def5a1ec..b9762afcb 100644 --- a/litellm/tests/test_amazing_vertex_completion.py +++ b/litellm/tests/test_amazing_vertex_completion.py @@ -895,6 +895,52 @@ async def test_gemini_pro_function_calling_httpx(model, sync_mode): pytest.fail("An unexpected exception occurred - {}".format(str(e))) +from litellm.tests.test_completion import response_format_tests + + +@pytest.mark.parametrize( + "model", ["vertex_ai/meta/llama3-405b-instruct-maas"] +) # "vertex_ai", +@pytest.mark.parametrize("sync_mode", [True, False]) # "vertex_ai", +@pytest.mark.asyncio +async def test_llama_3_httpx(model, sync_mode): + try: + load_vertex_ai_credentials() + litellm.set_verbose = True + + messages = [ + { + "role": "system", + "content": "Your name is Litellm Bot, you are a helpful assistant", + }, + # User asks for their name and weather in San Francisco + { + "role": "user", + "content": "Hello, what is your name and can you tell me the weather?", + }, + ] + + data = { + "model": model, + "messages": messages, + } + if sync_mode: + response = litellm.completion(**data) + else: + response = await litellm.acompletion(**data) + + response_format_tests(response=response) + + print(f"response: {response}") + except litellm.RateLimitError as e: + pass + except Exception as e: + if "429 Quota exceeded" in str(e): + pass + else: + pytest.fail("An unexpected exception occurred - {}".format(str(e))) + + def vertex_httpx_mock_reject_prompt_post(*args, **kwargs): mock_response = MagicMock() mock_response.status_code = 200 diff --git a/litellm/tests/test_anthropic_completion.py b/litellm/tests/test_anthropic_completion.py index cac0945d8..15d150a56 100644 --- a/litellm/tests/test_anthropic_completion.py +++ b/litellm/tests/test_anthropic_completion.py @@ -48,6 +48,42 @@ def test_anthropic_completion_input_translation(): ] +def test_anthropic_completion_input_translation_with_metadata(): + """ + Tests that cost tracking works as expected with LiteLLM Proxy + + LiteLLM Proxy will insert litellm_metadata for anthropic endpoints to track user_api_key and user_api_key_team_id + + This test ensures that the `litellm_metadata` is not present in the translated input + It ensures that `litellm.acompletion()` will receieve metadata which is a litellm specific param + """ + data = { + "model": "gpt-3.5-turbo", + "messages": [{"role": "user", "content": "Hey, how's it going?"}], + "litellm_metadata": { + "user_api_key": "88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b", + "user_api_key_alias": None, + "user_api_end_user_max_budget": None, + "litellm_api_version": "1.40.19", + "global_max_parallel_requests": None, + "user_api_key_user_id": "default_user_id", + "user_api_key_org_id": None, + "user_api_key_team_id": None, + "user_api_key_team_alias": None, + "user_api_key_team_max_budget": None, + "user_api_key_team_spend": None, + "user_api_key_spend": 0.0, + "user_api_key_max_budget": None, + "user_api_key_metadata": {}, + }, + } + translated_input = anthropic_adapter.translate_completion_input_params(kwargs=data) + + assert "litellm_metadata" not in translated_input + assert "metadata" in translated_input + assert translated_input["metadata"] == data["litellm_metadata"] + + def test_anthropic_completion_e2e(): litellm.set_verbose = True diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index c2ce836ef..31b7b8355 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -346,7 +346,7 @@ def test_completion_claude_3_empty_response(): messages = [ { "role": "system", - "content": "You are 2twNLGfqk4GMOn3ffp4p.", + "content": [{"type": "text", "text": "You are 2twNLGfqk4GMOn3ffp4p."}], }, {"role": "user", "content": "Hi gm!", "name": "ishaan"}, {"role": "assistant", "content": "Good morning! How are you doing today?"}, diff --git a/litellm/tests/test_embedding.py b/litellm/tests/test_embedding.py index 39a9e7f39..940f10e88 100644 --- a/litellm/tests/test_embedding.py +++ b/litellm/tests/test_embedding.py @@ -196,6 +196,28 @@ def test_openai_azure_embedding(): except Exception as e: pytest.fail(f"Error occurred: {e}") +@pytest.mark.skipif( + os.environ.get("CIRCLE_OIDC_TOKEN") is None, + reason="Cannot run without being in CircleCI Runner", +) +def test_openai_azure_embedding_with_oidc_and_cf(): + # TODO: Switch to our own Azure account, currently using ai.moda's account + os.environ["AZURE_TENANT_ID"] = "17c0a27a-1246-4aa1-a3b6-d294e80e783c" + os.environ["AZURE_CLIENT_ID"] = "4faf5422-b2bd-45e8-a6d7-46543a38acd0" + + try: + response = embedding( + model="azure/text-embedding-ada-002", + input=["Hello"], + azure_ad_token="oidc/circleci/", + api_base="https://gateway.ai.cloudflare.com/v1/0399b10e77ac6668c80404a5ff49eb37/litellm-test/azure-openai/eastus2-litellm", + api_version="2024-06-01", + ) + print(response) + + except Exception as e: + pytest.fail(f"Error occurred: {e}") + def test_openai_azure_embedding_optional_arg(mocker): mocked_create_embeddings = mocker.patch.object( diff --git a/litellm/tests/test_optional_params.py b/litellm/tests/test_optional_params.py index bbfc88710..b8011960e 100644 --- a/litellm/tests/test_optional_params.py +++ b/litellm/tests/test_optional_params.py @@ -128,6 +128,19 @@ def test_azure_ai_mistral_optional_params(): assert "user" not in optional_params +def test_vertex_ai_llama_3_optional_params(): + litellm.vertex_llama3_models = ["meta/llama3-405b-instruct-maas"] + litellm.drop_params = True + optional_params = get_optional_params( + model="meta/llama3-405b-instruct-maas", + user="John", + custom_llm_provider="vertex_ai", + max_tokens=10, + temperature=0.2, + ) + assert "user" not in optional_params + + def test_azure_gpt_optional_params_gpt_vision(): # for OpenAI, Azure all extra params need to get passed as extra_body to OpenAI python. We assert we actually set extra_body here optional_params = litellm.utils.get_optional_params( diff --git a/litellm/types/llms/anthropic.py b/litellm/types/llms/anthropic.py index 33f413ece..b41980afd 100644 --- a/litellm/types/llms/anthropic.py +++ b/litellm/types/llms/anthropic.py @@ -1,4 +1,4 @@ -from typing import Iterable, List, Optional, Union +from typing import Any, Dict, Iterable, List, Optional, Union from pydantic import BaseModel, validator from typing_extensions import Literal, Required, TypedDict @@ -113,6 +113,9 @@ class AnthropicMessagesRequest(TypedDict, total=False): top_k: int top_p: float + # litellm param - used for tracking litellm proxy metadata in the request + litellm_metadata: dict + class ContentTextBlockDelta(TypedDict): """ diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index 294e299db..35e442119 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -436,6 +436,7 @@ class ChatCompletionRequest(TypedDict, total=False): function_call: Union[str, dict] functions: List user: str + metadata: dict # litellm specific param class ChatCompletionDeltaChunk(TypedDict, total=False): diff --git a/litellm/utils.py b/litellm/utils.py index 7f615ab61..035c1c72f 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3088,6 +3088,15 @@ def get_optional_params( non_default_params=non_default_params, optional_params=optional_params, ) + elif custom_llm_provider == "vertex_ai" and model in litellm.vertex_llama3_models: + supported_params = get_supported_openai_params( + model=model, custom_llm_provider=custom_llm_provider + ) + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.VertexAILlama3Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) elif custom_llm_provider == "sagemaker": ## check if unsupported param passed in supported_params = get_supported_openai_params( @@ -4189,6 +4198,9 @@ def get_supported_openai_params( return litellm.GoogleAIStudioGeminiConfig().get_supported_openai_params() elif custom_llm_provider == "vertex_ai": if request_type == "chat_completion": + if model.startswith("meta/"): + return litellm.VertexAILlama3Config().get_supported_openai_params() + return litellm.VertexAIConfig().get_supported_openai_params() elif request_type == "embeddings": return litellm.VertexAITextEmbeddingConfig().get_supported_openai_params() @@ -5752,10 +5764,12 @@ def convert_to_model_response_object( model_response_object.usage.total_tokens = response_object["usage"].get("total_tokens", 0) # type: ignore if "created" in response_object: - model_response_object.created = response_object["created"] + model_response_object.created = response_object["created"] or int( + time.time() + ) if "id" in response_object: - model_response_object.id = response_object["id"] + model_response_object.id = response_object["id"] or str(uuid.uuid4()) if "system_fingerprint" in response_object: model_response_object.system_fingerprint = response_object[ diff --git a/model_prices_and_context_window.json b/model_prices_and_context_window.json index 4c6dd8fdb..e8ca6f74d 100644 --- a/model_prices_and_context_window.json +++ b/model_prices_and_context_window.json @@ -1975,6 +1975,16 @@ "supports_function_calling": true, "supports_vision": true }, + "vertex_ai/meta/llama3-405b-instruct-maas": { + "max_tokens": 32000, + "max_input_tokens": 32000, + "max_output_tokens": 32000, + "input_cost_per_token": 0.0, + "output_cost_per_token": 0.0, + "litellm_provider": "vertex_ai-llama_models", + "mode": "chat", + "source": "https://cloud.google.com/vertex-ai/generative-ai/pricing#partner-models" + }, "vertex_ai/imagegeneration@006": { "cost_per_image": 0.020, "litellm_provider": "vertex_ai-image-models", diff --git a/pyproject.toml b/pyproject.toml index 5dc8ab62d..10246abd7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "1.41.27" +version = "1.42.0" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" @@ -91,7 +91,7 @@ requires = ["poetry-core", "wheel"] build-backend = "poetry.core.masonry.api" [tool.commitizen] -version = "1.41.27" +version = "1.42.0" version_files = [ "pyproject.toml:^version" ] diff --git a/schema.prisma b/schema.prisma index 970a1197e..8f4125104 100644 --- a/schema.prisma +++ b/schema.prisma @@ -172,7 +172,7 @@ model LiteLLM_Config { model LiteLLM_SpendLogs { request_id String @id call_type String - api_key String @default ("") + api_key String @default ("") // Hashed API Token. Not the actual Virtual Key. Equivalent to 'token' column in LiteLLM_VerificationToken spend Float @default(0.0) total_tokens Int @default(0) prompt_tokens Int @default(0) @@ -183,12 +183,12 @@ model LiteLLM_SpendLogs { model String @default("") model_id String? @default("") // the model id stored in proxy model db model_group String? @default("") // public model_name / model_group - api_base String @default("") - user String @default("") - metadata Json @default("{}") - cache_hit String @default("") - cache_key String @default("") - request_tags Json @default("[]") + api_base String? @default("") + user String? @default("") + metadata Json? @default("{}") + cache_hit String? @default("") + cache_key String? @default("") + request_tags Json? @default("[]") team_id String? end_user String? requester_ip_address String?