From dcbdbd2e0b13e69d8316f8a70ea43b85d1f1230d Mon Sep 17 00:00:00 2001 From: Michael Slattery Date: Thu, 30 May 2024 14:14:41 -0400 Subject: [PATCH 01/52] feature - Types for mypy - #360 --- litellm/py.typed | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 litellm/py.typed diff --git a/litellm/py.typed b/litellm/py.typed new file mode 100644 index 0000000000..5686005abc --- /dev/null +++ b/litellm/py.typed @@ -0,0 +1,2 @@ +# Marker file to instruct type checkers to look for inline type annotations in this package. +# See PEP 561 for more information. From e56ba0226e448fdfb86d42a9a5c85ba4b407bbf9 Mon Sep 17 00:00:00 2001 From: Michael Slattery Date: Thu, 30 May 2024 14:36:04 -0400 Subject: [PATCH 02/52] feature #360 - Distribute py.typed --- pyproject.toml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index a6a6a86429..627209e713 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,6 +5,10 @@ description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT" readme = "README.md" +packages = [ + { include = "litellm" }, + { include = "litellm/py.typed"}, +] [tool.poetry.urls] homepage = "https://litellm.ai" From 96b556f385e97b247b4fe44f843c821c9fc52672 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Wed, 5 Jun 2024 21:20:36 -0700 Subject: [PATCH 03/52] feat(bedrock_httpx.py): add support for bedrock converse api closes https://github.com/BerriAI/litellm/issues/4000 --- litellm/__init__.py | 2 +- litellm/llms/bedrock_httpx.py | 382 ++++++++++++++++++++++ litellm/llms/prompt_templates/factory.py | 388 ++++++++++++++++++++++- litellm/tests/test_completion.py | 13 +- litellm/types/llms/bedrock.py | 77 ++++- litellm/utils.py | 15 +- 6 files changed, 846 insertions(+), 31 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index f67a252ebc..2fc47a9926 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -765,7 +765,7 @@ from .llms.sagemaker import SagemakerConfig from .llms.ollama import OllamaConfig from .llms.ollama_chat import OllamaChatConfig from .llms.maritalk import MaritTalkConfig -from .llms.bedrock_httpx import AmazonCohereChatConfig +from .llms.bedrock_httpx import AmazonCohereChatConfig, AmazonConverseConfig from .llms.bedrock import ( AmazonTitanConfig, AmazonAI21Config, diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index dbd7e7c695..e212650064 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -38,6 +38,8 @@ from .prompt_templates.factory import ( extract_between_tags, parse_xml_params, contains_tag, + _bedrock_converse_messages_pt, + _bedrock_tools_pt, ) from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from .base import BaseLLM @@ -118,6 +120,8 @@ class AmazonCohereChatConfig: "presence_penalty", "seed", "stop", + "tools", + "tool_choice", ] def map_openai_params( @@ -1069,6 +1073,384 @@ class BedrockLLM(BaseLLM): return super().embedding(*args, **kwargs) +class AmazonConverseConfig: + """ + Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + """ + + maxTokens: Optional[int] + stopSequences: Optional[List[str]] + temperature: Optional[int] + topP: Optional[int] + + def __init__( + self, + maxTokens: Optional[int] = None, + stopSequences: Optional[List[str]] = None, + temperature: Optional[int] = None, + top_p: Optional[int] = None, + ) -> None: + locals_ = locals() + for key, value in locals_.items(): + if key != "self" and value is not None: + setattr(self.__class__, key, value) + + @classmethod + def get_config(cls): + return { + k: v + for k, v in cls.__dict__.items() + if not k.startswith("__") + and not isinstance( + v, + ( + types.FunctionType, + types.BuiltinFunctionType, + classmethod, + staticmethod, + ), + ) + and v is not None + } + + def get_supported_openai_params(self) -> List[str]: + return [ + "max_tokens", + "stream", + "stream_options", + "stop", + "temperature", + "top_p", + "tools", + "tool_choice", + ] + + def get_supported_image_types(self) -> List[str]: + return ["png", "jpeg", "gif", "webp"] + + def map_openai_params( + self, non_default_params: dict, optional_params: dict + ) -> dict: + for param, value in non_default_params.items(): + if param == "max_tokens": + optional_params["maxTokens"] = value + if param == "stream": + optional_params["stream"] = value + if param == "stop": + if isinstance(value, str): + value = [value] + optional_params["stop_sequences"] = value + if param == "temperature": + optional_params["temperature"] = value + if param == "top_p": + optional_params["topP"] = value + return optional_params + + +class BedrockConverseLLM(BaseLLM): + def __init__(self) -> None: + super().__init__() + + def encode_model_id(self, model_id: str) -> str: + """ + Double encode the model ID to ensure it matches the expected double-encoded format. + Args: + model_id (str): The model ID to encode. + Returns: + str: The double-encoded model ID. + """ + return urllib.parse.quote(model_id, safe="") + + def get_credentials( + self, + aws_access_key_id: Optional[str] = None, + aws_secret_access_key: Optional[str] = None, + aws_region_name: Optional[str] = None, + aws_session_name: Optional[str] = None, + aws_profile_name: Optional[str] = None, + aws_role_name: Optional[str] = None, + aws_web_identity_token: Optional[str] = None, + ): + """ + Return a boto3.Credentials object + """ + import boto3 + + ## CHECK IS 'os.environ/' passed in + params_to_check: List[Optional[str]] = [ + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ] + + # Iterate over parameters and update if needed + for i, param in enumerate(params_to_check): + if param and param.startswith("os.environ/"): + _v = get_secret(param) + if _v is not None and isinstance(_v, str): + params_to_check[i] = _v + # Assign updated values back to parameters + ( + aws_access_key_id, + aws_secret_access_key, + aws_region_name, + aws_session_name, + aws_profile_name, + aws_role_name, + aws_web_identity_token, + ) = params_to_check + + ### CHECK STS ### + if ( + aws_web_identity_token is not None + and aws_role_name is not None + and aws_session_name is not None + ): + oidc_token = get_secret(aws_web_identity_token) + + if oidc_token is None: + raise BedrockError( + message="OIDC token could not be retrieved from secret manager.", + status_code=401, + ) + + sts_client = boto3.client("sts") + + # https://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRoleWithWebIdentity.html + # https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/sts/client/assume_role_with_web_identity.html + sts_response = sts_client.assume_role_with_web_identity( + RoleArn=aws_role_name, + RoleSessionName=aws_session_name, + WebIdentityToken=oidc_token, + DurationSeconds=3600, + ) + + session = boto3.Session( + aws_access_key_id=sts_response["Credentials"]["AccessKeyId"], + aws_secret_access_key=sts_response["Credentials"]["SecretAccessKey"], + aws_session_token=sts_response["Credentials"]["SessionToken"], + region_name=aws_region_name, + ) + + return session.get_credentials() + elif aws_role_name is not None and aws_session_name is not None: + sts_client = boto3.client( + "sts", + aws_access_key_id=aws_access_key_id, # [OPTIONAL] + aws_secret_access_key=aws_secret_access_key, # [OPTIONAL] + ) + + sts_response = sts_client.assume_role( + RoleArn=aws_role_name, RoleSessionName=aws_session_name + ) + + # Extract the credentials from the response and convert to Session Credentials + sts_credentials = sts_response["Credentials"] + from botocore.credentials import Credentials + + credentials = Credentials( + access_key=sts_credentials["AccessKeyId"], + secret_key=sts_credentials["SecretAccessKey"], + token=sts_credentials["SessionToken"], + ) + return credentials + elif aws_profile_name is not None: ### CHECK SESSION ### + # uses auth values from AWS profile usually stored in ~/.aws/credentials + client = boto3.Session(profile_name=aws_profile_name) + + return client.get_credentials() + else: + session = boto3.Session( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + region_name=aws_region_name, + ) + + return session.get_credentials() + + def completion( + self, + model: str, + messages: list, + custom_prompt_dict: dict, + model_response: ModelResponse, + print_verbose: Callable, + encoding, + logging_obj, + optional_params: dict, + acompletion: bool, + timeout: Optional[Union[float, httpx.Timeout]], + litellm_params=None, + logger_fn=None, + extra_headers: Optional[dict] = None, + client: Optional[Union[AsyncHTTPHandler, HTTPHandler]] = None, + ): + try: + import boto3 + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + except ImportError as e: + raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") + + ## SETUP ## + stream = optional_params.pop("stream", None) + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + + provider = model.split(".")[0] + + ## CREDENTIALS ## + # pop aws_secret_access_key, aws_access_key_id, aws_region_name from kwargs, since completion calls fail with them + aws_secret_access_key = optional_params.pop("aws_secret_access_key", None) + aws_access_key_id = optional_params.pop("aws_access_key_id", None) + aws_region_name = optional_params.pop("aws_region_name", None) + aws_role_name = optional_params.pop("aws_role_name", None) + aws_session_name = optional_params.pop("aws_session_name", None) + aws_profile_name = optional_params.pop("aws_profile_name", None) + aws_bedrock_runtime_endpoint = optional_params.pop( + "aws_bedrock_runtime_endpoint", None + ) # https://bedrock-runtime.{region_name}.amazonaws.com + aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) + + ### SET REGION NAME ### + if aws_region_name is None: + # check env # + litellm_aws_region_name = get_secret("AWS_REGION_NAME", None) + + if litellm_aws_region_name is not None and isinstance( + litellm_aws_region_name, str + ): + aws_region_name = litellm_aws_region_name + + standard_aws_region_name = get_secret("AWS_REGION", None) + if standard_aws_region_name is not None and isinstance( + standard_aws_region_name, str + ): + aws_region_name = standard_aws_region_name + + if aws_region_name is None: + aws_region_name = "us-west-2" + + credentials: Credentials = self.get_credentials( + aws_access_key_id=aws_access_key_id, + aws_secret_access_key=aws_secret_access_key, + aws_region_name=aws_region_name, + aws_session_name=aws_session_name, + aws_profile_name=aws_profile_name, + aws_role_name=aws_role_name, + aws_web_identity_token=aws_web_identity_token, + ) + + ### SET RUNTIME ENDPOINT ### + endpoint_url = "" + env_aws_bedrock_runtime_endpoint = get_secret("AWS_BEDROCK_RUNTIME_ENDPOINT") + if aws_bedrock_runtime_endpoint is not None and isinstance( + aws_bedrock_runtime_endpoint, str + ): + endpoint_url = aws_bedrock_runtime_endpoint + elif env_aws_bedrock_runtime_endpoint and isinstance( + env_aws_bedrock_runtime_endpoint, str + ): + endpoint_url = env_aws_bedrock_runtime_endpoint + else: + endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" + + if (stream is not None and stream == True) and provider != "ai21": + endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" + else: + endpoint_url = f"{endpoint_url}/model/{modelId}/converse" + + sigv4 = SigV4Auth(credentials, "bedrock", aws_region_name) + + # Separate system prompt from rest of message + system_prompt_indices = [] + system_content_blocks: List[SystemContentBlock] = [] + for idx, message in enumerate(messages): + if message["role"] == "system": + _system_content_block = SystemContentBlock(text=message["content"]) + system_content_blocks.append(_system_content_block) + system_prompt_indices.append(idx) + if len(system_prompt_indices) > 0: + for idx in reversed(system_prompt_indices): + messages.pop(idx) + + inference_params = copy.deepcopy(optional_params) + additional_request_keys = [] + additional_request_params = {} + supported_converse_params = AmazonConverseConfig().get_config().keys() + + ## TRANSFORMATION ## + # send all model-specific params in 'additional_request_params' + for k, v in inference_params.items(): + if k not in supported_converse_params: + additional_request_params[k] = v + additional_request_keys.append(k) + for key in additional_request_keys: + inference_params.pop(key, None) + + bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( + messages=messages + ) + bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( + inference_params.get("tools", []) + ) + bedrock_tool_config: Optional[ToolConfigBlock] = None + if len(bedrock_tools) > 0: + bedrock_tool_config = ToolConfigBlock( + tools=bedrock_tools, + toolChoice=inference_params.get("tool_choice", None), + ) + + data: RequestObject = { + "messages": bedrock_messages, + "additionalModelRequestFields": additional_request_params, + "system": system_content_blocks, + } + if bedrock_tool_config is not None: + data["toolConfig"] = bedrock_tool_config + + ## COMPLETION CALL + + headers = {"Content-Type": "application/json"} + if extra_headers is not None: + headers = {"Content-Type": "application/json", **extra_headers} + request = AWSRequest( + method="POST", url=endpoint_url, data=data, headers=headers + ) + sigv4.add_auth(request) + prepped = request.prepare() + + ## LOGGING + logging_obj.pre_call( + input=messages, + api_key="", + additional_args={ + "complete_input_dict": data, + "api_base": prepped.url, + "headers": prepped.headers, + }, + ) + + ### ROUTING (ASYNC, STREAMING, SYNC) + try: + response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + def get_response_stream_shape(): from botocore.model import ServiceModel from botocore.loaders import Loader diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index 41ecb486ce..d5ef696879 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -3,14 +3,7 @@ import requests, traceback import json, re, xml.etree.ElementTree as ET from jinja2 import Template, exceptions, meta, BaseLoader from jinja2.sandbox import ImmutableSandboxedEnvironment -from typing import ( - Any, - List, - Mapping, - MutableMapping, - Optional, - Sequence, -) +from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple import litellm import litellm.types from litellm.types.completion import ( @@ -24,7 +17,7 @@ from litellm.types.completion import ( import litellm.types.llms from litellm.types.llms.anthropic import * import uuid - +from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock import litellm.types.llms.vertex_ai @@ -1460,9 +1453,7 @@ def _load_image_from_url(image_url): try: from PIL import Image except: - raise Exception( - "gemini image conversion failed please run `pip install Pillow`" - ) + raise Exception("image conversion failed please run `pip install Pillow`") from io import BytesIO try: @@ -1613,6 +1604,379 @@ def azure_text_pt(messages: list): return prompt +###### AMAZON BEDROCK ####### + +from litellm.types.llms.bedrock import ( + ToolResultContentBlock as BedrockToolResultContentBlock, + ToolResultBlock as BedrockToolResultBlock, + ToolConfigBlock as BedrockToolConfigBlock, + ToolUseBlock as BedrockToolUseBlock, + ImageSourceBlock as BedrockImageSourceBlock, + ImageBlock as BedrockImageBlock, + ContentBlock as BedrockContentBlock, + ToolInputSchemaBlock as BedrockToolInputSchemaBlock, + ToolSpecBlock as BedrockToolSpecBlock, + ToolBlock as BedrockToolBlock, +) + + +def get_image_details(image_url) -> Tuple[bytes, str]: + try: + import base64 + + # Send a GET request to the image URL + response = requests.get(image_url) + response.raise_for_status() # Raise an exception for HTTP errors + + # Check the response's content type to ensure it is an image + content_type = response.headers.get("content-type") + if not content_type or "image" not in content_type: + raise ValueError( + f"URL does not point to a valid image (content-type: {content_type})" + ) + + # Convert the image content to base64 bytes + base64_bytes = base64.b64encode(response.content) + + # Get mime-type + mime_type = content_type.split("/")[ + 1 + ] # Extract mime-type from content-type header + + return base64_bytes, mime_type + + except requests.RequestException as e: + raise Exception(f"Request failed: {e}") + except Exception as e: + raise e + + +def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: + if "base64" in image_url: + # Case 1: Images with base64 encoding + import base64, re + + # base 64 is passed as data:image/jpeg;base64, + image_metadata, img_without_base_64 = image_url.split(",") + image_format = image_metadata.split("/")[1] + + # read mime_type from img_without_base_64=data:image/jpeg;base64 + # Extract MIME type using regular expression + mime_type_match = re.match(r"data:(.*?);base64", image_metadata) + + if mime_type_match: + mime_type = mime_type_match.group(1) + else: + mime_type = "jpeg" + decoded_img = base64.b64decode(img_without_base_64) + _blob = BedrockImageSourceBlock(bytes=decoded_img) + supported_image_formats = ( + litellm.AmazonConverseConfig().get_supported_image_types() + ) + if image_format in supported_image_formats: + return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + image_format, supported_image_formats + ) + ) + elif "https:/" in image_url: + # Case 2: Images with direct links + image_bytes, image_format = get_image_details(image_url) + _blob = BedrockImageSourceBlock(bytes=image_bytes) + supported_image_formats = ( + litellm.AmazonConverseConfig().get_supported_image_types() + ) + if image_format in supported_image_formats: + return BedrockImageBlock(source=_blob, format=image_format) # type: ignore + else: + # Handle the case when the image format is not supported + raise ValueError( + "Unsupported image format: {}. Supported formats: {}".format( + image_format, supported_image_formats + ) + ) + else: + raise ValueError( + "Unsupported image type. Expected either image url or base64 encoded string" + ) + + +def _convert_to_bedrock_tool_call_invoke( + tool_calls: list, +) -> List[BedrockContentBlock]: + """ + OpenAI tool invokes: + { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + """ + """ + Bedrock tool invokes: + [ + { + "role": "assistant", + "toolUse": { + "input": {"location": "Boston, MA", ..}, + "name": "get_current_weather", + "toolUseId": "call_abc123" + } + } + ] + """ + """ + - json.loads argument + - extract name + - extract id + """ + + try: + _parts_list: List[BedrockContentBlock] = [] + for tool in tool_calls: + if "function" in tool: + id = tool["id"] + name = tool["function"].get("name", "") + arguments = tool["function"].get("arguments", "") + arguments_dict = json.loads(arguments) + bedrock_tool = BedrockToolUseBlock( + input=arguments_dict, name=name, toolUseId=id + ) + bedrock_content_block = BedrockContentBlock(toolUse=bedrock_tool) + _parts_list.append(bedrock_content_block) + return _parts_list + except Exception as e: + raise Exception( + "Unable to convert openai tool calls={} to bedrock tool calls. Received error={}".format( + tool_calls, str(e) + ) + ) + + +def _convert_to_bedrock_tool_call_result( + message: dict, +) -> BedrockMessageBlock: + """ + OpenAI message with a tool result looks like: + { + "tool_call_id": "tool_1", + "role": "tool", + "name": "get_current_weather", + "content": "function result goes here", + }, + + OpenAI message with a function call result looks like: + { + "role": "function", + "name": "get_current_weather", + "content": "function result goes here", + } + """ + """ + Bedrock result looks like this: + { + "role": "user", + "content": [ + { + "toolResult": { + "toolUseId": "tooluse_kZJMlvQmRJ6eAyJE5GIl7Q", + "content": [ + { + "json": { + "song": "Elemental Hotel", + "artist": "8 Storey Hike" + } + } + ] + } + } + ] + } + """ + """ + - + """ + content = message.get("content", "") + name = message.get("name", "") + id = message.get("tool_call_id", str(uuid.uuid4())) + + tool_result_content_block = BedrockToolResultContentBlock(text=content) + tool_result = BedrockToolResultBlock( + content=tool_result_content_block, + toolUseId=id, + ) + content_block = BedrockContentBlock(toolResult=tool_result) + + return BedrockMessageBlock(role="user", content=[content_block]) + + +def _bedrock_converse_messages_pt(messages: List) -> List[BedrockMessageBlock]: + """ + Converts given messages from OpenAI format to Bedrock format + + - Roles must alternate b/w 'user' and 'model' (same as anthropic -> merge consecutive roles) + - Please ensure that function response turn comes immediately after a function call turn + """ + + contents: List[BedrockMessageBlock] = [] + msg_i = 0 + while msg_i < len(messages): + user_content: List[BedrockContentBlock] = [] + init_msg_i = msg_i + ## MERGE CONSECUTIVE USER CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "user": + if isinstance(messages[msg_i]["content"], list): + _parts: List[BedrockContentBlock] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + _part = BedrockContentBlock(text=element["text"]) + _parts.append(_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + _part = _process_bedrock_converse_image_block( # type: ignore + image_url=image_url + ) + _parts.append(BedrockContentBlock(image=_part)) # type: ignore + user_content.extend(_parts) + else: + _part = BedrockContentBlock(text=messages[msg_i]["content"]) + user_content.append(_part) + + msg_i += 1 + + if user_content: + contents.append(BedrockMessageBlock(role="user", content=user_content)) + assistant_content: List[BedrockContentBlock] = [] + ## MERGE CONSECUTIVE ASSISTANT CONTENT ## + while msg_i < len(messages) and messages[msg_i]["role"] == "assistant": + if isinstance(messages[msg_i]["content"], list): + assistants_parts: List[BedrockContentBlock] = [] + for element in messages[msg_i]["content"]: + if isinstance(element, dict): + if element["type"] == "text": + assistants_part = BedrockContentBlock(text=element["text"]) + assistants_parts.append(assistants_part) + elif element["type"] == "image_url": + image_url = element["image_url"]["url"] + assistants_part = _process_bedrock_converse_image_block( # type: ignore + image_url=image_url + ) + assistants_parts.append( + BedrockContentBlock(image=assistants_part) # type: ignore + ) + assistant_content.extend(assistants_parts) + elif messages[msg_i].get( + "tool_calls", [] + ): # support assistant tool invoke convertion + assistant_content.extend( + _convert_to_bedrock_tool_call_invoke(messages[msg_i]["tool_calls"]) + ) + else: + assistant_text = ( + messages[msg_i].get("content") or "" + ) # either string or none + if assistant_text: + assistant_content.append(BedrockContentBlock(text=assistant_text)) + + msg_i += 1 + + if assistant_content: + contents.append( + BedrockMessageBlock(role="assistant", content=assistant_content) + ) + + ## APPEND TOOL CALL MESSAGES ## + if msg_i < len(messages) and messages[msg_i]["role"] == "tool": + tool_call_result = _convert_to_bedrock_tool_call_result(messages[msg_i]) + contents.append(tool_call_result) + msg_i += 1 + if msg_i == init_msg_i: # prevent infinite loops + raise Exception( + "Invalid Message passed in - {}. File an issue https://github.com/BerriAI/litellm/issues".format( + messages[msg_i] + ) + ) + + return contents + + +def _bedrock_tools_pt(tools: List) -> List[BedrockToolBlock]: + """ + OpenAI tools looks like: + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + """ + """ + Bedrock toolConfig looks like: + "tools": [ + { + "toolSpec": { + "name": "top_song", + "description": "Get the most popular song played on a radio station.", + "inputSchema": { + "json": { + "type": "object", + "properties": { + "sign": { + "type": "string", + "description": "The call sign for the radio station for which you want the most popular song. Example calls signs are WZPZ, and WKRP." + } + }, + "required": [ + "sign" + ] + } + } + } + } + ] + """ + tool_block_list: List[BedrockToolBlock] = [] + for tool in tools: + parameters = tool.get("function", {}).get("parameters", None) + name = tool.get("function", {}).get("name", "") + description = tool.get("function", {}).get("description", "") + tool_input_schema = BedrockToolInputSchemaBlock(json=parameters) + tool_spec = BedrockToolSpecBlock( + inputSchema=tool_input_schema, name=name, description=description + ) + tool_block = BedrockToolBlock(toolSpec=tool_spec) + tool_block_list.append(tool_block) + + return tool_block_list + + # Function call template def function_call_prompt(messages: list, functions: list): function_prompt = """Produce JSON OUTPUT ONLY! Adhere to this format {"name": "function_name", "arguments":{"argument_name": "argument_value"}} The following functions are available to you:""" diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 47c55ca4f3..1befa1392e 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -300,7 +300,14 @@ def test_completion_claude_3(): pytest.fail(f"Error occurred: {e}") -def test_completion_claude_3_function_call(): +@pytest.mark.parametrize( + "model", + [ + # "anthropic/claude-3-opus-20240229", + "cohere.command-r-plus-v1:0" + ], +) +def test_completion_claude_3_function_call(model): litellm.set_verbose = True tools = [ { @@ -331,7 +338,7 @@ def test_completion_claude_3_function_call(): try: # test without max tokens response = completion( - model="anthropic/claude-3-opus-20240229", + model=model, messages=messages, tools=tools, tool_choice={ @@ -364,7 +371,7 @@ def test_completion_claude_3_function_call(): ) # In the second response, Claude should deduce answer from tool results second_response = completion( - model="anthropic/claude-3-opus-20240229", + model=model, messages=messages, tools=tools, tool_choice="auto", diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 0c82596827..9333ea1b9e 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -1,4 +1,4 @@ -from typing import TypedDict, Any, Union, Optional +from typing import TypedDict, Any, Union, Optional, Literal, List import json from typing_extensions import ( Self, @@ -11,6 +11,81 @@ from typing_extensions import ( ) +class SystemContentBlock(TypedDict): + text: str + + +class ImageSourceBlock(TypedDict): + bytes: Optional[bytes] + + +class ImageBlock(TypedDict): + format: Literal["png", "jpeg", "gif", "webp"] + source: ImageSourceBlock + + +class ToolResultContentBlock(TypedDict, total=False): + image: ImageBlock + json: dict + text: str + + +class ToolResultBlock(TypedDict, total=False): + content: Required[ToolResultContentBlock] + toolUseId: Required[str] + status: Literal["success", "error"] + + +class ToolUseBlock(TypedDict): + input: dict + name: str + toolUseId: str + + +class ContentBlock(TypedDict, total=False): + text: str + image: ImageBlock + toolResult: ToolResultBlock + toolUse: ToolUseBlock + + +class MessageBlock(TypedDict): + content: List[ContentBlock] + role: Literal["user", "assistant"] + + +class ToolInputSchemaBlock(TypedDict): + json: Optional[dict] + + +class ToolSpecBlock(TypedDict, total=False): + inputSchema: Required[ToolInputSchemaBlock] + name: Required[str] + description: str + + +class ToolBlock(TypedDict): + toolSpec: Optional[ToolSpecBlock] + + +class SpecificToolChoiceBlock(TypedDict): + name: str + + +class ToolConfigBlock(TypedDict, total=False): + tools: Required[List[ToolBlock]] + toolChoice: Union[str, SpecificToolChoiceBlock] + + +class RequestObject(TypedDict, total=False): + additionalModelRequestFields: dict + additionalModelResponseFieldPaths: List[str] + inferenceConfig: dict + messages: Required[List[MessageBlock]] + system: List[SystemContentBlock] + toolConfig: ToolConfigBlock + + class GenericStreamingChunk(TypedDict): text: Required[str] is_finished: Required[bool] diff --git a/litellm/utils.py b/litellm/utils.py index 1788600941..65a34058b2 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6401,20 +6401,7 @@ def get_supported_openai_params( - None if unmapped """ if custom_llm_provider == "bedrock": - if model.startswith("anthropic.claude-3"): - return litellm.AmazonAnthropicClaude3Config().get_supported_openai_params() - elif model.startswith("anthropic"): - return litellm.AmazonAnthropicConfig().get_supported_openai_params() - elif model.startswith("ai21"): - return ["max_tokens", "temperature", "top_p", "stream"] - elif model.startswith("amazon"): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] - elif model.startswith("meta"): - return ["max_tokens", "temperature", "top_p", "stream"] - elif model.startswith("cohere"): - return ["stream", "temperature", "max_tokens"] - elif model.startswith("mistral"): - return ["max_tokens", "temperature", "stop", "top_p", "stream"] + return litellm.AmazonConverseConfig().get_supported_openai_params() elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_supported_openai_params() elif custom_llm_provider == "ollama_chat": From 3df177d0d00583b5fd33d94439de86ab4594c815 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 10:38:15 -0700 Subject: [PATCH 04/52] feat - redact messages from slack alerting --- litellm/__init__.py | 1 + litellm/proxy/proxy_config.yaml | 4 +++- litellm/utils.py | 10 +++++++--- 3 files changed, 11 insertions(+), 4 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 9fb614396e..33ae7deaba 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -60,6 +60,7 @@ _async_failure_callback: List[Callable] = ( pre_call_rules: List[Callable] = [] post_call_rules: List[Callable] = [] turn_off_message_logging: Optional[bool] = False +redact_messages_in_exceptions: Optional[bool] = False store_audit_logs = False # Enterprise feature, allow users to see audit logs ## end of callbacks ############# diff --git a/litellm/proxy/proxy_config.yaml b/litellm/proxy/proxy_config.yaml index 88fc0e9136..2bef95acfd 100644 --- a/litellm/proxy/proxy_config.yaml +++ b/litellm/proxy/proxy_config.yaml @@ -21,7 +21,9 @@ model_list: general_settings: master_key: sk-1234 + alerting: ["slack"] litellm_settings: callbacks: ["otel"] - store_audit_logs: true \ No newline at end of file + store_audit_logs: true + redact_messages_in_exceptions: True \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index ba6a374674..91e6d9faba 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -7351,10 +7351,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: if custom_llm_provider == "databricks": return litellm.DatabricksConfig().get_required_params() - + elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_required_params() - + else: return [] @@ -8514,7 +8514,11 @@ def exception_type( extra_information = f"\nModel: {model}" if _api_base: extra_information += f"\nAPI Base: `{_api_base}`" - if messages and len(messages) > 0: + if ( + messages + and len(messages) > 0 + and litellm.redact_messages_in_exceptions is False + ): extra_information += f"\nMessages: `{messages}`" if _model_group is not None: From 383d58a3f835780f88b2ecf230b82de714390b49 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 10:48:08 -0700 Subject: [PATCH 05/52] fix - turn of message logging --- litellm/integrations/slack_alerting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index c98d60f1fd..ce7039ef15 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -326,7 +326,7 @@ class SlackAlerting(CustomLogger): end_time=end_time, ) ) - if litellm.turn_off_message_logging: + if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: messages = "Message not logged. `litellm.turn_off_message_logging=True`." request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" @@ -567,7 +567,10 @@ class SlackAlerting(CustomLogger): except: messages = "" - if litellm.turn_off_message_logging: + if ( + litellm.turn_off_message_logging + or litellm.redact_messages_in_exceptions + ): messages = ( "Message not logged. `litellm.turn_off_message_logging=True`." ) From 99429457976c0852bf26997d60bfb16639dba2d9 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 12:30:56 -0700 Subject: [PATCH 06/52] docs - redacting messages from slack alerting --- docs/my-website/docs/proxy/alerting.md | 17 +++++++++++++++++ litellm/integrations/slack_alerting.py | 4 ++-- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/my-website/docs/proxy/alerting.md b/docs/my-website/docs/proxy/alerting.md index 3ef676bbd6..402de410c9 100644 --- a/docs/my-website/docs/proxy/alerting.md +++ b/docs/my-website/docs/proxy/alerting.md @@ -62,6 +62,23 @@ curl -X GET 'http://localhost:4000/health/services?service=slack' \ -H 'Authorization: Bearer sk-1234' ``` +## Advanced - Redacting Messages from Alerts + +By default alerts show the `messages/input` passed to the LLM. If you want to redact this from slack alerting set the following setting on your config + + +```shell +general_settings: + alerting: ["slack"] + alert_types: ["spend_reports"] + +litellm_settings: + redact_messages_in_exceptions: True +``` + + + + ## Advanced - Opting into specific alert types Set `alert_types` if you want to Opt into only specific alert types diff --git a/litellm/integrations/slack_alerting.py b/litellm/integrations/slack_alerting.py index ce7039ef15..21415fb6d6 100644 --- a/litellm/integrations/slack_alerting.py +++ b/litellm/integrations/slack_alerting.py @@ -327,7 +327,7 @@ class SlackAlerting(CustomLogger): ) ) if litellm.turn_off_message_logging or litellm.redact_messages_in_exceptions: - messages = "Message not logged. `litellm.turn_off_message_logging=True`." + messages = "Message not logged. litellm.redact_messages_in_exceptions=True" request_info = f"\nRequest Model: `{model}`\nAPI Base: `{api_base}`\nMessages: `{messages}`" slow_message = f"`Responses are slow - {round(time_difference_float,2)}s response time > Alerting threshold: {self.alerting_threshold}s`" if time_difference_float > self.alerting_threshold: @@ -572,7 +572,7 @@ class SlackAlerting(CustomLogger): or litellm.redact_messages_in_exceptions ): messages = ( - "Message not logged. `litellm.turn_off_message_logging=True`." + "Message not logged. litellm.redact_messages_in_exceptions=True" ) request_info = f"\nRequest Model: `{model}`\nMessages: `{messages}`" else: From b4eba4bddd67590f8b2765fd40dc668447e91204 Mon Sep 17 00:00:00 2001 From: nick-rackauckas Date: Thu, 6 Jun 2024 16:15:01 -0700 Subject: [PATCH 07/52] Fix to work with all supported Gemini file types --- litellm/llms/vertex_ai.py | 37 +++--- litellm/types/files.py | 267 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 286 insertions(+), 18 deletions(-) create mode 100644 litellm/types/files.py diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index aabe084b8e..81b02da433 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -12,6 +12,7 @@ from litellm.llms.prompt_templates.factory import ( convert_to_gemini_tool_call_result, convert_to_gemini_tool_call_invoke, ) +from litellm.types.files import get_file_mime_type_for_file_type, get_file_type_from_extension, is_gemini_1_5_accepted_file_type, is_video_file_type class VertexAIError(Exception): @@ -297,29 +298,29 @@ def _convert_gemini_role(role: str) -> Literal["user", "model"]: def _process_gemini_image(image_url: str) -> PartType: try: - if ".mp4" in image_url and "gs://" in image_url: - # Case 1: Videos with Cloud Storage URIs - part_mime = "video/mp4" - _file_data = FileDataType(mime_type=part_mime, file_uri=image_url) - return PartType(file_data=_file_data) - elif ".pdf" in image_url and "gs://" in image_url: - # Case 2: PDF's with Cloud Storage URIs - part_mime = "application/pdf" - _file_data = FileDataType(mime_type=part_mime, file_uri=image_url) - return PartType(file_data=_file_data) - elif "gs://" in image_url: - # Case 3: Images with Cloud Storage URIs - # The supported MIME types for images include image/png and image/jpeg. - part_mime = "image/png" if "png" in image_url else "image/jpeg" - _file_data = FileDataType(mime_type=part_mime, file_uri=image_url) - return PartType(file_data=_file_data) + # GCS URIs + if "gs://" in image_url: + # Figure out file type + extension = os.path.splitext(image_url)[-1] + file_type = get_file_type_from_extension(extension) + + # Validate the file type is supported by Gemini + if not is_gemini_1_5_accepted_file_type(file_type): + raise Exception(f"File type not supported by gemini - {file_type}") + + mime_type = get_file_mime_type_for_file_type(file_type) + file_data = FileDataType(mime_type=mime_type, file_uri=image_url) + + return PartType(file_data=file_data) + + # Direct links elif "https:/" in image_url: - # Case 4: Images with direct links image = _load_image_from_url(image_url) _blob = BlobType(data=image.data, mime_type=image._mime_type) return PartType(inline_data=_blob) + + # Base64 encoding elif "base64" in image_url: - # Case 5: Images with base64 encoding import base64, re # base 64 is passed as data:image/jpeg;base64, diff --git a/litellm/types/files.py b/litellm/types/files.py new file mode 100644 index 0000000000..ea229373d5 --- /dev/null +++ b/litellm/types/files.py @@ -0,0 +1,267 @@ +from enum import Enum +from types import MappingProxyType +from typing import List, Set + +""" +Base Enums/Consts +""" +class FileType(Enum): + AAC = "AAC" + CSV = "CSV" + DOC = "DOC" + DOCX = "DOCX" + FLAC = "FLAC" + FLV = "FLV" + GIF = "GIF" + GOOGLE_DOC = "GOOGLE_DOC" + GOOGLE_DRAWINGS = "GOOGLE_DRAWINGS" + GOOGLE_SHEETS = "GOOGLE_SHEETS" + GOOGLE_SLIDES = "GOOGLE_SLIDES" + HEIC = "HEIC" + HEIF = "HEIF" + HTML = "HTML" + JPEG = "JPEG" + JSON = "JSON" + M4A = "M4A" + M4V = "M4V" + MOV = "MOV" + MP3 = "MP3" + MP4 = "MP4" + MPEG = "MPEG" + MPEGPS = "MPEGPS" + MPG = "MPG" + MPA = "MPA" + MPGA = "MPGA" + OGG = "OGG" + OPUS = "OPUS" + PDF = "PDF" + PCM = "PCM" + PNG = "PNG" + PPT = "PPT" + PPTX = "PPTX" + RTF = "RTF" + THREE_GPP = "3GPP" + TXT = "TXT" + WAV = "WAV" + WEBM = "WEBM" + WEBP = "WEBP" + WMV = "WMV" + XLS = "XLS" + XLSX = "XLSX" + +FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType({ + FileType.AAC: ["aac"], + FileType.CSV: ["csv"], + FileType.DOC: ["doc"], + FileType.DOCX: ["docx"], + FileType.FLAC: ["flac"], + FileType.FLV: ["flv"], + FileType.GIF: ["gif"], + FileType.GOOGLE_DOC: ["gdoc"], + FileType.GOOGLE_DRAWINGS: ["gdraw"], + FileType.GOOGLE_SHEETS: ["gsheet"], + FileType.GOOGLE_SLIDES: ["gslides"], + FileType.HEIC: ["heic"], + FileType.HEIF: ["heif"], + FileType.HTML: ["html", "htm"], + FileType.JPEG: ["jpeg", "jpg"], + FileType.JSON: ["json"], + FileType.M4A: ["m4a"], + FileType.M4V: ["m4v"], + FileType.MOV: ["mov"], + FileType.MP3: ["mp3"], + FileType.MP4: ["mp4"], + FileType.MPEG: ["mpeg"], + FileType.MPEGPS: ["mpegps"], + FileType.MPG: ["mpg"], + FileType.MPA: ["mpa"], + FileType.MPGA: ["mpga"], + FileType.OGG: ["ogg"], + FileType.OPUS: ["opus"], + FileType.PDF: ["pdf"], + FileType.PCM: ["pcm"], + FileType.PNG: ["png"], + FileType.PPT: ["ppt"], + FileType.PPTX: ["pptx"], + FileType.RTF: ["rtf"], + FileType.THREE_GPP: ["3gpp"], + FileType.TXT: ["txt"], + FileType.WAV: ["wav"], + FileType.WEBM: ["webm"], + FileType.WEBP: ["webp"], + FileType.WMV: ["wmv"], + FileType.XLS: ["xls"], + FileType.XLSX: ["xlsx"], +}) + +FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType({ + FileType.AAC: "audio/aac", + FileType.CSV: "text/csv", + FileType.DOC: "application/msword", + FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + FileType.FLAC: "audio/flac", + FileType.FLV: "video/x-flv", + FileType.GIF: "image/gif", + FileType.GOOGLE_DOC: "application/vnd.google-apps.document", + FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing", + FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet", + FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation", + FileType.HEIC: "image/heic", + FileType.HEIF: "image/heif", + FileType.HTML: "text/html", + FileType.JPEG: "image/jpeg", + FileType.JSON: "application/json", + FileType.M4A: "audio/x-m4a", + FileType.M4V: "video/x-m4v", + FileType.MOV: "video/quicktime", + FileType.MP3: "audio/mpeg", + FileType.MP4: "video/mp4", + FileType.MPEG: "video/mpeg", + FileType.MPEGPS: "video/mpegps", + FileType.MPG: "video/mpg", + FileType.MPA: "audio/m4a", + FileType.MPGA: "audio/mpga", + FileType.OGG: "audio/ogg", + FileType.OPUS: "audio/opus", + FileType.PDF: "application/pdf", + FileType.PCM: "audio/pcm", + FileType.PNG: "image/png", + FileType.PPT: "application/vnd.ms-powerpoint", + FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation", + FileType.RTF: "application/rtf", + FileType.THREE_GPP: "video/3gpp", + FileType.TXT: "text/plain", + FileType.WAV: "audio/wav", + FileType.WEBM: "video/webm", + FileType.WEBP: "image/webp", + FileType.WMV: "video/wmv", + FileType.XLS: "application/vnd.ms-excel", + FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", +}) + +""" +Util Functions +""" +def get_file_mime_type_from_extension(extension: str) -> str: + for file_type, extensions in FILE_EXTENSIONS.items(): + if extension in extensions: + return FILE_MIME_TYPES[file_type] + raise ValueError(f"Unknown mime type for extension: {extension}") + + +def get_file_extension_from_mime_type(mime_type: str) -> str: + for file_type, mime in FILE_MIME_TYPES.items(): + if mime == mime_type: + return FILE_EXTENSIONS[file_type][0] + raise ValueError(f"Unknown extension for mime type: {mime_type}") + + +def get_file_type_from_extension(extension: str) -> FileType: + for file_type, extensions in FILE_EXTENSIONS.items(): + if extension in extensions: + return file_type + + raise ValueError(f"Unknown file type for extension: {extension}") + + +def get_file_extension_for_file_type(file_type: FileType) -> str: + return FILE_EXTENSIONS[file_type][0] + +def get_file_mime_type_for_file_type(file_type: FileType) -> str: + return FILE_MIME_TYPES[file_type] + + +""" +FileType Type Groupings (Videos, Images, etc) +""" + +# Images +IMAGE_FILE_TYPES = { + FileType.PNG, + FileType.JPEG, + FileType.GIF, + FileType.WEBP, + FileType.HEIC, + FileType.HEIF +} + +def is_image_file_type(file_type): + return file_type in IMAGE_FILE_TYPES + +# Videos +VIDEO_FILE_TYPES = { + FileType.MOV, + FileType.MP4, + FileType.MPEG, + FileType.M4V, + FileType.FLV, + FileType.MPEGPS, + FileType.MPG, + FileType.WEBM, + FileType.WMV, + FileType.THREE_GPP +} + +def is_video_file_type(file_type): + return file_type in VIDEO_FILE_TYPES + +# Audio +AUDIO_FILE_TYPES = { + FileType.AAC, + FileType.FLAC, + FileType.MP3, + FileType.MPA, + FileType.MPGA, + FileType.OPUS, + FileType.PCM, + FileType.WAV, +} + +def is_audio_file_type(file_type): + return file_type in AUDIO_FILE_TYPES + +# Text +TEXT_FILE_TYPES = { + FileType.CSV, + FileType.HTML, + FileType.RTF, + FileType.TXT +} + +def is_text_file_type(file_type): + return file_type in TEXT_FILE_TYPES + +""" +Other FileType Groupings +""" +# Accepted file types for GEMINI 1.5 through Vertex AI +# https://cloud.google.com/vertex-ai/generative-ai/docs/multimodal/send-multimodal-prompts#gemini-send-multimodal-samples-images-nodejs +GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = { + # Image + FileType.PNG, + FileType.JPEG, + # Audio + FileType.AAC, + FileType.FLAC, + FileType.MP3, + FileType.MPA, + FileType.MPGA, + FileType.OPUS, + FileType.PCM, + FileType.WAV, + # Video + FileType.FLV, + FileType.MOV, + FileType.MPEG, + FileType.MPEGPS, + FileType.MPG, + FileType.MP4, + FileType.WEBM, + FileType.WMV, + FileType.THREE_GPP, + # Document + FileType.PDF, +} + +def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool: + return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES From e205381ba0c407f15c25f9abf901f1121976af1e Mon Sep 17 00:00:00 2001 From: nick-rackauckas Date: Thu, 6 Jun 2024 16:30:23 -0700 Subject: [PATCH 08/52] Remove unused function --- litellm/llms/vertex_ai.py | 108 +------------------------------------- 1 file changed, 1 insertion(+), 107 deletions(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 81b02da433..05dc718dc2 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -310,7 +310,7 @@ def _process_gemini_image(image_url: str) -> PartType: mime_type = get_file_mime_type_for_file_type(file_type) file_data = FileDataType(mime_type=mime_type, file_uri=image_url) - + return PartType(file_data=file_data) # Direct links @@ -427,112 +427,6 @@ def _gemini_convert_messages_with_history(messages: list) -> List[ContentType]: return contents -def _gemini_vision_convert_messages(messages: list): - """ - Converts given messages for GPT-4 Vision to Gemini format. - - Args: - messages (list): The messages to convert. Each message can be a dictionary with a "content" key. The content can be a string or a list of elements. If it is a string, it will be concatenated to the prompt. If it is a list, each element will be processed based on its type: - - If the element is a dictionary with a "type" key equal to "text", its "text" value will be concatenated to the prompt. - - If the element is a dictionary with a "type" key equal to "image_url", its "image_url" value will be added to the list of images. - - Returns: - tuple: A tuple containing the prompt (a string) and the processed images (a list of objects representing the images). - - Raises: - VertexAIError: If the import of the 'vertexai' module fails, indicating that 'google-cloud-aiplatform' needs to be installed. - Exception: If any other exception occurs during the execution of the function. - - Note: - This function is based on the code from the 'gemini/getting-started/intro_gemini_python.ipynb' notebook in the 'generative-ai' repository on GitHub. - The supported MIME types for images include 'image/png' and 'image/jpeg'. - - Examples: - >>> messages = [ - ... {"content": "Hello, world!"}, - ... {"content": [{"type": "text", "text": "This is a text message."}, {"type": "image_url", "image_url": "example.com/image.png"}]}, - ... ] - >>> _gemini_vision_convert_messages(messages) - ('Hello, world!This is a text message.', [, ]) - """ - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - Image, - ) - - # given messages for gpt-4 vision, convert them for gemini - # https://github.com/GoogleCloudPlatform/generative-ai/blob/main/gemini/getting-started/intro_gemini_python.ipynb - prompt = "" - images = [] - for message in messages: - if isinstance(message["content"], str): - prompt += message["content"] - elif isinstance(message["content"], list): - # see https://docs.litellm.ai/docs/providers/openai#openai-vision-models - for element in message["content"]: - if isinstance(element, dict): - if element["type"] == "text": - prompt += element["text"] - elif element["type"] == "image_url": - image_url = element["image_url"]["url"] - images.append(image_url) - # processing images passed to gemini - processed_images = [] - for img in images: - if "gs://" in img: - # Case 1: Images with Cloud Storage URIs - # The supported MIME types for images include image/png and image/jpeg. - part_mime = "image/png" if "png" in img else "image/jpeg" - google_clooud_part = Part.from_uri(img, mime_type=part_mime) - processed_images.append(google_clooud_part) - elif "https:/" in img: - # Case 2: Images with direct links - image = _load_image_from_url(img) - processed_images.append(image) - elif ".mp4" in img and "gs://" in img: - # Case 3: Videos with Cloud Storage URIs - part_mime = "video/mp4" - google_clooud_part = Part.from_uri(img, mime_type=part_mime) - processed_images.append(google_clooud_part) - elif "base64" in img: - # Case 4: Images with base64 encoding - import base64, re - - # base 64 is passed as data:image/jpeg;base64, - image_metadata, img_without_base_64 = img.split(",") - - # read mime_type from img_without_base_64=data:image/jpeg;base64 - # Extract MIME type using regular expression - mime_type_match = re.match(r"data:(.*?);base64", image_metadata) - - if mime_type_match: - mime_type = mime_type_match.group(1) - else: - mime_type = "image/jpeg" - decoded_img = base64.b64decode(img_without_base_64) - processed_image = Part.from_data(data=decoded_img, mime_type=mime_type) - processed_images.append(processed_image) - return prompt, processed_images - except Exception as e: - raise e - - def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): _cache_key = f"{model}-{vertex_project}-{vertex_location}" return _cache_key From 93bf678026df3ea35aede04452d7bc97ec2fe6aa Mon Sep 17 00:00:00 2001 From: nick-rackauckas Date: Thu, 6 Jun 2024 16:35:39 -0700 Subject: [PATCH 09/52] Comment --- litellm/types/files.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/types/files.py b/litellm/types/files.py index ea229373d5..0545567ece 100644 --- a/litellm/types/files.py +++ b/litellm/types/files.py @@ -259,7 +259,7 @@ GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = { FileType.WEBM, FileType.WMV, FileType.THREE_GPP, - # Document + # PDF FileType.PDF, } From f8ffaa7fc87b38cc6d11f38fd7fbaa7adf3b14d0 Mon Sep 17 00:00:00 2001 From: nick-rackauckas Date: Thu, 6 Jun 2024 16:57:42 -0700 Subject: [PATCH 10/52] Fix --- litellm/llms/vertex_ai.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/litellm/llms/vertex_ai.py b/litellm/llms/vertex_ai.py index 05dc718dc2..bd9cfaa8d6 100644 --- a/litellm/llms/vertex_ai.py +++ b/litellm/llms/vertex_ai.py @@ -301,7 +301,9 @@ def _process_gemini_image(image_url: str) -> PartType: # GCS URIs if "gs://" in image_url: # Figure out file type - extension = os.path.splitext(image_url)[-1] + extension_with_dot = os.path.splitext(image_url)[-1] # Ex: ".png" + extension = extension_with_dot[1:] # Ex: "png" + file_type = get_file_type_from_extension(extension) # Validate the file type is supported by Gemini From cb5ebba6fa689b81451f9c0bcb260fc2ac2a3bfb Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 19:49:18 -0700 Subject: [PATCH 11/52] feat -v0 parent_otel_span in basic db reads --- litellm/proxy/auth/auth_checks.py | 5 +++++ litellm/proxy/proxy_server.py | 20 +++++++++++++++++--- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index e4b8e6c8a8..f184589f4f 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -21,6 +21,7 @@ from typing import Optional, Literal, Union from litellm.proxy.utils import PrismaClient from litellm.caching import DualCache import litellm +from opentelemetry.trace import Span all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value @@ -186,6 +187,7 @@ async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, ) -> Optional[LiteLLM_EndUserTable]: """ Returns end user object, if in db. @@ -250,6 +252,7 @@ async def get_user_object( prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, user_id_upsert: bool, + parent_otel_span: Optional[Span] = None, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table @@ -300,6 +303,7 @@ async def get_team_object( team_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, ) -> LiteLLM_TeamTable: """ - Check if team id in proxy Team Table @@ -342,6 +346,7 @@ async def get_org_object( org_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, + parent_otel_span: Optional[Span] = None, ): """ - Check if org id in proxy Org Table diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 8cf2fa118d..05eb515739 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -7,6 +7,7 @@ import secrets, subprocess import hashlib, uuid import warnings import importlib +from opentelemetry.trace import Span def showwarning(message, category, filename, lineno, file=None, line=None): @@ -398,6 +399,7 @@ disable_spend_logs = False jwt_handler = JWTHandler() prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None store_model_in_db: bool = False +open_telemetry_logger = None ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### @@ -498,7 +500,12 @@ async def user_api_key_auth( if isinstance(api_key, str): passed_in_key = api_key api_key = _get_bearer_token(api_key=api_key) - + parent_otel_span: Optional[Span] = None + if open_telemetry_logger is not None: + parent_otel_span = open_telemetry_logger.tracer.start_span( + name="Received Proxy Server Request", + start_time=time.time(), + ) ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth is not None: response = await user_custom_auth(request=request, api_key=api_key) @@ -580,6 +587,7 @@ async def user_api_key_auth( team_id=team_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, ) # [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable` @@ -591,6 +599,7 @@ async def user_api_key_auth( org_id=org_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, ) # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` user_object = None @@ -604,6 +613,7 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, user_id_upsert=jwt_handler.is_upsert_user_id(), + parent_otel_span=parent_otel_span, ) # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` @@ -617,6 +627,7 @@ async def user_api_key_auth( end_user_id=end_user_id, prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, ) global_proxy_spend = None @@ -715,6 +726,7 @@ async def user_api_key_auth( end_user_id=request_data["user"], prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, + parent_otel_span=parent_otel_span, ) if _end_user_object is not None: end_user_params["allowed_model_region"] = ( @@ -2306,7 +2318,7 @@ class ProxyConfig: """ Load config values into proxy global state """ - global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user + global master_key, user_config_file_path, otel_logging, user_custom_auth, user_custom_auth_path, user_custom_key_generate, use_background_health_checks, health_check_interval, use_queue, custom_db_client, proxy_budget_rescheduler_max_time, proxy_budget_rescheduler_min_time, ui_access_mode, litellm_master_key_hash, proxy_batch_write_at, disable_spend_logs, prompt_injection_detection_obj, redis_usage_cache, store_model_in_db, premium_user, open_telemetry_logger # Load existing config config = await self.get_config(config_file_path=config_file_path) @@ -2430,7 +2442,9 @@ class ProxyConfig: OpenTelemetry, ) - imported_list.append(OpenTelemetry()) + open_telemetry_logger = OpenTelemetry() + + imported_list.append(open_telemetry_logger) elif isinstance(callback, str) and callback == "presidio": from litellm.proxy.hooks.presidio_pii_masking import ( _OPTIONAL_PresidioPIIMasking, From f8b5aa3df6852cd9306b7611f0f3583dac5d0d77 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 6 Jun 2024 20:12:41 -0700 Subject: [PATCH 12/52] fix(bedrock_httpx.py): working claude 3 function calling --- .pre-commit-config.yaml | 16 +- litellm/llms/bedrock_httpx.py | 230 ++++++++++++++++++++-- litellm/llms/custom_httpx/http_handler.py | 3 +- litellm/llms/prompt_templates/factory.py | 3 +- litellm/main.py | 4 +- litellm/tests/test_completion.py | 7 +- litellm/tests/test_prompt_factory.py | 27 +++ litellm/types/llms/bedrock.py | 34 +++- litellm/types/llms/openai.py | 17 ++ litellm/utils.py | 85 +------- ruff.toml | 3 + 11 files changed, 321 insertions(+), 108 deletions(-) create mode 100644 ruff.toml diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc41d85f14..e8bb1ff66a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -- repo: local - hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports - language: system - types: [python] - files: ^litellm/ \ No newline at end of file +# - repo: local +# hooks: +# - id: mypy +# name: mypy +# entry: python3 -m mypy --ignore-missing-imports +# language: system +# types: [python] +# files: ^litellm/ \ No newline at end of file diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index e212650064..ce6a931747 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -47,6 +47,11 @@ import httpx # type: ignore from .bedrock import BedrockError, convert_messages_to_prompt, ModelResponseIterator from litellm.types.llms.bedrock import * import urllib.parse +from litellm.types.llms.openai import ( + ChatCompletionResponseMessage, + ChatCompletionToolCallChunk, + ChatCompletionToolCallFunctionChunk, +) class AmazonCohereChatConfig: @@ -1004,12 +1009,12 @@ class BedrockLLM(BaseLLM): if isinstance(timeout, float) or isinstance(timeout, int): timeout = httpx.Timeout(timeout) _params["timeout"] = timeout - self.client = AsyncHTTPHandler(**_params) # type: ignore + client = AsyncHTTPHandler(**_params) # type: ignore else: - self.client = client # type: ignore + client = client # type: ignore try: - response = await self.client.post(api_base, headers=headers, data=data) # type: ignore + response = await client.post(api_base, headers=headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1125,11 +1130,55 @@ class AmazonConverseConfig: "tool_choice", ] + def map_tool_choice_values( + self, model: str, tool_choice: Union[str, dict], drop_params: bool + ) -> Optional[ToolChoiceValuesBlock]: + if not model.startswith("anthropic") and not model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + if drop_params == True or litellm.drop_params == True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`", + status_code=400, + ) + if tool_choice == "none": + if litellm.drop_params is True or drop_params is True: + return None + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + elif tool_choice == "required": + return ToolChoiceValuesBlock(any={}) + elif tool_choice == "auto": + return ToolChoiceValuesBlock(auto={}) + elif isinstance(tool_choice, dict): + # only supported for anthropic + mistral models - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + specific_tool = SpecificToolChoiceBlock( + name=tool_choice.get("function", {}).get("name", "") + ) + return ToolChoiceValuesBlock(tool=specific_tool) + else: + raise litellm.utils.UnsupportedParamsError( + message="Bedrock doesn't support tool_choice={}. Supported tool_choice values=['auto', 'required', json object]. To drop it from the call, set `litellm.drop_params = True.".format( + tool_choice + ), + status_code=400, + ) + def get_supported_image_types(self) -> List[str]: return ["png", "jpeg", "gif", "webp"] def map_openai_params( - self, non_default_params: dict, optional_params: dict + self, + model: str, + non_default_params: dict, + optional_params: dict, + drop_params: bool, ) -> dict: for param, value in non_default_params.items(): if param == "max_tokens": @@ -1144,6 +1193,14 @@ class AmazonConverseConfig: optional_params["temperature"] = value if param == "top_p": optional_params["topP"] = value + if param == "tools": + optional_params["tools"] = value + if param == "tool_choice": + _tool_choice_value = self.map_tool_choice_values( + model=model, tool_choice=value, drop_params=drop_params + ) + if _tool_choice_value is not None: + optional_params["tool_choice"] = _tool_choice_value return optional_params @@ -1151,6 +1208,124 @@ class BedrockConverseLLM(BaseLLM): def __init__(self) -> None: super().__init__() + def process_response( + self, + model: str, + response: Union[requests.Response, httpx.Response], + model_response: ModelResponse, + stream: bool, + logging_obj: Logging, + optional_params: dict, + api_key: str, + data: Union[dict, str], + messages: List, + print_verbose, + encoding, + ) -> Union[ModelResponse, CustomStreamWrapper]: + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key=api_key, + original_response=response.text, + additional_args={"complete_input_dict": data}, + ) + print_verbose(f"raw model_response: {response.text}") + + ## RESPONSE OBJECT + try: + completion_response = ConverseResponseBlock(**response.json()) # type: ignore + except Exception as e: + raise BedrockError( + message="Received={}, Error converting to valid response block={}. File an issue if litellm error - https://github.com/BerriAI/litellm/issues".format( + response.text, str(e) + ), + status_code=422, + ) + + """ + Bedrock Response Object has optional message block + + completion_response["output"].get("message", None) + + A message block looks like this (Example 1): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "text": "Is there anything else you'd like to talk about? Perhaps I can help with some economic questions or provide some information about economic concepts?" + } + ] + } + }, + (Example 2): + "output": { + "message": { + "role": "assistant", + "content": [ + { + "toolUse": { + "toolUseId": "tooluse_hbTgdi0CSLq_hM4P8csZJA", + "name": "top_song", + "input": { + "sign": "WZPZ" + } + } + } + ] + } + } + + """ + message: Optional[MessageBlock] = completion_response["output"]["message"] + chat_completion_message: ChatCompletionResponseMessage = {"role": "assistant"} + content_str = "" + tools: List[ChatCompletionToolCallChunk] = [] + if message is not None: + for content in message["content"]: + """ + - Content is either a tool response or text + """ + if "text" in content: + content_str += content["text"] + if "toolUse" in content: + _function_chunk = ChatCompletionToolCallFunctionChunk( + name=content["toolUse"]["name"], + arguments=json.dumps(content["toolUse"]["input"]), + ) + _tool_response_chunk = ChatCompletionToolCallChunk( + id=content["toolUse"]["toolUseId"], + type="function", + function=_function_chunk, + ) + tools.append(_tool_response_chunk) + chat_completion_message["content"] = content_str + chat_completion_message["tool_calls"] = tools + + ## CALCULATING USAGE - bedrock returns usage in the headers + input_tokens = completion_response["usage"]["inputTokens"] + output_tokens = completion_response["usage"]["outputTokens"] + total_tokens = completion_response["usage"]["totalTokens"] + + model_response.choices = [ + litellm.Choices( + finish_reason=map_finish_reason(completion_response["stopReason"]), + index=0, + message=litellm.Message(**chat_completion_message), + ) + ] + model_response["created"] = int(time.time()) + model_response["model"] = model + usage = Usage( + prompt_tokens=input_tokens, + completion_tokens=output_tokens, + total_tokens=total_tokens, + ) + setattr(model_response, "usage", usage) + + return model_response + def encode_model_id(self, model_id: str) -> str: """ Double encode the model ID to ensure it matches the expected double-encoded format. @@ -1387,11 +1562,14 @@ class BedrockConverseLLM(BaseLLM): additional_request_keys = [] additional_request_params = {} supported_converse_params = AmazonConverseConfig().get_config().keys() - + supported_tool_call_params = ["tools", "tool_choice"] ## TRANSFORMATION ## # send all model-specific params in 'additional_request_params' for k, v in inference_params.items(): - if k not in supported_converse_params: + if ( + k not in supported_converse_params + and k not in supported_tool_call_params + ): additional_request_params[k] = v additional_request_keys.append(k) for key in additional_request_keys: @@ -1401,23 +1579,27 @@ class BedrockConverseLLM(BaseLLM): messages=messages ) bedrock_tools: List[ToolBlock] = _bedrock_tools_pt( - inference_params.get("tools", []) + inference_params.pop("tools", []) ) bedrock_tool_config: Optional[ToolConfigBlock] = None if len(bedrock_tools) > 0: + tool_choice_values: ToolChoiceValuesBlock = inference_params.pop( + "tool_choice", None + ) bedrock_tool_config = ToolConfigBlock( tools=bedrock_tools, - toolChoice=inference_params.get("tool_choice", None), ) + if tool_choice_values is not None: + bedrock_tool_config["toolChoice"] = tool_choice_values - data: RequestObject = { + _data: RequestObject = { "messages": bedrock_messages, "additionalModelRequestFields": additional_request_params, "system": system_content_blocks, } if bedrock_tool_config is not None: - data["toolConfig"] = bedrock_tool_config - + _data["toolConfig"] = bedrock_tool_config + data = json.dumps(_data) ## COMPLETION CALL headers = {"Content-Type": "application/json"} @@ -1441,8 +1623,18 @@ class BedrockConverseLLM(BaseLLM): ) ### ROUTING (ASYNC, STREAMING, SYNC) + ### COMPLETION + if client is None or isinstance(client, AsyncHTTPHandler): + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = HTTPHandler(**_params) # type: ignore + else: + client = client try: - response = self.client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore + response = client.post(url=prepped.url, headers=prepped.headers, data=data) # type: ignore response.raise_for_status() except httpx.HTTPStatusError as err: error_code = err.response.status_code @@ -1450,6 +1642,20 @@ class BedrockConverseLLM(BaseLLM): except httpx.TimeoutException as e: raise BedrockError(status_code=408, message="Timeout error occurred.") + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream, + logging_obj=logging_obj, + optional_params=optional_params, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + encoding=encoding, + ) + def get_response_stream_shape(): from botocore.model import ServiceModel diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index b91aaee2ae..5ec9c79bb2 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -156,12 +156,13 @@ class HTTPHandler: self, url: str, data: Optional[Union[dict, str]] = None, + json: Optional[Union[dict, str]] = None, params: Optional[dict] = None, headers: Optional[dict] = None, stream: bool = False, ): req = self.client.build_request( - "POST", url, data=data, params=params, headers=headers # type: ignore + "POST", url, data=data, json=json, params=params, headers=headers # type: ignore ) response = self.client.send(req, stream=stream) return response diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index d5ef696879..ddd0e1909f 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1617,6 +1617,7 @@ from litellm.types.llms.bedrock import ( ToolInputSchemaBlock as BedrockToolInputSchemaBlock, ToolSpecBlock as BedrockToolSpecBlock, ToolBlock as BedrockToolBlock, + ToolChoiceValuesBlock as BedrockToolChoiceValuesBlock, ) @@ -1814,7 +1815,7 @@ def _convert_to_bedrock_tool_call_result( tool_result_content_block = BedrockToolResultContentBlock(text=content) tool_result = BedrockToolResultBlock( - content=tool_result_content_block, + content=[tool_result_content_block], toolUseId=id, ) content_block = BedrockContentBlock(toolResult=tool_result) diff --git a/litellm/main.py b/litellm/main.py index f76d6c5213..c95b419ba2 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -79,7 +79,7 @@ from .llms.anthropic import AnthropicChatCompletion from .llms.anthropic_text import AnthropicTextCompletion from .llms.huggingface_restapi import Huggingface from .llms.predibase import PredibaseChatCompletion -from .llms.bedrock_httpx import BedrockLLM +from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM from .llms.vertex_httpx import VertexLLM from .llms.triton import TritonChatCompletion from .llms.prompt_templates.factory import ( @@ -121,7 +121,7 @@ azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() -bedrock_chat_completion = BedrockLLM() +bedrock_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 1befa1392e..bcbe4944c5 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -302,10 +302,7 @@ def test_completion_claude_3(): @pytest.mark.parametrize( "model", - [ - # "anthropic/claude-3-opus-20240229", - "cohere.command-r-plus-v1:0" - ], + ["anthropic/claude-3-opus-20240229", "anthropic.claude-3-sonnet-20240229-v1:0"], ) def test_completion_claude_3_function_call(model): litellm.set_verbose = True @@ -345,6 +342,7 @@ def test_completion_claude_3_function_call(model): "type": "function", "function": {"name": "get_current_weather"}, }, + drop_params=True, ) # Add any assertions, here to check response args @@ -375,6 +373,7 @@ def test_completion_claude_3_function_call(model): messages=messages, tools=tools, tool_choice="auto", + drop_params=True, ) print(second_response) except Exception as e: diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 2fc04ec528..9f112a0b1b 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -15,6 +15,7 @@ from litellm.llms.prompt_templates.factory import ( claude_2_1_pt, llama_2_chat_pt, prompt_factory, + _bedrock_tools_pt, ) @@ -128,3 +129,29 @@ def test_anthropic_messages_pt(): # codellama_prompt_format() +def test_bedrock_tool_calling_pt(): + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + }, + } + ] + converted_tools = _bedrock_tools_pt(tools=tools) + + print(converted_tools) + + assert False diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 9333ea1b9e..647dc1d7b0 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -31,7 +31,7 @@ class ToolResultContentBlock(TypedDict, total=False): class ToolResultBlock(TypedDict, total=False): - content: Required[ToolResultContentBlock] + content: Required[List[ToolResultContentBlock]] toolUseId: Required[str] status: Literal["success", "error"] @@ -54,6 +54,30 @@ class MessageBlock(TypedDict): role: Literal["user", "assistant"] +class ConverseMetricsBlock(TypedDict): + latencyMs: float # time in ms + + +class ConverseResponseOutputBlock(TypedDict): + message: Optional[MessageBlock] + + +class ConverseTokenUsageBlock(TypedDict): + inputTokens: int + outputTokens: int + totalTokens: int + + +class ConverseResponseBlock(TypedDict): + additionalModelResponseFields: dict + metrics: ConverseMetricsBlock + output: ConverseResponseOutputBlock + stopReason: ( + str # end_turn | tool_use | max_tokens | stop_sequence | content_filtered + ) + usage: ConverseTokenUsageBlock + + class ToolInputSchemaBlock(TypedDict): json: Optional[dict] @@ -72,9 +96,15 @@ class SpecificToolChoiceBlock(TypedDict): name: str +class ToolChoiceValuesBlock(TypedDict, total=False): + any: dict + auto: dict + tool: SpecificToolChoiceBlock + + class ToolConfigBlock(TypedDict, total=False): tools: Required[List[ToolBlock]] - toolChoice: Union[str, SpecificToolChoiceBlock] + toolChoice: Union[str, ToolChoiceValuesBlock] class RequestObject(TypedDict, total=False): diff --git a/litellm/types/llms/openai.py b/litellm/types/llms/openai.py index bc0c82434f..7861e394cd 100644 --- a/litellm/types/llms/openai.py +++ b/litellm/types/llms/openai.py @@ -293,3 +293,20 @@ class ListBatchRequest(TypedDict, total=False): extra_headers: Optional[Dict[str, str]] extra_body: Optional[Dict[str, str]] timeout: Optional[float] + + +class ChatCompletionToolCallFunctionChunk(TypedDict): + name: str + arguments: str + + +class ChatCompletionToolCallChunk(TypedDict): + id: str + type: Literal["function"] + function: ChatCompletionToolCallFunctionChunk + + +class ChatCompletionResponseMessage(TypedDict, total=False): + content: Optional[str] + tool_calls: List[ChatCompletionToolCallChunk] + role: Literal["assistant"] diff --git a/litellm/utils.py b/litellm/utils.py index 65a34058b2..6db5f540c0 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5618,84 +5618,13 @@ def get_optional_params( supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider ) - if "ai21" in model: - _check_valid_arg(supported_params=supported_params) - # params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[], - # https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra - if max_tokens is not None: - optional_params["maxTokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["topP"] = top_p - if stream: - optional_params["stream"] = stream - elif "anthropic" in model: - _check_valid_arg(supported_params=supported_params) - # anthropic params on bedrock - # \"max_tokens_to_sample\":300,\"temperature\":0.5,\"top_p\":1,\"stop_sequences\":[\"\\\\n\\\\nHuman:\"]}" - if model.startswith("anthropic.claude-3"): - optional_params = ( - litellm.AmazonAnthropicClaude3Config().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) - ) - else: - optional_params = litellm.AmazonAnthropicConfig().map_openai_params( - non_default_params=non_default_params, - optional_params=optional_params, - ) - elif "amazon" in model: # amazon titan llms - _check_valid_arg(supported_params=supported_params) - # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large - if max_tokens is not None: - optional_params["maxTokenCount"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if stop is not None: - filtered_stop = _map_and_modify_arg( - {"stop": stop}, provider="bedrock", model=model - ) - optional_params["stopSequences"] = filtered_stop["stop"] - if top_p is not None: - optional_params["topP"] = top_p - if stream: - optional_params["stream"] = stream - elif "meta" in model: # amazon / meta llms - _check_valid_arg(supported_params=supported_params) - # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large - if max_tokens is not None: - optional_params["max_gen_len"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stream: - optional_params["stream"] = stream - elif "cohere" in model: # cohere models on bedrock - _check_valid_arg(supported_params=supported_params) - # handle cohere params - if stream: - optional_params["stream"] = stream - if temperature is not None: - optional_params["temperature"] = temperature - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - elif "mistral" in model: - _check_valid_arg(supported_params=supported_params) - # mistral params on bedrock - # \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}" - if max_tokens is not None: - optional_params["max_tokens"] = max_tokens - if temperature is not None: - optional_params["temperature"] = temperature - if top_p is not None: - optional_params["top_p"] = top_p - if stop is not None: - optional_params["stop"] = stop - if stream is not None: - optional_params["stream"] = stream + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AmazonConverseConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=drop_params, + ) elif custom_llm_provider == "aleph_alpha": supported_params = [ "max_tokens", diff --git a/ruff.toml b/ruff.toml new file mode 100644 index 0000000000..dfb323c1b3 --- /dev/null +++ b/ruff.toml @@ -0,0 +1,3 @@ +ignore = ["F405"] +extend-select = ["E501"] +line-length = 120 From 92a3c062a70054e26630d66356879b9e9486408a Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 21:29:40 -0700 Subject: [PATCH 13/52] fix log_to_opentelemetry --- litellm/proxy/auth/auth_checks.py | 41 ++++++++++++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index f184589f4f..fdd01f17e2 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -18,14 +18,45 @@ from litellm.proxy._types import ( LitellmUserRoles, ) from typing import Optional, Literal, Union -from litellm.proxy.utils import PrismaClient +from litellm.proxy.utils import PrismaClient, ProxyLogging from litellm.caching import DualCache import litellm from opentelemetry.trace import Span +from functools import wraps +from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from datetime import datetime all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value +def log_to_opentelemetry(func): + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = datetime.now() + result = await func(*args, **kwargs) + end_time = datetime.now() + + # Log to OTEL only if "parent_otel_span" is in kwargs and is not None + if ( + "parent_otel_span" in kwargs + and kwargs["parent_otel_span"] is not None + and "proxy_logging_obj" in kwargs + and kwargs["proxy_logging_obj"] is not None + ): + proxy_logging_obj = kwargs["proxy_logging_obj"] + await proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs["parent_otel_span"], + start_time=start_time, + end_time=end_time, + ) + # end of logging to otel + return result + + return wrapper + + def common_checks( request_body: dict, team_object: Optional[LiteLLM_TeamTable], @@ -183,11 +214,13 @@ def get_actual_routes(allowed_routes: list) -> list: return actual_routes +@log_to_opentelemetry async def get_end_user_object( end_user_id: Optional[str], prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ) -> Optional[LiteLLM_EndUserTable]: """ Returns end user object, if in db. @@ -247,12 +280,14 @@ async def get_end_user_object( return None +@log_to_opentelemetry async def get_user_object( user_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, user_id_upsert: bool, parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ) -> Optional[LiteLLM_UserTable]: """ - Check if user id in proxy User Table @@ -299,11 +334,13 @@ async def get_user_object( ) +@log_to_opentelemetry async def get_team_object( team_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ) -> LiteLLM_TeamTable: """ - Check if team id in proxy Team Table @@ -342,11 +379,13 @@ async def get_team_object( ) +@log_to_opentelemetry async def get_org_object( org_id: str, prisma_client: Optional[PrismaClient], user_api_key_cache: DualCache, parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ): """ - Check if org id in proxy Org Table From 87df3a4a048da95b6a0f2b733d78f3f8cea98d0f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 21:30:55 -0700 Subject: [PATCH 14/52] add _to_ns to utils --- litellm/proxy/utils.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 776d700962..549b0f2abc 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -2837,3 +2837,7 @@ missing_keys_html_form = """ """ + + +def _to_ns(dt): + return int(dt.timestamp() * 1e9) From c867f88c5703dc00a1839334bb0c9fe7fb855aa7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 22:06:28 -0700 Subject: [PATCH 15/52] fix - add new types for ServiceLoggerPayload --- litellm/types/services.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/litellm/types/services.py b/litellm/types/services.py index b694ca8078..9c3c2120eb 100644 --- a/litellm/types/services.py +++ b/litellm/types/services.py @@ -3,7 +3,7 @@ from pydantic import BaseModel, Field from typing import Optional -class ServiceTypes(enum.Enum): +class ServiceTypes(str, enum.Enum): """ Enum for litellm + litellm-adjacent services (redis/postgres/etc.) """ From 312521a0b396ef7a16a9c9441f9b70d777906b7b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 22:12:45 -0700 Subject: [PATCH 16/52] fix service logger for OTEL --- litellm/_service_logger.py | 22 ++++++++++++++++++++-- litellm/integrations/opentelemetry.py | 27 +++++++++++++++++++++++++++ litellm/proxy/proxy_server.py | 12 +++++++++--- 3 files changed, 56 insertions(+), 5 deletions(-) diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index dc6f35642b..dcc2fc1dd3 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -4,7 +4,9 @@ from .types.services import ServiceTypes, ServiceLoggerPayload from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.custom_logger import CustomLogger from datetime import timedelta -from typing import Union +from typing import Union, Optional +from opentelemetry.trace import Span +from datetime import datetime class ServiceLogging(CustomLogger): @@ -40,7 +42,13 @@ class ServiceLogging(CustomLogger): self.mock_testing_sync_failure_hook += 1 async def async_service_success_hook( - self, service: ServiceTypes, duration: float, call_type: str + self, + service: ServiceTypes, + call_type: str, + duration: float, + parent_otel_span: Optional[Span] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, ): """ - For counting if the redis, postgres call is successful @@ -61,6 +69,16 @@ class ServiceLogging(CustomLogger): payload=payload ) + from litellm.proxy.proxy_server import open_telemetry_logger + + if parent_otel_span is not None and open_telemetry_logger is not None: + await open_telemetry_logger.async_service_success_hook( + payload=payload, + parent_otel_span=parent_otel_span, + start_time=start_time, + end_time=end_time, + ) + async def async_service_failure_hook( self, service: ServiceTypes, diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index ac92d5ddd7..90b0626002 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -4,6 +4,9 @@ from dataclasses import dataclass from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_logger +from litellm.types.services import ServiceLoggerPayload, ServiceTypes +from opentelemetry.trace import Span +from datetime import datetime LITELLM_TRACER_NAME = "litellm" LITELLM_RESOURCE = {"service.name": "litellm"} @@ -74,6 +77,30 @@ class OpenTelemetry(CustomLogger): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): self._handle_failure(kwargs, response_obj, start_time, end_time) + async def async_service_success_hook( + self, + payload: ServiceLoggerPayload, + parent_otel_span: Optional[Span] = None, + start_time: Optional[datetime] = None, + end_time: Optional[datetime] = None, + ): + from opentelemetry import trace + from datetime import datetime + + if parent_otel_span is not None: + _span_name = payload.service + service_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + start_time=self._to_ns(start_time), + ) + service_logging_span.set_attribute(key="call_type", value=payload.call_type) + service_logging_span.set_attribute( + key="service", value=payload.service.value + ) + service_logging_span.end(end_time=self._to_ns(end_time)) + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + def _handle_sucess(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 05eb515739..ef60a9b444 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -104,6 +104,7 @@ from litellm.proxy.utils import ( update_spend, encrypt_value, decrypt_value, + _to_ns, ) from litellm import ( CreateBatchRequest, @@ -399,7 +400,7 @@ disable_spend_logs = False jwt_handler = JWTHandler() prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None store_model_in_db: bool = False -open_telemetry_logger = None +open_telemetry_logger: Optional[litellm.integrations.opentelemetry.OpenTelemetry] = None ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### @@ -495,7 +496,7 @@ async def check_request_disconnection(request: Request, llm_api_call_task): async def user_api_key_auth( request: Request, api_key: str = fastapi.Security(api_key_header) ) -> UserAPIKeyAuth: - global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings + global master_key, prisma_client, llm_model_list, user_custom_auth, custom_db_client, general_settings, proxy_logging_obj try: if isinstance(api_key, str): passed_in_key = api_key @@ -504,7 +505,7 @@ async def user_api_key_auth( if open_telemetry_logger is not None: parent_otel_span = open_telemetry_logger.tracer.start_span( name="Received Proxy Server Request", - start_time=time.time(), + start_time=_to_ns(datetime.now()), ) ### USER-DEFINED AUTH FUNCTION ### if user_custom_auth is not None: @@ -588,6 +589,7 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) # [OPTIONAL] track spend for an org id - `LiteLLM_OrganizationTable` @@ -600,6 +602,7 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) # [OPTIONAL] track spend against an internal employee - `LiteLLM_UserTable` user_object = None @@ -614,6 +617,7 @@ async def user_api_key_auth( user_api_key_cache=user_api_key_cache, user_id_upsert=jwt_handler.is_upsert_user_id(), parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) # [OPTIONAL] track spend against an external user - `LiteLLM_EndUserTable` @@ -628,6 +632,7 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) global_proxy_spend = None @@ -727,6 +732,7 @@ async def user_api_key_auth( prisma_client=prisma_client, user_api_key_cache=user_api_key_cache, parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) if _end_user_object is not None: end_user_params["allowed_model_region"] = ( From cd125e6309a0a80066282e5b636f8f0a549a4dcb Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 22:13:13 -0700 Subject: [PATCH 17/52] fix auth checks --- litellm/proxy/auth/auth_checks.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index fdd01f17e2..49bce0855e 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -48,6 +48,7 @@ def log_to_opentelemetry(func): service=ServiceTypes.DB, call_type=func.__name__, parent_otel_span=kwargs["parent_otel_span"], + duration=0.0, start_time=start_time, end_time=end_time, ) From c41b60f6bfd8b936955d8b79cd5f66472872d809 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Thu, 6 Jun 2024 22:13:21 -0700 Subject: [PATCH 18/52] feat(bedrock_httpx.py): working bedrock converse api streaming --- litellm/__init__.py | 3 +- litellm/llms/bedrock_httpx.py | 125 ++++++++++++++++++---- litellm/llms/custom_httpx/http_handler.py | 1 + litellm/tests/test_streaming.py | 4 +- litellm/types/llms/bedrock.py | 24 ++++- litellm/utils.py | 33 +++++- 6 files changed, 165 insertions(+), 25 deletions(-) diff --git a/litellm/__init__.py b/litellm/__init__.py index 2fc47a9926..fe0dd2a56b 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -5,7 +5,7 @@ warnings.filterwarnings("ignore", message=".*conflict with protected namespace.* ### INIT VARIABLES ### import threading, requests, os from typing import Callable, List, Optional, Dict, Union, Any, Literal -from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.caching import Cache from litellm._logging import ( set_verbose, @@ -232,6 +232,7 @@ max_end_user_budget: Optional[float] = None #### RELIABILITY #### request_timeout: float = 6000 module_level_aclient = AsyncHTTPHandler(timeout=request_timeout) +module_level_client = HTTPHandler(timeout=request_timeout) num_retries: Optional[int] = None # per model endpoint default_fallbacks: Optional[List] = None fallbacks: Optional[List] = None diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index ce6a931747..6329f165eb 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -185,6 +185,37 @@ async def make_call( return completion_stream +def make_sync_call( + client: Optional[HTTPHandler], + api_base: str, + headers: dict, + data: str, + model: str, + messages: list, + logging_obj, +): + if client is None: + client = HTTPHandler() # Create a new client if none provided + + response = client.post(api_base, headers=headers, data=data, stream=True) + + if response.status_code != 200: + raise BedrockError(status_code=response.status_code, message=response.read()) + + decoder = AWSEventStreamDecoder(model=model) + completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) + + # LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=completion_stream, # Pass the completion stream for logging + additional_args={"complete_input_dict": data}, + ) + + return completion_stream + + class BedrockLLM(BaseLLM): """ Example call @@ -1081,6 +1112,7 @@ class BedrockLLM(BaseLLM): class AmazonConverseConfig: """ Reference - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html + #2 - https://docs.aws.amazon.com/bedrock/latest/userguide/conversation-inference.html#conversation-inference-supported-models-features """ maxTokens: Optional[int] @@ -1118,30 +1150,32 @@ class AmazonConverseConfig: and v is not None } - def get_supported_openai_params(self) -> List[str]: - return [ + def get_supported_openai_params(self, model: str) -> List[str]: + supported_params = [ "max_tokens", "stream", "stream_options", "stop", "temperature", "top_p", - "tools", - "tool_choice", ] + if ( + model.startswith("anthropic") + or model.startswith("mistral") + or model.startswith("cohere") + ): + supported_params.append("tools") + + if model.startswith("anthropic") or model.startswith("mistral"): + # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html + supported_params.append("tool_choice") + + return supported_params + def map_tool_choice_values( self, model: str, tool_choice: Union[str, dict], drop_params: bool ) -> Optional[ToolChoiceValuesBlock]: - if not model.startswith("anthropic") and not model.startswith("mistral"): - # only anthropic and mistral support tool choice config. otherwise (E.g. cohere) will fail the call - https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_ToolChoice.html - if drop_params == True or litellm.drop_params == True: - return None - else: - raise litellm.utils.UnsupportedParamsError( - message="Only Anthropic and Mistral on Bedrock support 'tool_choice'. To drop it from the call, set `litellm.drop_params = True.`", - status_code=400, - ) if tool_choice == "none": if litellm.drop_params is True or drop_params is True: return None @@ -1197,7 +1231,7 @@ class AmazonConverseConfig: optional_params["tools"] = value if param == "tool_choice": _tool_choice_value = self.map_tool_choice_values( - model=model, tool_choice=value, drop_params=drop_params + model=model, tool_choice=value, drop_params=drop_params # type: ignore ) if _tool_choice_value is not None: optional_params["tool_choice"] = _tool_choice_value @@ -1539,7 +1573,7 @@ class BedrockConverseLLM(BaseLLM): else: endpoint_url = f"https://bedrock-runtime.{aws_region_name}.amazonaws.com" - if (stream is not None and stream == True) and provider != "ai21": + if (stream is not None and stream is True) and provider != "ai21": endpoint_url = f"{endpoint_url}/model/{modelId}/converse-stream" else: endpoint_url = f"{endpoint_url}/model/{modelId}/converse" @@ -1561,7 +1595,7 @@ class BedrockConverseLLM(BaseLLM): inference_params = copy.deepcopy(optional_params) additional_request_keys = [] additional_request_params = {} - supported_converse_params = AmazonConverseConfig().get_config().keys() + supported_converse_params = AmazonConverseConfig.__annotations__.keys() supported_tool_call_params = ["tools", "tool_choice"] ## TRANSFORMATION ## # send all model-specific params in 'additional_request_params' @@ -1596,6 +1630,7 @@ class BedrockConverseLLM(BaseLLM): "messages": bedrock_messages, "additionalModelRequestFields": additional_request_params, "system": system_content_blocks, + "inferenceConfig": InferenceConfig(**inference_params), } if bedrock_tool_config is not None: _data["toolConfig"] = bedrock_tool_config @@ -1623,7 +1658,35 @@ class BedrockConverseLLM(BaseLLM): ) ### ROUTING (ASYNC, STREAMING, SYNC) + if (stream is not None and stream is True) and provider != "ai21": + + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_sync_call, + client=None, + api_base=prepped.url, + headers=prepped.headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + + ## LOGGING + logging_obj.post_call( + input=messages, + api_key="", + original_response=streaming_response, + additional_args={"complete_input_dict": data}, + ) + return streaming_response ### COMPLETION + if client is None or isinstance(client, AsyncHTTPHandler): _params = {} if timeout is not None: @@ -1675,6 +1738,31 @@ class AWSEventStreamDecoder: self.parser = EventStreamJSONParser() def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + text = "" + tool_str = "" + is_finished = False + finish_reason = "" + usage: Optional[ConverseTokenUsageBlock] = None + if "delta" in chunk_data: + delta_obj = ContentBlockDeltaEvent(**chunk_data["delta"]) + if "text" in delta_obj: + text = delta_obj["text"] + elif "toolUse" in delta_obj: + tool_str = delta_obj["toolUse"]["input"] + elif "stopReason" in chunk_data: + finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) + elif "usage" in chunk_data: + usage = ConverseTokenUsageBlock(**chunk_data["usage"]) + response = GenericStreamingChunk( + text=text, + tool_str=tool_str, + is_finished=is_finished, + finish_reason=finish_reason, + usage=usage, + ) + return response + + def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" is_finished = False finish_reason = "" @@ -1763,12 +1851,11 @@ class AWSEventStreamDecoder: def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() - parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") - chunk = parsed_response.get("chunk") + chunk = response_dict.get("body") if not chunk: return None - return chunk.get("bytes").decode() # type: ignore[no-any-return] + return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index 5ec9c79bb2..d8dd4f01e4 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -168,6 +168,7 @@ class HTTPHandler: return response def __del__(self) -> None: + traceback.print_stack() try: self.close() except Exception: diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index c24de601f5..1113adc40c 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1284,7 +1284,7 @@ async def test_completion_replicate_llama3_streaming(sync_mode): # pytest.fail(f"Error occurred: {e}") -@pytest.mark.parametrize("sync_mode", [True, False]) +@pytest.mark.parametrize("sync_mode", [True]) # False @pytest.mark.parametrize( "model", [ @@ -1324,6 +1324,8 @@ async def test_bedrock_httpx_streaming(sync_mode, model): raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") + + assert False else: response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore model=model, diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 647dc1d7b0..757ece516f 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -107,10 +107,30 @@ class ToolConfigBlock(TypedDict, total=False): toolChoice: Union[str, ToolChoiceValuesBlock] +class InferenceConfig(TypedDict, total=False): + maxTokens: int + stopSequences: List[str] + temperature: float + topP: float + + +class ToolBlockDeltaEvent(TypedDict): + input: str + + +class ContentBlockDeltaEvent(TypedDict, total=False): + """ + Either 'text' or 'toolUse' will be specified for Converse API streaming response. + """ + + text: str + toolUse: ToolBlockDeltaEvent + + class RequestObject(TypedDict, total=False): additionalModelRequestFields: dict additionalModelResponseFieldPaths: List[str] - inferenceConfig: dict + inferenceConfig: InferenceConfig messages: Required[List[MessageBlock]] system: List[SystemContentBlock] toolConfig: ToolConfigBlock @@ -118,8 +138,10 @@ class RequestObject(TypedDict, total=False): class GenericStreamingChunk(TypedDict): text: Required[str] + tool_str: Required[str] is_finished: Required[bool] finish_reason: Required[str] + usage: Optional[ConverseTokenUsageBlock] class Document(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index 6db5f540c0..75dd853286 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -239,6 +239,8 @@ def map_finish_reason( return "length" elif finish_reason == "tool_use": # anthropic return "tool_calls" + elif finish_reason == "content_filtered": + return "content_filter" return finish_reason @@ -6330,7 +6332,7 @@ def get_supported_openai_params( - None if unmapped """ if custom_llm_provider == "bedrock": - return litellm.AmazonConverseConfig().get_supported_openai_params() + return litellm.AmazonConverseConfig().get_supported_openai_params(model=model) elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_supported_openai_params() elif custom_llm_provider == "ollama_chat": @@ -11242,12 +11244,27 @@ class CustomStreamWrapper: if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] elif self.custom_llm_provider == "bedrock": + from litellm.types.llms.bedrock import GenericStreamingChunk + if self.received_finish_reason is not None: raise StopIteration - response_obj = self.handle_bedrock_stream(chunk) + response_obj: GenericStreamingChunk = chunk completion_obj["content"] = response_obj["text"] + if response_obj["is_finished"]: self.received_finish_reason = response_obj["finish_reason"] + + if ( + self.stream_options + and self.stream_options.get("include_usage", False) is True + and response_obj["usage"] is not None + ): + self.sent_stream_usage = True + model_response.usage = litellm.Usage( + prompt_tokens=response_obj["usage"]["inputTokens"], + completion_tokens=response_obj["usage"]["outputTokens"], + total_tokens=response_obj["usage"]["totalTokens"], + ) elif self.custom_llm_provider == "sagemaker": print_verbose(f"ENTERS SAGEMAKER STREAMING for chunk {chunk}") response_obj = self.handle_sagemaker_stream(chunk) @@ -11509,7 +11526,7 @@ class CustomStreamWrapper: and hasattr(model_response, "usage") and hasattr(model_response.usage, "prompt_tokens") ): - if self.sent_first_chunk == False: + if self.sent_first_chunk is False: completion_obj["role"] = "assistant" self.sent_first_chunk = True model_response.choices[0].delta = Delta(**completion_obj) @@ -11677,6 +11694,8 @@ class CustomStreamWrapper: def __next__(self): try: + if self.completion_stream is None: + self.fetch_sync_stream() while True: if ( isinstance(self.completion_stream, str) @@ -11751,6 +11770,14 @@ class CustomStreamWrapper: custom_llm_provider=self.custom_llm_provider, ) + def fetch_sync_stream(self): + if self.completion_stream is None and self.make_call is not None: + # Call make_call to get the completion stream + self.completion_stream = self.make_call(client=litellm.module_level_client) + self._stream_iter = self.completion_stream.__iter__() + + return self.completion_stream + async def fetch_stream(self): if self.completion_stream is None and self.make_call is not None: # Call make_call to get the completion stream From 193e71642cd8a0014109d711dbe5c0b2344add07 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 22:28:01 -0700 Subject: [PATCH 19/52] fix - log_to_opentelemetry --- litellm/proxy/auth/auth_checks.py | 32 +---------------------------- litellm/proxy/utils.py | 34 +++++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 31 deletions(-) diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 49bce0855e..10037e60fd 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -18,46 +18,16 @@ from litellm.proxy._types import ( LitellmUserRoles, ) from typing import Optional, Literal, Union -from litellm.proxy.utils import PrismaClient, ProxyLogging +from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry from litellm.caching import DualCache import litellm from opentelemetry.trace import Span -from functools import wraps from litellm.types.services import ServiceLoggerPayload, ServiceTypes from datetime import datetime all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value -def log_to_opentelemetry(func): - @wraps(func) - async def wrapper(*args, **kwargs): - start_time = datetime.now() - result = await func(*args, **kwargs) - end_time = datetime.now() - - # Log to OTEL only if "parent_otel_span" is in kwargs and is not None - if ( - "parent_otel_span" in kwargs - and kwargs["parent_otel_span"] is not None - and "proxy_logging_obj" in kwargs - and kwargs["proxy_logging_obj"] is not None - ): - proxy_logging_obj = kwargs["proxy_logging_obj"] - await proxy_logging_obj.service_logging_obj.async_service_success_hook( - service=ServiceTypes.DB, - call_type=func.__name__, - parent_otel_span=kwargs["parent_otel_span"], - duration=0.0, - start_time=start_time, - end_time=end_time, - ) - # end of logging to otel - return result - - return wrapper - - def common_checks( request_body: dict, team_object: Optional[LiteLLM_TeamTable], diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 549b0f2abc..e5efb93d05 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -47,6 +47,8 @@ from email.mime.multipart import MIMEMultipart from datetime import datetime, timedelta from litellm.integrations.slack_alerting import SlackAlerting from typing_extensions import overload +from opentelemetry.trace import Span +from functools import wraps def print_verbose(print_statement): @@ -62,6 +64,35 @@ def print_verbose(print_statement): print(f"LiteLLM Proxy: {print_statement}") # noqa +def log_to_opentelemetry(func): + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = datetime.now() + result = await func(*args, **kwargs) + end_time = datetime.now() + + # Log to OTEL only if "parent_otel_span" is in kwargs and is not None + if ( + "parent_otel_span" in kwargs + and kwargs["parent_otel_span"] is not None + and "proxy_logging_obj" in kwargs + and kwargs["proxy_logging_obj"] is not None + ): + proxy_logging_obj = kwargs["proxy_logging_obj"] + await proxy_logging_obj.service_logging_obj.async_service_success_hook( + service=ServiceTypes.DB, + call_type=func.__name__, + parent_otel_span=kwargs["parent_otel_span"], + duration=0.0, + start_time=start_time, + end_time=end_time, + ) + # end of logging to otel + return result + + return wrapper + + ### LOGGING ### class ProxyLogging: """ @@ -831,6 +862,7 @@ class PrismaClient: max_time=10, # maximum total time to retry for on_backoff=on_backoff, # specifying the function to call on backoff ) + @log_to_opentelemetry async def get_data( self, token: Optional[Union[str, list]] = None, @@ -857,6 +889,8 @@ class PrismaClient: limit: Optional[ int ] = None, # pagination, number of rows to getch when find_all==True + parent_otel_span: Optional[Span] = None, + proxy_logging_obj: Optional[ProxyLogging] = None, ): args_passed_in = locals() start_time = time.time() From f48d8fd6aef54a25010244c4e2f95502aa82daa7 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Thu, 6 Jun 2024 22:30:49 -0700 Subject: [PATCH 20/52] fix open_telemetry_logger --- litellm/proxy/proxy_server.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index ef60a9b444..d21553c2d9 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -400,7 +400,7 @@ disable_spend_logs = False jwt_handler = JWTHandler() prompt_injection_detection_obj: Optional[_OPTIONAL_PromptInjectionDetection] = None store_model_in_db: bool = False -open_telemetry_logger: Optional[litellm.integrations.opentelemetry.OpenTelemetry] = None +open_telemetry_logger = None ### INITIALIZE GLOBAL LOGGING OBJECT ### proxy_logging_obj = ProxyLogging(user_api_key_cache=user_api_key_cache) ### REDIS QUEUE ### @@ -845,7 +845,10 @@ async def user_api_key_auth( verbose_proxy_logger.debug("api key: %s", api_key) if prisma_client is not None: _valid_token: Optional[BaseModel] = await prisma_client.get_data( - token=api_key, table_name="combined_view" + token=api_key, + table_name="combined_view", + parent_otel_span=parent_otel_span, + proxy_logging_obj=proxy_logging_obj, ) if _valid_token is not None: valid_token = UserAPIKeyAuth( From c4613565238224f9bb11d972ed3e29e0cfb8aa25 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 08:25:26 -0700 Subject: [PATCH 21/52] feat - add parent_otel_span to UserAPIKeyAuth --- litellm/proxy/_types.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 8a95f4e1d8..e089c0429d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -6,6 +6,8 @@ from datetime import datetime import uuid, json, sys, os from litellm.types.router import UpdateRouterConfig from litellm.types.utils import ProviderField +from typing_extensions import Annotated +from opentelemetry.trace import Span class LitellmUserRoles(str, enum.Enum): @@ -1194,6 +1196,7 @@ class UserAPIKeyAuth( ] ] = None allowed_model_region: Optional[Literal["eu"]] = None + parent_otel_span: Optional[Span] = None @model_validator(mode="before") @classmethod @@ -1206,6 +1209,9 @@ class UserAPIKeyAuth( values.update({"api_key": hash_token(values.get("api_key"))}) return values + class Config: + arbitrary_types_allowed = True + class LiteLLM_Config(LiteLLMBase): param_name: str From 0ccf1bff521d24ae7c96332bf2372cdcf9c89c0b Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 08:27:47 -0700 Subject: [PATCH 22/52] feat - use async_service_success_hook with litellm proxy --- litellm/integrations/opentelemetry.py | 23 +++++++++++++++++++---- 1 file changed, 19 insertions(+), 4 deletions(-) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 90b0626002..5c6312c05f 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -86,6 +86,7 @@ class OpenTelemetry(CustomLogger): ): from opentelemetry import trace from datetime import datetime + from opentelemetry.trace import Status, StatusCode if parent_otel_span is not None: _span_name = payload.service @@ -98,8 +99,8 @@ class OpenTelemetry(CustomLogger): service_logging_span.set_attribute( key="service", value=payload.service.value ) + service_logging_span.set_status(Status(StatusCode.OK)) service_logging_span.end(end_time=self._to_ns(end_time)) - parent_otel_span.end(end_time=self._to_ns(datetime.now())) def _handle_sucess(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -109,15 +110,18 @@ class OpenTelemetry(CustomLogger): kwargs, self.config, ) + _parent_context, parent_otel_span = self._get_span_context(kwargs) span = self.tracer.start_span( name=self._get_span_name(kwargs), start_time=self._to_ns(start_time), - context=self._get_span_context(kwargs), + context=_parent_context, ) span.set_status(Status(StatusCode.OK)) self.set_attributes(span, kwargs, response_obj) span.end(end_time=self._to_ns(end_time)) + if parent_otel_span is not None: + parent_otel_span.end(end_time=self._to_ns(datetime.now())) def _handle_failure(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode @@ -146,17 +150,28 @@ class OpenTelemetry(CustomLogger): from opentelemetry.trace.propagation.tracecontext import ( TraceContextTextMapPropagator, ) + from opentelemetry import trace litellm_params = kwargs.get("litellm_params", {}) or {} proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} headers = proxy_server_request.get("headers", {}) or {} traceparent = headers.get("traceparent", None) + _metadata = litellm_params.get("metadata", {}) or {} + parent_otel_span = _metadata.get("litellm_parent_otel_span", None) + + """ + Two way to use parents in opentelemetry + - using the traceparent header + - using the parent_otel_span in the [metadata][parent_otel_span] + """ + if parent_otel_span is not None: + return trace.set_span_in_context(parent_otel_span), parent_otel_span if traceparent is None: - return None + return None, None else: carrier = {"traceparent": traceparent} - return TraceContextTextMapPropagator().extract(carrier=carrier) + return TraceContextTextMapPropagator().extract(carrier=carrier), None def _get_span_processor(self): from opentelemetry.sdk.trace.export import ( From 12ed3dc9117de19cef87c3df39c51e1dc99dbd6f Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 08:47:51 -0700 Subject: [PATCH 23/52] refactor(main.py): only route anthropic calls through converse api v0 scope let's move function calling to converse api --- litellm/llms/bedrock_httpx.py | 168 +- litellm/llms/custom_httpx/http_handler.py | 1 - litellm/main.py | 53 +- litellm/tests/log.txt | 4274 --------------------- litellm/tests/test_streaming.py | 18 +- litellm/utils.py | 81 +- 6 files changed, 263 insertions(+), 4332 deletions(-) delete mode 100644 litellm/tests/log.txt diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 6329f165eb..7aba78d7ce 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1125,7 +1125,7 @@ class AmazonConverseConfig: maxTokens: Optional[int] = None, stopSequences: Optional[List[str]] = None, temperature: Optional[int] = None, - top_p: Optional[int] = None, + topP: Optional[int] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): @@ -1481,6 +1481,93 @@ class BedrockConverseLLM(BaseLLM): return session.get_credentials() + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + def completion( self, model: str, @@ -1504,7 +1591,7 @@ class BedrockConverseLLM(BaseLLM): from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials - except ImportError as e: + except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## SETUP ## @@ -1658,6 +1745,46 @@ class BedrockConverseLLM(BaseLLM): ) ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream is True and provider != "ai21": + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, # type: ignore + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + if (stream is not None and stream is True) and provider != "ai21": streaming_response = CustomStreamWrapper( @@ -1666,7 +1793,7 @@ class BedrockConverseLLM(BaseLLM): make_sync_call, client=None, api_base=prepped.url, - headers=prepped.headers, + headers=prepped.headers, # type: ignore data=data, model=model, messages=messages, @@ -1702,7 +1829,7 @@ class BedrockConverseLLM(BaseLLM): except httpx.HTTPStatusError as err: error_code = err.response.status_code raise BedrockError(status_code=error_code, message=response.text) - except httpx.TimeoutException as e: + except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") return self.process_response( @@ -1737,7 +1864,7 @@ class AWSEventStreamDecoder: self.model = model self.parser = EventStreamJSONParser() - def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" tool_str = "" is_finished = False @@ -1762,7 +1889,7 @@ class AWSEventStreamDecoder: ) return response - def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" is_finished = False finish_reason = "" @@ -1774,19 +1901,8 @@ class AWSEventStreamDecoder: is_finished = True finish_reason = "stop" ######## bedrock.anthropic mappings ############### - elif "completion" in chunk_data: # not claude-3 - text = chunk_data["completion"] # bedrock.anthropic - stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason elif "delta" in chunk_data: - if chunk_data["delta"].get("text", None) is not None: - text = chunk_data["delta"]["text"] - stop_reason = chunk_data["delta"].get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason + return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: if ( @@ -1851,11 +1967,17 @@ class AWSEventStreamDecoder: def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") + if "chunk" in parsed_response: + chunk = parsed_response.get("chunk") + if not chunk: + return None + return chunk.get("bytes").decode() # type: ignore[no-any-return] + else: + chunk = response_dict.get("body") + if not chunk: + return None - chunk = response_dict.get("body") - if not chunk: - return None - - return chunk.decode() # type: ignore[no-any-return] + return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index d8dd4f01e4..5ec9c79bb2 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -168,7 +168,6 @@ class HTTPHandler: return response def __del__(self) -> None: - traceback.print_stack() try: self.close() except Exception: diff --git a/litellm/main.py b/litellm/main.py index c95b419ba2..15334d0414 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -121,7 +121,8 @@ azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() -bedrock_chat_completion = BedrockConverseLLM() +bedrock_chat_completion = BedrockLLM() +bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ @@ -2097,22 +2098,40 @@ def completion( logging_obj=logging, ) else: - response = bedrock_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) + if model.startswith("anthropic"): + response = bedrock_converse_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + ) + else: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + ) if optional_params.get("stream", False): ## LOGGING logging.post_call( diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt deleted file mode 100644 index ea07ca7e12..0000000000 --- a/litellm/tests/log.txt +++ /dev/null @@ -1,4274 +0,0 @@ -============================= test session starts ============================== -platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -rootdir: /Users/krrishdholakia/Documents/litellm -configfile: pyproject.toml -plugins: asyncio-0.23.6, mock-3.14.0, anyio-4.2.0 -asyncio: mode=Mode.STRICT -collected 1 item - -test_amazing_vertex_completion.py F [100%] - -=================================== FAILURES =================================== -____________________________ test_gemini_pro_vision ____________________________ - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -model_response = ModelResponse(id='chatcmpl-722df0e7-4e2d-44e6-9e2c-49823faa0189', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1716145725, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = -logging_obj = -vertex_project = None, vertex_location = None, vertex_credentials = None -optional_params = {} -litellm_params = {'acompletion': False, 'api_base': '', 'api_key': None, 'completion_call_id': None, ...} -logger_fn = None, acompletion = False - - def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - ): - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - ) - from google.cloud import aiplatform # type: ignore - from google.protobuf import json_format # type: ignore - from google.protobuf.struct_pb2 import Value # type: ignore - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - import google.auth # type: ignore - import proto # type: ignore - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" - ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - json_obj = json.loads(vertex_credentials) - - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - - ## Load Config - config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## Process safety settings into format expected by vertex AI - safety_settings = None - if "safety_settings" in optional_params: - safety_settings = optional_params.pop("safety_settings") - if not isinstance(safety_settings, list): - raise ValueError("safety_settings must be a list") - if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): - raise ValueError("safety_settings must be a list of dicts") - safety_settings = [ - gapic_content_types.SafetySetting(x) for x in safety_settings - ] - - # vertexai does not use an API key, it looks for credentials.json in the environment - - prompt = " ".join( - [ - message["content"] - for message in messages - if isinstance(message["content"], str) - ] - ) - - mode = "" - - request_str = "" - response_obj = None - async_client = None - instances = None - client_options = { - "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" - } - if ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - llm_model = GenerativeModel(model) - mode = "vision" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_chat_models: - llm_model = ChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = ChatModel.from_pretrained({model})\n" - elif model in litellm.vertex_text_models: - llm_model = TextGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_text_models: - llm_model = CodeGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models - llm_model = CodeChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - elif model == "private": - mode = "private" - model = optional_params.pop("model_id", None) - # private endpoint requires a dict instead of JSON - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - llm_model = aiplatform.PrivateEndpoint( - endpoint_name=model, - project=vertex_project, - location=vertex_location, - ) - request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" - else: # assume vertex model garden on public endpoint - mode = "custom" - - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - # Will determine the API used based on async parameter - llm_model = None - - # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: - data = { - "llm_model": llm_model, - "mode": mode, - "prompt": prompt, - "logging_obj": logging_obj, - "request_str": request_str, - "model": model, - "model_response": model_response, - "encoding": encoding, - "messages": messages, - "print_verbose": print_verbose, - "client_options": client_options, - "instances": instances, - "vertex_location": vertex_location, - "vertex_project": vertex_project, - "safety_settings": safety_settings, - **optional_params, - } - if optional_params.get("stream", False) is True: - # async streaming - return async_streaming(**data) - - return async_completion(**data) - - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) - stream = optional_params.pop("stream", False) - if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - model_response = llm_model.generate_content( - contents={"content": content}, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call -> response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - -../llms/vertex_ai.py:740: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:405: in generate_content - return self._generate_content( -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:487: in _generate_content - request = self._prepare_request( -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:274: in _prepare_request - contents = [ -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:275: in - gapic_content_types.Content(content_dict) for content_dict in contents -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -self = <[AttributeError('Unknown field for Content: _pb') raised in repr()] Content object at 0x1646aaa90> -mapping = {'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -], 'role': 'user'} -ignore_unknown_fields = False, kwargs = {} -params = {'parts': [text: "Whats in this image?" -, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -], 'role': 'user'} -marshal = , key = 'parts' -value = [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -] -pb_value = [text: "Whats in this image?" -, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -] - - def __init__( - self, - mapping=None, - *, - ignore_unknown_fields=False, - **kwargs, - ): - # We accept several things for `mapping`: - # * An instance of this class. - # * An instance of the underlying protobuf descriptor class. - # * A dict - # * Nothing (keyword arguments only). - if mapping is None: - if not kwargs: - # Special fast path for empty construction. - super().__setattr__("_pb", self._meta.pb()) - return - - mapping = kwargs - elif isinstance(mapping, self._meta.pb): - # Make a copy of the mapping. - # This is a constructor for a new object, so users will assume - # that it will not have side effects on the arguments being - # passed in. - # - # The `wrap` method on the metaclass is the public API for taking - # ownership of the passed in protobuf object. - mapping = copy.deepcopy(mapping) - if kwargs: - mapping.MergeFrom(self._meta.pb(**kwargs)) - - super().__setattr__("_pb", mapping) - return - elif isinstance(mapping, type(self)): - # Just use the above logic on mapping's underlying pb. - self.__init__(mapping=mapping._pb, **kwargs) - return - elif isinstance(mapping, collections.abc.Mapping): - # Can't have side effects on mapping. - mapping = copy.copy(mapping) - # kwargs entries take priority for duplicate keys. - mapping.update(kwargs) - else: - # Sanity check: Did we get something not a map? Error if so. - raise TypeError( - "Invalid constructor input for %s: %r" - % ( - self.__class__.__name__, - mapping, - ) - ) - - params = {} - # Update the mapping to address any values that need to be - # coerced. - marshal = self._meta.marshal - for key, value in mapping.items(): - (key, pb_type) = self._get_pb_type_from_key(key) - if pb_type is None: - if ignore_unknown_fields: - continue - - raise ValueError( - "Unknown field for {}: {}".format(self.__class__.__name__, key) - ) - - try: - pb_value = marshal.to_proto(pb_type, value) - except ValueError: - # Underscores may be appended to field names - # that collide with python or proto-plus keywords. - # In case a key only exists with a `_` suffix, coerce the key - # to include the `_` suffix. It's not possible to - # natively define the same field with a trailing underscore in protobuf. - # See related issue - # https://github.com/googleapis/python-api-core/issues/227 - if isinstance(value, dict): - if _upb: - # In UPB, pb_type is MessageMeta which doesn't expose attrs like it used to in Python/CPP. - keys_to_update = [ - item - for item in value - if item not in pb_type.DESCRIPTOR.fields_by_name - and f"{item}_" in pb_type.DESCRIPTOR.fields_by_name - ] - else: - keys_to_update = [ - item - for item in value - if not hasattr(pb_type, item) - and hasattr(pb_type, f"{item}_") - ] - for item in keys_to_update: - value[f"{item}_"] = value.pop(item) - - pb_value = marshal.to_proto(pb_type, value) - - if pb_value is not None: - params[key] = pb_value - - # Create the internal protocol buffer. -> super().__setattr__("_pb", self._meta.pb(**params)) -E TypeError: Parameter to MergeFrom() must be instance of same class: expected got . - -../proxy/myenv/lib/python3.11/site-packages/proto/message.py:615: TypeError - -During handling of the above exception, another exception occurred: - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -timeout = 600.0, temperature = None, top_p = None, n = None, stream = None -stream_options = None, stop = None, max_tokens = None, presence_penalty = None -frequency_penalty = None, logit_bias = None, user = None, response_format = None -seed = None, tools = None, tool_choice = None, logprobs = None -top_logprobs = None, deployment_id = None, extra_headers = None -functions = None, function_call = None, base_url = None, api_version = None -api_key = None, model_list = None -kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } -args = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -api_base = None, mock_response = None, force_timeout = 600, logger_fn = None -verbose = False, custom_llm_provider = 'vertex_ai' - - @client - def completion( - model: str, - # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create - messages: List = [], - timeout: Optional[Union[float, str, httpx.Timeout]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stream_options: Optional[dict] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[dict] = None, - user: Optional[str] = None, - # openai v1.0+ new params - response_format: Optional[dict] = None, - seed: Optional[int] = None, - tools: Optional[List] = None, - tool_choice: Optional[str] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, - deployment_id=None, - extra_headers: Optional[dict] = None, - # soon to be deprecated params by OpenAI - functions: Optional[List] = None, - function_call: Optional[str] = None, - # set api_base, api_version, api_key - base_url: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - # Optional liteLLM function params - **kwargs, - ) -> Union[ModelResponse, CustomStreamWrapper]: - """ - Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) - Parameters: - model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ - messages (List): A list of message objects representing the conversation context (default is an empty list). - - OPTIONAL PARAMS - functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). - function_call (str, optional): The name of the function to call within the conversation (default is an empty string). - temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). - top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). - n (int, optional): The number of completions to generate (default is 1). - stream (bool, optional): If True, return a streaming response (default is False). - stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. - stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. - max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). - presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. - frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. - logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. - user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message - top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. - api_base (str, optional): Base URL for the API (default is None). - api_version (str, optional): API version (default is None). - api_key (str, optional): API key (default is None). - model_list (list, optional): List of api base, version, keys - extra_headers (dict, optional): Additional headers to include in the request. - - LITELLM Specific Params - mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). - custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" - max_retries (int, optional): The number of retries to attempt (default is 0). - Returns: - ModelResponse: A response object containing the generated completion and associated metadata. - - Note: - - This function is used to perform completions() using the specified language model. - - It supports various optional parameters for customizing the completion behavior. - - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. - """ - ######### unpacking kwargs ##################### - args = locals() - api_base = kwargs.get("api_base", None) - mock_response = kwargs.get("mock_response", None) - force_timeout = kwargs.get("force_timeout", 600) ## deprecated - logger_fn = kwargs.get("logger_fn", None) - verbose = kwargs.get("verbose", False) - custom_llm_provider = kwargs.get("custom_llm_provider", None) - litellm_logging_obj = kwargs.get("litellm_logging_obj", None) - id = kwargs.get("id", None) - metadata = kwargs.get("metadata", None) - model_info = kwargs.get("model_info", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - fallbacks = kwargs.get("fallbacks", None) - headers = kwargs.get("headers", None) or extra_headers - num_retries = kwargs.get("num_retries", None) ## deprecated - max_retries = kwargs.get("max_retries", None) - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - organization = kwargs.get("organization", None) - ### CUSTOM MODEL COST ### - input_cost_per_token = kwargs.get("input_cost_per_token", None) - output_cost_per_token = kwargs.get("output_cost_per_token", None) - input_cost_per_second = kwargs.get("input_cost_per_second", None) - output_cost_per_second = kwargs.get("output_cost_per_second", None) - ### CUSTOM PROMPT TEMPLATE ### - initial_prompt_value = kwargs.get("initial_prompt_value", None) - roles = kwargs.get("roles", None) - final_prompt_value = kwargs.get("final_prompt_value", None) - bos_token = kwargs.get("bos_token", None) - eos_token = kwargs.get("eos_token", None) - preset_cache_key = kwargs.get("preset_cache_key", None) - hf_model_name = kwargs.get("hf_model_name", None) - supports_system_message = kwargs.get("supports_system_message", None) - ### TEXT COMPLETION CALLS ### - text_completion = kwargs.get("text_completion", False) - atext_completion = kwargs.get("atext_completion", False) - ### ASYNC CALLS ### - acompletion = kwargs.get("acompletion", False) - client = kwargs.get("client", None) - ### Admin Controls ### - no_log = kwargs.get("no-log", False) - ######## end of unpacking kwargs ########### - openai_params = [ - "functions", - "function_call", - "temperature", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] - litellm_params = [ - "metadata", - "acompletion", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - "model_config", - ] - - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - - try: - if base_url is not None: - api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) - num_retries = max_retries - logging = litellm_logging_obj - fallbacks = fallbacks or litellm.model_fallbacks - if fallbacks is not None: - return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [ - m["litellm_params"] for m in model_list if m["model_name"] == model - ] - return batch_completion_models(deployments=deployments, **args) - if litellm.model_alias_map and model in litellm.model_alias_map: - model = litellm.model_alias_map[ - model - ] # update the model to the actual value if an alias has been passed in - model_response = ModelResponse() - setattr(model_response, "usage", litellm.Usage()) - if ( - kwargs.get("azure", False) == True - ): # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider = "azure" - if deployment_id != None: # azure llms - model = deployment_id - custom_llm_provider = "azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - api_key=api_key, - ) - if model_response is not None and hasattr(model_response, "_hidden_params"): - model_response._hidden_params["custom_llm_provider"] = custom_llm_provider - model_response._hidden_params["region_name"] = kwargs.get( - "aws_region_name", None - ) # support region-based pricing for bedrock - - ### TIMEOUT LOGIC ### - timeout = timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default - if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout( - custom_llm_provider - ): - timeout = timeout.read or 600 # default 10 min timeout - elif not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - } - ) - elif ( - input_cost_per_second is not None - ): # time based pricing just needs cost in place - output_cost_per_second = output_cost_per_second - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - } - ) - ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if ( - initial_prompt_value - or roles - or final_prompt_value - or bos_token - or eos_token - ): - custom_prompt_dict = {model: {}} - if initial_prompt_value: - custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: - custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: - custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value - if bos_token: - custom_prompt_dict[model]["bos_token"] = bos_token - if eos_token: - custom_prompt_dict[model]["eos_token"] = eos_token - - if ( - supports_system_message is not None - and isinstance(supports_system_message, bool) - and supports_system_message == False - ): - messages = map_system_message_pt(messages=messages) - model_api_key = get_api_key( - llm_provider=custom_llm_provider, dynamic_api_key=api_key - ) # get the api key from the environment if required for the model - - if dynamic_api_key is not None: - api_key = dynamic_api_key - # check if user passed in any of the OpenAI optional params - optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stream_options=stream_options, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - logprobs=logprobs, - top_logprobs=top_logprobs, - extra_headers=extra_headers, - **non_default_params, - ) - - if litellm.add_function_to_prompt and optional_params.get( - "functions_unsupported_model", None - ): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop( - "functions_unsupported_model" - ) - messages = function_call_prompt( - messages=messages, functions=functions_unsupported_model - ) - - # For logging - save the values of the litellm-specific params passed in - litellm_params = get_litellm_params( - acompletion=acompletion, - api_key=api_key, - force_timeout=force_timeout, - logger_fn=logger_fn, - verbose=verbose, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - litellm_call_id=kwargs.get("litellm_call_id", None), - model_alias_map=litellm.model_alias_map, - completion_call_id=id, - metadata=metadata, - model_info=model_info, - proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key, - no_log=no_log, - input_cost_per_second=input_cost_per_second, - input_cost_per_token=input_cost_per_token, - output_cost_per_second=output_cost_per_second, - output_cost_per_token=output_cost_per_token, - ) - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params=litellm_params, - ) - if mock_response: - return mock_completion( - model, - messages, - stream=stream, - mock_response=mock_response, - logging=logging, - acompletion=acompletion, - ) - if custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif custom_llm_provider == "azure_text": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_text_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif ( - model in litellm.open_ai_chat_completion_models - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openai" - or custom_llm_provider == "together_ai" - or custom_llm_provider in litellm.openai_compatible_providers - or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo - ): # allow user to make an openai call with a custom base - # note: if a user sets a custom base - we should ensure this works - # allow for the setting of dynamic and stateful api-bases - api_base = ( - api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - openai.organization = ( - organization - or litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - try: - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, - client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, - custom_llm_provider=custom_llm_provider, - ) - except Exception as e: - ## LOGGING - log the original exception returned - logging.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"headers": headers}, - ) - raise e - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={"headers": headers}, - ) - elif ( - custom_llm_provider == "text-completion-openai" - or "ft:babbage-002" in model - or "ft:davinci-002" in model # support for finetuned completion models - ): - openai.api_type = "openai" - - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - - openai.api_version = None - # set API KEY - - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAITextCompletionConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - if litellm.organization: - openai.organization = litellm.organization - - if ( - len(messages) > 0 - and "content" in messages[0] - and type(messages[0]["content"]) == list - ): - # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] - # https://platform.openai.com/docs/api-reference/completions/create - prompt = messages[0]["content"] - else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore - - ## COMPLETION CALL - _response = openai_text_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - client=client, # pass AsyncOpenAI, OpenAI client - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - ) - - if ( - optional_params.get("stream", False) == False - and acompletion == False - and text_completion == False - ): - # convert to chat completion response - _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( - response_object=_response, model_response_object=model_response - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=_response, - additional_args={"headers": headers}, - ) - response = _response - elif ( - "replicate" in model - or custom_llm_provider == "replicate" - or model in litellm.replicate_models - ): - # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") - replicate_key = None - replicate_key = ( - api_key - or litellm.replicate_key - or litellm.api_key - or get_secret("REPLICATE_API_KEY") - or get_secret("REPLICATE_API_TOKEN") - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("REPLICATE_API_BASE") - or "https://api.replicate.com/v1" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = replicate.completion( # type: ignore - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - acompletion=acompletion, - ) - - if optional_params.get("stream", False) == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=replicate_key, - original_response=model_response, - ) - - response = model_response - elif ( - "clarifai" in model - or custom_llm_provider == "clarifai" - or model in litellm.clarifai_models - ): - clarifai_key = None - clarifai_key = ( - api_key - or litellm.clarifai_key - or litellm.api_key - or get_secret("CLARIFAI_API_KEY") - or get_secret("CLARIFAI_API_TOKEN") - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("CLARIFAI_API_BASE") - or "https://api.clarifai.com/v2" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = clarifai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - acompletion=acompletion, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=clarifai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=model_response, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=clarifai_key, - original_response=model_response, - ) - response = model_response - - elif custom_llm_provider == "anthropic": - api_key = ( - api_key - or litellm.anthropic_key - or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") - ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - if (model == "claude-2") or (model == "claude-instant-1"): - # call anthropic /completion, only use this route for claude-2, claude-instant-1 - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/complete" - ) - response = anthropic_text_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - else: - # call /messages - # default route for all anthropic models - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/messages" - ) - response = anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - response = response - elif custom_llm_provider == "nlp_cloud": - nlp_cloud_key = ( - api_key - or litellm.nlp_cloud_key - or get_secret("NLP_CLOUD_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("NLP_CLOUD_API_BASE") - or "https://api.nlpcloud.io/v1/gpu/" - ) - - response = nlp_cloud.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=nlp_cloud_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="nlp_cloud", - logging_obj=logging, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - - response = response - elif custom_llm_provider == "aleph_alpha": - aleph_alpha_key = ( - api_key - or litellm.aleph_alpha_key - or get_secret("ALEPH_ALPHA_API_KEY") - or get_secret("ALEPHALPHA_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("ALEPH_ALPHA_API_BASE") - or "https://api.aleph-alpha.com/complete" - ) - - model_response = aleph_alpha.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - default_max_tokens_to_sample=litellm.max_tokens, - api_key=aleph_alpha_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="aleph_alpha", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/generate" - ) - - model_response = cohere.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere_chat": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/chat" - ) - - model_response = cohere_chat.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "maritalk": - maritalk_key = ( - api_key - or litellm.maritalk_key - or get_secret("MARITALK_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("MARITALK_API_BASE") - or "https://chat.maritaca.ai/api/chat/inference" - ) - - model_response = maritalk.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=maritalk_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="maritalk", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "huggingface": - custom_llm_provider = "huggingface" - huggingface_key = ( - api_key - or litellm.huggingface_key - or os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_API_KEY") - or litellm.api_key - ) - hf_headers = headers or litellm.headers - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = huggingface.completion( - model=model, - messages=messages, - api_base=api_base, # type: ignore - headers=hf_headers, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, - acompletion=acompletion, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - timeout=timeout, # type: ignore - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion is False - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="huggingface", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "oobabooga": - custom_llm_provider = "oobabooga" - model_response = oobabooga.completion( - model=model, - messages=messages, - model_response=model_response, - api_base=api_base, # type: ignore - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - api_key=None, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="oobabooga", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "openrouter": - api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" - - api_key = ( - api_key - or litellm.api_key - or litellm.openrouter_key - or get_secret("OPENROUTER_API_KEY") - or get_secret("OR_API_KEY") - ) - - openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - - openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" - - headers = ( - headers - or litellm.headers - or { - "HTTP-Referer": openrouter_site_url, - "X-Title": openrouter_app_name, - } - ) - - ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): - if k == "extra_body": - # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: - optional_params[k].update(v) - else: - optional_params[k] = v - elif k not in optional_params: - optional_params[k] = v - - data = {"model": model, "messages": messages, **optional_params} - - ## COMPLETION CALL - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - ) - ## LOGGING - logging.post_call( - input=messages, api_key=openai.api_key, original_response=response - ) - elif ( - custom_llm_provider == "together_ai" - or ("togethercomputer" in model) - or (model in litellm.together_ai_models) - ): - """ - Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility - """ - custom_llm_provider = "together_ai" - together_ai_key = ( - api_key - or litellm.togetherai_api_key - or get_secret("TOGETHER_AI_TOKEN") - or get_secret("TOGETHERAI_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("TOGETHERAI_API_BASE") - or "https://api.together.xyz/inference" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = together_ai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=together_ai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream_tokens" in optional_params - and optional_params["stream_tokens"] == True - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="together_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "palm": - palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key - - # palm does not support streaming as yet :( - model_response = palm.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=palm_api_key, - logging_obj=logging, - ) - # fake palm streaming - if "stream" in optional_params and optional_params["stream"] == True: - # fake streaming for palm - resp_string = model_response["choices"][0]["message"]["content"] - response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="palm", logging_obj=logging - ) - return response - response = model_response - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key - or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work - or litellm.api_key - ) - - # palm does not support streaming as yet :( - model_response = gemini.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=gemini_api_key, - logging_obj=logging, - acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - iter(model_response), - model, - custom_llm_provider="gemini", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) - new_params = deepcopy(optional_params) - if "claude-3" in model: - model_response = vertex_ai_anthropic.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - else: -> model_response = vertex_ai.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - -../main.py:1824: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -model_response = ModelResponse(id='chatcmpl-722df0e7-4e2d-44e6-9e2c-49823faa0189', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1716145725, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = -logging_obj = -vertex_project = None, vertex_location = None, vertex_credentials = None -optional_params = {} -litellm_params = {'acompletion': False, 'api_base': '', 'api_key': None, 'completion_call_id': None, ...} -logger_fn = None, acompletion = False - - def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - ): - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - ) - from google.cloud import aiplatform # type: ignore - from google.protobuf import json_format # type: ignore - from google.protobuf.struct_pb2 import Value # type: ignore - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - import google.auth # type: ignore - import proto # type: ignore - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" - ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - json_obj = json.loads(vertex_credentials) - - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - - ## Load Config - config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## Process safety settings into format expected by vertex AI - safety_settings = None - if "safety_settings" in optional_params: - safety_settings = optional_params.pop("safety_settings") - if not isinstance(safety_settings, list): - raise ValueError("safety_settings must be a list") - if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): - raise ValueError("safety_settings must be a list of dicts") - safety_settings = [ - gapic_content_types.SafetySetting(x) for x in safety_settings - ] - - # vertexai does not use an API key, it looks for credentials.json in the environment - - prompt = " ".join( - [ - message["content"] - for message in messages - if isinstance(message["content"], str) - ] - ) - - mode = "" - - request_str = "" - response_obj = None - async_client = None - instances = None - client_options = { - "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" - } - if ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - llm_model = GenerativeModel(model) - mode = "vision" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_chat_models: - llm_model = ChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = ChatModel.from_pretrained({model})\n" - elif model in litellm.vertex_text_models: - llm_model = TextGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_text_models: - llm_model = CodeGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models - llm_model = CodeChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - elif model == "private": - mode = "private" - model = optional_params.pop("model_id", None) - # private endpoint requires a dict instead of JSON - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - llm_model = aiplatform.PrivateEndpoint( - endpoint_name=model, - project=vertex_project, - location=vertex_location, - ) - request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" - else: # assume vertex model garden on public endpoint - mode = "custom" - - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - # Will determine the API used based on async parameter - llm_model = None - - # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: - data = { - "llm_model": llm_model, - "mode": mode, - "prompt": prompt, - "logging_obj": logging_obj, - "request_str": request_str, - "model": model, - "model_response": model_response, - "encoding": encoding, - "messages": messages, - "print_verbose": print_verbose, - "client_options": client_options, - "instances": instances, - "vertex_location": vertex_location, - "vertex_project": vertex_project, - "safety_settings": safety_settings, - **optional_params, - } - if optional_params.get("stream", False) is True: - # async streaming - return async_streaming(**data) - - return async_completion(**data) - - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) - stream = optional_params.pop("stream", False) - if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - model_response = llm_model.generate_content( - contents={"content": content}, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call - response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - - if tools is not None and bool( - getattr(response.candidates[0].content.parts[0], "function_call", None) - ): - function_call = response.candidates[0].content.parts[0].function_call - args_dict = {} - - # Check if it's a RepeatedComposite instance - for key, val in function_call.args.items(): - if isinstance( - val, proto.marshal.collections.repeated.RepeatedComposite - ): - # If so, convert to list - args_dict[key] = [v for v in val] - else: - args_dict[key] = val - - try: - args_str = json.dumps(args_dict) - except Exception as e: - raise VertexAIError(status_code=422, message=str(e)) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "arguments": args_str, - "name": function_call.name, - }, - "type": "function", - } - ], - ) - completion_response = message - else: - completion_response = response.text - response_obj = response._raw_response - optional_params["tools"] = tools - elif mode == "chat": - chat = llm_model.start_chat() - request_str += f"chat = llm_model.start_chat()\n" - - if "stream" in optional_params and optional_params["stream"] == True: - # NOTE: VertexAI does not accept stream=True as a param and raises an error, - # we handle this by removing 'stream' from optional params and sending the request - # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format - optional_params.pop( - "stream", None - ) # vertex ai raises an error when passing stream in optional params - request_str += ( - f"chat.send_message_streaming({prompt}, **{optional_params})\n" - ) - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = chat.send_message_streaming(prompt, **optional_params) - - return model_response - - request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - completion_response = chat.send_message(prompt, **optional_params).text - elif mode == "text": - if "stream" in optional_params and optional_params["stream"] == True: - optional_params.pop( - "stream", None - ) # See note above on handling streaming for vertex ai - request_str += ( - f"llm_model.predict_streaming({prompt}, **{optional_params})\n" - ) - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = llm_model.predict_streaming(prompt, **optional_params) - - return model_response - - request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - completion_response = llm_model.predict(prompt, **optional_params).text - elif mode == "custom": - """ - Vertex AI Model Garden - """ - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - llm_model = aiplatform.gapic.PredictionServiceClient( - client_options=client_options - ) - request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n" - endpoint_path = llm_model.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) - request_str += ( - f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" - ) - response = llm_model.predict( - endpoint=endpoint_path, instances=instances - ).predictions - - completion_response = response[0] - if ( - isinstance(completion_response, str) - and "\nOutput:\n" in completion_response - ): - completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: - response = TextStreamer(completion_response) - return response - elif mode == "private": - """ - Vertex AI Model Garden deployed on private endpoint - """ - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - request_str += f"llm_model.predict(instances={instances})\n" - response = llm_model.predict(instances=instances).predictions - - completion_response = response[0] - if ( - isinstance(completion_response, str) - and "\nOutput:\n" in completion_response - ): - completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: - response = TextStreamer(completion_response) - return response - - ## LOGGING - logging_obj.post_call( - input=prompt, api_key=None, original_response=completion_response - ) - - ## RESPONSE OBJECT - if isinstance(completion_response, litellm.Message): - model_response["choices"][0]["message"] = completion_response - elif len(str(completion_response)) > 0: - model_response["choices"][0]["message"]["content"] = str( - completion_response - ) - model_response["created"] = int(time.time()) - model_response["model"] = model - ## CALCULATING USAGE - if model in litellm.vertex_language_models and response_obj is not None: - model_response["choices"][0].finish_reason = map_finish_reason( - response_obj.candidates[0].finish_reason.name - ) - usage = Usage( - prompt_tokens=response_obj.usage_metadata.prompt_token_count, - completion_tokens=response_obj.usage_metadata.candidates_token_count, - total_tokens=response_obj.usage_metadata.total_token_count, - ) - else: - # init prompt tokens - # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter - prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 - if response_obj is not None: - if hasattr(response_obj, "usage_metadata") and hasattr( - response_obj.usage_metadata, "prompt_token_count" - ): - prompt_tokens = response_obj.usage_metadata.prompt_token_count - completion_tokens = ( - response_obj.usage_metadata.candidates_token_count - ) - else: - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode( - model_response["choices"][0]["message"].get("content", "") - ) - ) - - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response - except Exception as e: - if isinstance(e, VertexAIError): - raise e -> raise VertexAIError(status_code=500, message=str(e)) -E litellm.llms.vertex_ai.VertexAIError: Parameter to MergeFrom() must be instance of same class: expected got . - -../llms/vertex_ai.py:971: VertexAIError - -During handling of the above exception, another exception occurred: - -args = () -kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': -call_type = 'completion', model = 'vertex_ai/gemini-1.5-flash-preview-0514' -k = 'litellm_logging_obj' - - @wraps(original_function) - def wrapper(*args, **kwargs): - # DO NOT MOVE THIS. It always needs to run first - # Check if this is an async function. If so only execute the async function - if ( - kwargs.get("acompletion", False) == True - or kwargs.get("aembedding", False) == True - or kwargs.get("aimg_generation", False) == True - or kwargs.get("amoderation", False) == True - or kwargs.get("atext_completion", False) == True - or kwargs.get("atranscription", False) == True - ): - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # MODEL CALL - result = original_function(*args, **kwargs) - if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): - chunks = [] - for idx, chunk in enumerate(result): - chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: - return result - - return result - - # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print - print_args_passed_to_litellm(original_function, args, kwargs) - start_time = datetime.datetime.now() - result = None - logging_obj = kwargs.get("litellm_logging_obj", None) - - # only set litellm_call_id if its not in kwargs - call_type = original_function.__name__ - if "litellm_call_id" not in kwargs: - kwargs["litellm_call_id"] = str(uuid.uuid4()) - try: - model = args[0] if len(args) > 0 else kwargs["model"] - except: - model = None - if ( - call_type != CallTypes.image_generation.value - and call_type != CallTypes.text_completion.value - ): - raise ValueError("model param not passed in.") - - try: - if logging_obj is None: - logging_obj, kwargs = function_setup( - original_function.__name__, rules_obj, start_time, *args, **kwargs - ) - kwargs["litellm_logging_obj"] = logging_obj - - # CHECK FOR 'os.environ/' in kwargs - for k, v in kwargs.items(): - if v is not None and isinstance(v, str) and v.startswith("os.environ/"): - kwargs[k] = litellm.get_secret(v) - # [OPTIONAL] CHECK BUDGET - if litellm.max_budget: - if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, - ) - - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # [OPTIONAL] CHECK CACHE - print_verbose( - f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" - ) - # if caching is false or cache["no-cache"]==True, don't run this - if ( - ( - ( - ( - kwargs.get("caching", None) is None - and litellm.cache is not None - ) - or kwargs.get("caching", False) == True - ) - and kwargs.get("cache", {}).get("no-cache", False) != True - ) - and kwargs.get("aembedding", False) != True - and kwargs.get("atext_completion", False) != True - and kwargs.get("acompletion", False) != True - and kwargs.get("aimg_generation", False) != True - and kwargs.get("atranscription", False) != True - ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose(f"Checking Cache") - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: - if "detail" in cached_result: - # implies an error occurred - pass - else: - call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), - ) - - if kwargs.get("stream", False) == True: - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", - ) - - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": False, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - return cached_result - - # CHECK MAX TOKENS - if ( - kwargs.get("max_tokens", None) is not None - and model is not None - and litellm.modify_params - == True # user is okay with params being modified - and ( - call_type == CallTypes.acompletion.value - or call_type == CallTypes.completion.value - ) - ): - try: - base_model = model - if kwargs.get("hf_model_name", None) is not None: - base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit - messages = None - if len(args) > 1: - messages = args[1] - elif kwargs.get("messages", None): - messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int - except Exception as e: - print_verbose(f"Error while checking max token limit: {str(e)}") - # MODEL CALL -> result = original_function(*args, **kwargs) - -../utils.py:3211: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../main.py:2368: in completion - raise exception_type( -../utils.py:9709: in exception_type - raise e -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -original_exception = VertexAIError("Parameter to MergeFrom() must be instance of same class: expected got .") -custom_llm_provider = 'vertex_ai' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -extra_kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } - - def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - extra_kwargs={}, - ): - global user_logger_fn, liteDebuggerClient - exception_mapping_worked = False - if litellm.suppress_debug_info is False: - print() # noqa - print( # noqa - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa - ) # noqa - print( # noqa - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa - ) # noqa - print() # noqa - try: - if model: - error_str = str(original_exception) - if isinstance(original_exception, BaseException): - exception_type = type(original_exception).__name__ - else: - exception_type = "" - - ################################################################################ - # Common Extra information needed for all providers - # We pass num retries, api_base, vertex_deployment etc to the exception here - ################################################################################ - extra_information = "" - try: - _api_base = litellm.get_api_base( - model=model, optional_params=extra_kwargs - ) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" - - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) - except: - # DO NOT LET this Block raising the original exception - pass - - ################################################################################ - # End of Common Extra information Needed for all providers - ################################################################################ - - ################################################################################ - #################### Start of Provider Exception mapping #################### - ################################################################################ - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: - exception_mapping_worked = True - raise Timeout( - message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) - if message is not None and isinstance(message, str): - message = message.replace("OPENAI", custom_llm_provider.upper()) - message = message.replace("openai", custom_llm_provider) - message = message.replace("OpenAI", custom_llm_provider) - if custom_llm_provider == "openai": - exception_provider = "OpenAI" + "Exception" - else: - exception_provider = ( - custom_llm_provider[0].upper() - + custom_llm_provider[1:] - + "Exception" - ) - - if "This model's maximum context length is" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "content_policy_violation" in error_str - ): - exception_mapping_worked = True - raise ContentPolicyViolationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Request too large" in error_str: - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Mistral API raised a streaming error" in error_str: - exception_mapping_worked = True - _request = httpx.Request( - method="POST", url="https://api.openai.com/v1" - ) - raise APIError( - status_code=500, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=_request, - litellm_debug_info=extra_information, - ) - elif hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=original_exception.request, - litellm_debug_info=extra_information, - ) - else: - # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors - raise APIConnectionError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - litellm_debug_info=extra_information, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if hasattr(original_exception, "status_code"): - print_verbose(f"status_code: {original_exception.status_code}") - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", - llm_provider="anthropic", - model=model, - request=original_exception.request, - ) - elif custom_llm_provider == "replicate": - if "Incorrect authentication token" in error_str: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif "input is too long" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif exception_type == "ModelError": - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif "Request was throttled" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"ReplicateException - {str(original_exception)}", - llm_provider="replicate", - model=model, - request=httpx.Request( - method="POST", - url="https://api.replicate.com/v1/deployments", - ), - ) - elif custom_llm_provider == "watsonx": - if "token_quota_reached" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"WatsonxException: Rate Limit Errror - {error_str}", - llm_provider="watsonx", - model=model, - response=original_exception.response, - ) - elif custom_llm_provider == "predibase": - if "authorization denied for" in error_str: - exception_mapping_worked = True - - # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception - if ( - error_str is not None - and isinstance(error_str, str) - and "bearer" in error_str.lower() - ): - # only keep the first 10 chars after the occurnence of "bearer" - _bearer_token_start_index = error_str.lower().find("bearer") - error_str = error_str[: _bearer_token_start_index + 14] - error_str += "XXXXXXX" + '"' - - raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "prompt: length: 1.." in error_str - or "Too many input tokens" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"BedrockException: Context Window Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "Malformed input request" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "AccessDeniedException" in error_str: - exception_mapping_worked = True - raise PermissionDeniedError( - message=f"BedrockException PermissionDeniedError - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Connect timeout on endpoint URL" in error_str - or "timed out" in error_str - ): - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException: Timeout Error - {error_str}", - model=model, - llm_provider="bedrock", - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=httpx.Response( - status_code=500, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ), - ) - elif original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "sagemaker": - if "Unable to locate credentials" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "`inputs` tokens + `max_new_tokens` must be <=" in error_str - or "instance type with more CPU capacity or memory" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "None Unknown Error." in error_str - or "Content has no parts." in error_str - ): - exception_mapping_worked = True - raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - request=original_exception.request, - litellm_debug_info=extra_information, - ) - elif "403" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "The response was blocked." in error_str: - exception_mapping_worked = True - raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - elif ( - "429 Quota exceeded" in error_str - or "IndexError: list index out of range" in error_str - or "429 Unable to submit request because the service is temporarily out of capacity." - in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - if hasattr(original_exception, "status_code"): - if original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=original_exception.response, - ) - if original_exception.status_code == 500: - exception_mapping_worked = True -> raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - request=original_exception.request, -E litellm.exceptions.APIError: VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -../utils.py:8922: APIError - -During handling of the above exception, another exception occurred: - - def test_gemini_pro_vision(): - try: - load_vertex_ai_credentials() - litellm.set_verbose = True - litellm.num_retries = 3 -> resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - }, - }, - ], - } - ], - ) - -test_amazing_vertex_completion.py:510: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../utils.py:3289: in wrapper - return litellm.completion_with_retries(*args, **kwargs) -../main.py:2401: in completion_with_retries - return retryer(original_function, *args, **kwargs) -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:379: in __call__ - do = self.iter(retry_state=retry_state) -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:325: in iter - raise retry_exc.reraise() -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:158: in reraise - raise self.last_attempt.result() -/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/_base.py:449: in result - return self.__get_result() -/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/_base.py:401: in __get_result - raise self._exception -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:382: in __call__ - result = fn(*args, **kwargs) -../utils.py:3317: in wrapper - raise e -../utils.py:3211: in wrapper - result = original_function(*args, **kwargs) -../main.py:2368: in completion - raise exception_type( -../utils.py:9709: in exception_type - raise e -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -original_exception = VertexAIError("Parameter to MergeFrom() must be instance of same class: expected got .") -custom_llm_provider = 'vertex_ai' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -extra_kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } - - def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - extra_kwargs={}, - ): - global user_logger_fn, liteDebuggerClient - exception_mapping_worked = False - if litellm.suppress_debug_info is False: - print() # noqa - print( # noqa - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa - ) # noqa - print( # noqa - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa - ) # noqa - print() # noqa - try: - if model: - error_str = str(original_exception) - if isinstance(original_exception, BaseException): - exception_type = type(original_exception).__name__ - else: - exception_type = "" - - ################################################################################ - # Common Extra information needed for all providers - # We pass num retries, api_base, vertex_deployment etc to the exception here - ################################################################################ - extra_information = "" - try: - _api_base = litellm.get_api_base( - model=model, optional_params=extra_kwargs - ) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" - - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) - except: - # DO NOT LET this Block raising the original exception - pass - - ################################################################################ - # End of Common Extra information Needed for all providers - ################################################################################ - - ################################################################################ - #################### Start of Provider Exception mapping #################### - ################################################################################ - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: - exception_mapping_worked = True - raise Timeout( - message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) - if message is not None and isinstance(message, str): - message = message.replace("OPENAI", custom_llm_provider.upper()) - message = message.replace("openai", custom_llm_provider) - message = message.replace("OpenAI", custom_llm_provider) - if custom_llm_provider == "openai": - exception_provider = "OpenAI" + "Exception" - else: - exception_provider = ( - custom_llm_provider[0].upper() - + custom_llm_provider[1:] - + "Exception" - ) - - if "This model's maximum context length is" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "content_policy_violation" in error_str - ): - exception_mapping_worked = True - raise ContentPolicyViolationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Request too large" in error_str: - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Mistral API raised a streaming error" in error_str: - exception_mapping_worked = True - _request = httpx.Request( - method="POST", url="https://api.openai.com/v1" - ) - raise APIError( - status_code=500, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=_request, - litellm_debug_info=extra_information, - ) - elif hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=original_exception.request, - litellm_debug_info=extra_information, - ) - else: - # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors - raise APIConnectionError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - litellm_debug_info=extra_information, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if hasattr(original_exception, "status_code"): - print_verbose(f"status_code: {original_exception.status_code}") - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", - llm_provider="anthropic", - model=model, - request=original_exception.request, - ) - elif custom_llm_provider == "replicate": - if "Incorrect authentication token" in error_str: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif "input is too long" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif exception_type == "ModelError": - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif "Request was throttled" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"ReplicateException - {str(original_exception)}", - llm_provider="replicate", - model=model, - request=httpx.Request( - method="POST", - url="https://api.replicate.com/v1/deployments", - ), - ) - elif custom_llm_provider == "watsonx": - if "token_quota_reached" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"WatsonxException: Rate Limit Errror - {error_str}", - llm_provider="watsonx", - model=model, - response=original_exception.response, - ) - elif custom_llm_provider == "predibase": - if "authorization denied for" in error_str: - exception_mapping_worked = True - - # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception - if ( - error_str is not None - and isinstance(error_str, str) - and "bearer" in error_str.lower() - ): - # only keep the first 10 chars after the occurnence of "bearer" - _bearer_token_start_index = error_str.lower().find("bearer") - error_str = error_str[: _bearer_token_start_index + 14] - error_str += "XXXXXXX" + '"' - - raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "prompt: length: 1.." in error_str - or "Too many input tokens" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"BedrockException: Context Window Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "Malformed input request" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "AccessDeniedException" in error_str: - exception_mapping_worked = True - raise PermissionDeniedError( - message=f"BedrockException PermissionDeniedError - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Connect timeout on endpoint URL" in error_str - or "timed out" in error_str - ): - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException: Timeout Error - {error_str}", - model=model, - llm_provider="bedrock", - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=httpx.Response( - status_code=500, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ), - ) - elif original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "sagemaker": - if "Unable to locate credentials" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "`inputs` tokens + `max_new_tokens` must be <=" in error_str - or "instance type with more CPU capacity or memory" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "None Unknown Error." in error_str - or "Content has no parts." in error_str - ): - exception_mapping_worked = True - raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - request=original_exception.request, - litellm_debug_info=extra_information, - ) - elif "403" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "The response was blocked." in error_str: - exception_mapping_worked = True - raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - elif ( - "429 Quota exceeded" in error_str - or "IndexError: list index out of range" in error_str - or "429 Unable to submit request because the service is temporarily out of capacity." - in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - if hasattr(original_exception, "status_code"): - if original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=original_exception.response, - ) - if original_exception.status_code == 500: - exception_mapping_worked = True -> raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - request=original_exception.request, -E litellm.exceptions.APIError: VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -../utils.py:8922: APIError - -During handling of the above exception, another exception occurred: - - def test_gemini_pro_vision(): - try: - load_vertex_ai_credentials() - litellm.set_verbose = True - litellm.num_retries = 3 - resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - }, - }, - ], - } - ], - ) - print(resp) - - prompt_tokens = resp.usage.prompt_tokens - - # DO Not DELETE this ASSERT - # Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response - assert prompt_tokens == 263 # the gemini api returns 263 to us - except litellm.RateLimitError as e: - pass - except Exception as e: - if "500 Internal error encountered.'" in str(e): - pass - else: -> pytest.fail(f"An exception occurred - {str(e)}") -E Failed: An exception occurred - VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -test_amazing_vertex_completion.py:540: Failed ----------------------------- Captured stdout setup ----------------------------- - ------------------------------ Captured stdout call ----------------------------- -loading vertex ai credentials -Read vertexai file path - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}]) - - -self.optional_params: {} -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] -=============================== warnings summary =============================== -../proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings - /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) - -../proxy/_types.py:255 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:255: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:342 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:342: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - extra = Extra.allow # Allow extra fields - -../proxy/_types.py:345 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:345: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:374 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:374: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:421 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:421: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:490 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:490: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:510 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:510: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:523 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:523: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:568 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:568: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:605 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:605: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:923 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:923: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:950 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:950: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:971 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:971: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../utils.py:60 - /Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. - with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f: - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -=========================== short test summary info ============================ -FAILED test_amazing_vertex_completion.py::test_gemini_pro_vision - Failed: An... -======================== 1 failed, 39 warnings in 2.09s ======================== diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 1113adc40c..a5e098b027 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1288,14 +1288,14 @@ async def test_completion_replicate_llama3_streaming(sync_mode): @pytest.mark.parametrize( "model", [ - # "bedrock/cohere.command-r-plus-v1:0", - # "anthropic.claude-3-sonnet-20240229-v1:0", - # "anthropic.claude-instant-v1", - # "bedrock/ai21.j2-mid", - # "mistral.mistral-7b-instruct-v0:2", - # "bedrock/amazon.titan-tg1-large", - # "meta.llama3-8b-instruct-v1:0", - "cohere.command-text-v14" + "bedrock/cohere.command-r-plus-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-instant-v1", + "bedrock/ai21.j2-mid", + "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + "meta.llama3-8b-instruct-v1:0", + "cohere.command-text-v14", ], ) @pytest.mark.asyncio @@ -1324,8 +1324,6 @@ async def test_bedrock_httpx_streaming(sync_mode, model): raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") - - assert False else: response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore model=model, diff --git a/litellm/utils.py b/litellm/utils.py index 75dd853286..728173f383 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5620,13 +5620,80 @@ def get_optional_params( supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider ) - _check_valid_arg(supported_params=supported_params) - optional_params = litellm.AmazonConverseConfig().map_openai_params( - model=model, - non_default_params=non_default_params, - optional_params=optional_params, - drop_params=drop_params, - ) + if "ai21" in model: + _check_valid_arg(supported_params=supported_params) + # params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[], + # https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra + if max_tokens is not None: + optional_params["maxTokens"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["topP"] = top_p + if stream: + optional_params["stream"] = stream + elif "anthropic" in model: + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AmazonConverseConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) + elif "amazon" in model: # amazon titan llms + _check_valid_arg(supported_params=supported_params) + # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large + if max_tokens is not None: + optional_params["maxTokenCount"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if stop is not None: + filtered_stop = _map_and_modify_arg( + {"stop": stop}, provider="bedrock", model=model + ) + optional_params["stopSequences"] = filtered_stop["stop"] + if top_p is not None: + optional_params["topP"] = top_p + if stream: + optional_params["stream"] = stream + elif "meta" in model: # amazon / meta llms + _check_valid_arg(supported_params=supported_params) + # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large + if max_tokens is not None: + optional_params["max_gen_len"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if stream: + optional_params["stream"] = stream + elif "cohere" in model: # cohere models on bedrock + _check_valid_arg(supported_params=supported_params) + # handle cohere params + if stream: + optional_params["stream"] = stream + if temperature is not None: + optional_params["temperature"] = temperature + if max_tokens is not None: + optional_params["max_tokens"] = max_tokens + elif "mistral" in model: + _check_valid_arg(supported_params=supported_params) + # mistral params on bedrock + # \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}" + if max_tokens is not None: + optional_params["max_tokens"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if stop is not None: + optional_params["stop"] = stop + if stream is not None: + optional_params["stream"] = stream elif custom_llm_provider == "aleph_alpha": supported_params = [ "max_tokens", From 8370f81aa6ec5fd995aec61504270947970ff39c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 08:49:47 -0700 Subject: [PATCH 24/52] feat - use safe safe_deep_copy --- litellm/proxy/utils.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index e5efb93d05..c3d7313611 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -64,6 +64,14 @@ def print_verbose(print_statement): print(f"LiteLLM Proxy: {print_statement}") # noqa +def safe_deep_copy(data): + if isinstance(data, dict): + # remove litellm_parent_otel_span since this is not picklable + data.pop("litellm_parent_otel_span", None) + new_data = copy.deepcopy(data) + return new_data + + def log_to_opentelemetry(func): @wraps(func) async def wrapper(*args, **kwargs): @@ -312,7 +320,7 @@ class ProxyLogging: """ Runs the CustomLogger's async_moderation_hook() """ - new_data = copy.deepcopy(data) + new_data = safe_deep_copy(data) for callback in litellm.callbacks: try: if isinstance(callback, CustomLogger): From 0f99d47d87efd4034928e22d7dcb7deff82e93a2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 08:54:28 -0700 Subject: [PATCH 25/52] use litellm_parent_otel_span as litellm_param --- litellm/main.py | 3 +++ litellm/utils.py | 6 ++++-- 2 files changed, 7 insertions(+), 2 deletions(-) diff --git a/litellm/main.py b/litellm/main.py index f76d6c5213..596f85f334 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -600,6 +600,7 @@ def completion( client = kwargs.get("client", None) ### Admin Controls ### no_log = kwargs.get("no-log", False) + litellm_parent_otel_span = kwargs.get("litellm_parent_otel_span", None) ######## end of unpacking kwargs ########### openai_params = [ "functions", @@ -689,6 +690,7 @@ def completion( "allowed_model_region", "model_config", "fastest_response", + "litellm_parent_otel_span", ] default_params = openai_params + litellm_params @@ -873,6 +875,7 @@ def completion( input_cost_per_token=input_cost_per_token, output_cost_per_second=output_cost_per_second, output_cost_per_token=output_cost_per_token, + litellm_parent_otel_span=litellm_parent_otel_span, ) logging.update_environment_variables( model=model, diff --git a/litellm/utils.py b/litellm/utils.py index ba6a374674..be7728dfef 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4918,6 +4918,7 @@ def get_litellm_params( input_cost_per_token=None, output_cost_per_token=None, output_cost_per_second=None, + litellm_parent_otel_span=None, ): litellm_params = { "acompletion": acompletion, @@ -4940,6 +4941,7 @@ def get_litellm_params( "input_cost_per_second": input_cost_per_second, "output_cost_per_token": output_cost_per_token, "output_cost_per_second": output_cost_per_second, + "litellm_parent_otel_span": litellm_parent_otel_span, } return litellm_params @@ -7351,10 +7353,10 @@ def get_provider_fields(custom_llm_provider: str) -> List[ProviderField]: if custom_llm_provider == "databricks": return litellm.DatabricksConfig().get_required_params() - + elif custom_llm_provider == "ollama": return litellm.OllamaConfig().get_required_params() - + else: return [] From 98ed4533f7bd0930b8e723dea596fa8675c16ddb Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 08:55:35 -0700 Subject: [PATCH 26/52] fix - otel _get_span_context --- litellm/integrations/opentelemetry.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 5c6312c05f..97d57a55a0 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -156,8 +156,7 @@ class OpenTelemetry(CustomLogger): proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} headers = proxy_server_request.get("headers", {}) or {} traceparent = headers.get("traceparent", None) - _metadata = litellm_params.get("metadata", {}) or {} - parent_otel_span = _metadata.get("litellm_parent_otel_span", None) + parent_otel_span = litellm_params.get("litellm_parent_otel_span", None) """ Two way to use parents in opentelemetry From 778eed31151f0305515aaf767cd3f98061ab7107 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 09:01:07 -0700 Subject: [PATCH 27/52] fix(bedrock_httpx.py): fix linting errors --- .pre-commit-config.yaml | 16 ++++++++-------- litellm/llms/bedrock_httpx.py | 12 ++++++------ 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8bb1ff66a..cc41d85f14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -# - repo: local -# hooks: -# - id: mypy -# name: mypy -# entry: python3 -m mypy --ignore-missing-imports -# language: system -# types: [python] -# files: ^litellm/ \ No newline at end of file +- repo: local + hooks: + - id: mypy + name: mypy + entry: python3 -m mypy --ignore-missing-imports + language: system + types: [python] + files: ^litellm/ \ No newline at end of file diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 7aba78d7ce..9e44f16cfc 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1879,7 +1879,7 @@ class AWSEventStreamDecoder: elif "stopReason" in chunk_data: finish_reason = map_finish_reason(chunk_data.get("stopReason", "stop")) elif "usage" in chunk_data: - usage = ConverseTokenUsageBlock(**chunk_data["usage"]) + usage = ConverseTokenUsageBlock(**chunk_data["usage"]) # type: ignore response = GenericStreamingChunk( text=text, tool_str=tool_str, @@ -1929,11 +1929,11 @@ class AWSEventStreamDecoder: is_finished = True finish_reason = chunk_data["completionReason"] return GenericStreamingChunk( - **{ - "text": text, - "is_finished": is_finished, - "finish_reason": finish_reason, - } + text=text, + is_finished=is_finished, + finish_reason=finish_reason, + tool_str="", + usage=None, ) def iter_bytes(self, iterator: Iterator[bytes]) -> Iterator[GenericStreamingChunk]: From 1577d5a79e3b5415e2aec4c2f82766dac93fdcfa Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 09:13:38 -0700 Subject: [PATCH 28/52] feat(test_completion.py): deduplicate kwarg --- litellm/tests/test_completion.py | 1 - 1 file changed, 1 deletion(-) diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index f3e62338c7..f639653f46 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -2546,7 +2546,6 @@ def test_replicate_custom_prompt_dict(): } ], mock_response="Hello world", - mock_response="hello world", repetition_penalty=0.1, num_retries=3, ) From 74bf9e99725f5f06629cac9a0f99a099d52602ac Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 09:25:07 -0700 Subject: [PATCH 29/52] fix - include litellm_parent_otel_span --- litellm/proxy/proxy_server.py | 29 ++++++++++++++++++++++++----- 1 file changed, 24 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index d21553c2d9..f96d787883 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -549,7 +549,10 @@ async def user_api_key_auth( litellm_proxy_roles=jwt_handler.litellm_jwtauth, ) if is_allowed: - return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) else: allowed_routes = ( jwt_handler.litellm_jwtauth.admin_allowed_routes @@ -666,7 +669,6 @@ async def user_api_key_auth( user_info=user_info, ) ) - # get the request body request_data = await _read_request_body(request=request) @@ -695,15 +697,21 @@ async def user_api_key_auth( user_role=LitellmUserRoles.INTERNAL_USER, user_id=user_id, org_id=org_id, + parent_otel_span=parent_otel_span, ) #### ELSE #### if master_key is None: if isinstance(api_key, str): return UserAPIKeyAuth( - api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN + api_key=api_key, + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, ) else: - return UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN) + return UserAPIKeyAuth( + user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, + ) elif api_key is None: # only require api key if master key is set raise Exception("No api key passed in.") elif api_key == "": @@ -781,6 +789,7 @@ async def user_api_key_auth( valid_token.allowed_model_region = end_user_params.get( "allowed_model_region" ) + valid_token.parent_otel_span = parent_otel_span return valid_token @@ -807,6 +816,7 @@ async def user_api_key_auth( api_key=master_key, user_role=LitellmUserRoles.PROXY_ADMIN, user_id=litellm_proxy_admin_name, + parent_otel_span=parent_otel_span, **end_user_params, ) await user_api_key_cache.async_set_cache( @@ -1452,6 +1462,7 @@ async def user_api_key_auth( return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, **valid_token_dict, ) elif ( @@ -1459,7 +1470,10 @@ async def user_api_key_auth( and route in LiteLLMRoutes.sso_only_routes.value ): return UserAPIKeyAuth( - api_key=api_key, user_role="app_owner", **valid_token_dict + api_key=api_key, + user_role="app_owner", + parent_otel_span=parent_otel_span, + **valid_token_dict, ) else: raise Exception( @@ -1475,18 +1489,21 @@ async def user_api_key_auth( return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.PROXY_ADMIN, + parent_otel_span=parent_otel_span, **valid_token_dict, ) elif _has_user_setup_sso() and route in LiteLLMRoutes.sso_only_routes.value: return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.INTERNAL_USER, + parent_otel_span=parent_otel_span, **valid_token_dict, ) else: return UserAPIKeyAuth( api_key=api_key, user_role=LitellmUserRoles.INTERNAL_USER, + parent_otel_span=parent_otel_span, **valid_token_dict, ) else: @@ -4193,6 +4210,8 @@ async def chat_completion( ) # do not store the original `sk-..` api key in the db data["metadata"]["headers"] = _headers data["metadata"]["endpoint"] = str(request.url) + # Add the OTEL Parent Trace before sending it LiteLLM + data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span ### TEAM-SPECIFIC PARAMS ### if user_api_key_dict.team_id is not None: From a0f5b61bbc960fc90e93226dec2048205e6b23c2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 09:38:37 -0700 Subject: [PATCH 30/52] fix import Span --- litellm/proxy/_types.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index e089c0429d..05efc7495d 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1,13 +1,20 @@ from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict from dataclasses import fields import enum -from typing import Optional, List, Union, Dict, Literal, Any +from typing import Optional, List, Union, Dict, Literal, Any, TYPE_CHECKING from datetime import datetime import uuid, json, sys, os from litellm.types.router import UpdateRouterConfig from litellm.types.utils import ProviderField from typing_extensions import Annotated -from opentelemetry.trace import Span + + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any class LitellmUserRoles(str, enum.Enum): From ca99456e04f31208978204d3c2fd3cb412ba3e12 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 09:47:13 -0700 Subject: [PATCH 31/52] fix import OTEL span --- litellm/_service_logger.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/litellm/_service_logger.py b/litellm/_service_logger.py index dcc2fc1dd3..fd14b3cdeb 100644 --- a/litellm/_service_logger.py +++ b/litellm/_service_logger.py @@ -1,12 +1,18 @@ -import litellm, traceback +from datetime import datetime +import litellm from litellm.proxy._types import UserAPIKeyAuth from .types.services import ServiceTypes, ServiceLoggerPayload from .integrations.prometheus_services import PrometheusServicesLogger from .integrations.custom_logger import CustomLogger from datetime import timedelta -from typing import Union, Optional -from opentelemetry.trace import Span -from datetime import datetime +from typing import Union, Optional, TYPE_CHECKING, Any + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any class ServiceLogging(CustomLogger): From 4828e2426fda4ced5f8e482a754c369ac74dd12c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 09:55:59 -0700 Subject: [PATCH 32/52] fix importing Span --- litellm/integrations/opentelemetry.py | 14 ++++++++++---- litellm/proxy/auth/auth_checks.py | 10 ++++++++-- litellm/proxy/proxy_server.py | 10 ++++++++-- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 97d57a55a0..3fc50848e8 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -1,12 +1,18 @@ import os -from typing import Optional from dataclasses import dataclass +from datetime import datetime from litellm.integrations.custom_logger import CustomLogger from litellm._logging import verbose_logger -from litellm.types.services import ServiceLoggerPayload, ServiceTypes -from opentelemetry.trace import Span -from datetime import datetime +from litellm.types.services import ServiceLoggerPayload +from typing import Union, Optional, TYPE_CHECKING, Any + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any LITELLM_TRACER_NAME = "litellm" LITELLM_RESOURCE = {"service.name": "litellm"} diff --git a/litellm/proxy/auth/auth_checks.py b/litellm/proxy/auth/auth_checks.py index 10037e60fd..b9813afc91 100644 --- a/litellm/proxy/auth/auth_checks.py +++ b/litellm/proxy/auth/auth_checks.py @@ -17,14 +17,20 @@ from litellm.proxy._types import ( LiteLLM_OrganizationTable, LitellmUserRoles, ) -from typing import Optional, Literal, Union +from typing import Optional, Literal, TYPE_CHECKING, Any from litellm.proxy.utils import PrismaClient, ProxyLogging, log_to_opentelemetry from litellm.caching import DualCache import litellm -from opentelemetry.trace import Span from litellm.types.services import ServiceLoggerPayload, ServiceTypes from datetime import datetime +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index f96d787883..25e46269e1 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -2,12 +2,18 @@ import sys, os, platform, time, copy, re, asyncio, inspect import threading, ast import shutil, random, traceback, requests from datetime import datetime, timedelta, timezone -from typing import Optional, List, Callable, get_args, Set +from typing import Optional, List, Callable, get_args, Set, Any, TYPE_CHECKING import secrets, subprocess import hashlib, uuid import warnings import importlib -from opentelemetry.trace import Span + +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any def showwarning(message, category, filename, lineno, file=None, line=None): From 141cea5eb60296b170dc99c89382df10d4e9d0af Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 10:00:33 -0700 Subject: [PATCH 33/52] fix import Span --- litellm/proxy/utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index c3d7313611..8ae811f112 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -1,4 +1,4 @@ -from typing import Optional, List, Any, Literal, Union +from typing import Optional, List, Any, Literal, Union, TYPE_CHECKING import os, subprocess, hashlib, importlib, asyncio, copy, json, aiohttp, httpx, time import litellm, backoff, traceback from litellm.proxy._types import ( @@ -47,9 +47,15 @@ from email.mime.multipart import MIMEMultipart from datetime import datetime, timedelta from litellm.integrations.slack_alerting import SlackAlerting from typing_extensions import overload -from opentelemetry.trace import Span from functools import wraps +if TYPE_CHECKING: + from opentelemetry.trace import Span as _Span + + Span = _Span +else: + Span = Any + def print_verbose(print_statement): """ From 672dcf0c6fe57f17b65909bba3ed144b75ca24bb Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 10:04:03 -0700 Subject: [PATCH 34/52] fix(factory.py): handle bedrock claude image url's --- .gitignore | 1 + .pre-commit-config.yaml | 16 +++++----- litellm/llms/bedrock_httpx.py | 1 + litellm/llms/prompt_templates/factory.py | 16 +++++----- litellm/tests/test_bedrock_completion.py | 16 +++++++--- litellm/types/llms/bedrock.py | 2 +- litellm/utils.py | 38 +++++++++++++++++------- 7 files changed, 58 insertions(+), 32 deletions(-) diff --git a/.gitignore b/.gitignore index 8d99ae8af8..69061d62d3 100644 --- a/.gitignore +++ b/.gitignore @@ -59,3 +59,4 @@ myenv/* litellm/proxy/_experimental/out/404/index.html litellm/proxy/_experimental/out/model_hub/index.html litellm/proxy/_experimental/out/onboarding/index.html +litellm/tests/log.txt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc41d85f14..e8bb1ff66a 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -- repo: local - hooks: - - id: mypy - name: mypy - entry: python3 -m mypy --ignore-missing-imports - language: system - types: [python] - files: ^litellm/ \ No newline at end of file +# - repo: local +# hooks: +# - id: mypy +# name: mypy +# entry: python3 -m mypy --ignore-missing-imports +# language: system +# types: [python] +# files: ^litellm/ \ No newline at end of file diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 9e44f16cfc..59945a5857 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1158,6 +1158,7 @@ class AmazonConverseConfig: "stop", "temperature", "top_p", + "extra_headers", ] if ( diff --git a/litellm/llms/prompt_templates/factory.py b/litellm/llms/prompt_templates/factory.py index e5c8a79958..6bf03b52d4 100644 --- a/litellm/llms/prompt_templates/factory.py +++ b/litellm/llms/prompt_templates/factory.py @@ -1621,7 +1621,7 @@ from litellm.types.llms.bedrock import ( ) -def get_image_details(image_url) -> Tuple[bytes, str]: +def get_image_details(image_url) -> Tuple[str, str]: try: import base64 @@ -1637,7 +1637,7 @@ def get_image_details(image_url) -> Tuple[bytes, str]: ) # Convert the image content to base64 bytes - base64_bytes = base64.b64encode(response.content) + base64_bytes = base64.b64encode(response.content).decode("utf-8") # Get mime-type mime_type = content_type.split("/")[ @@ -1659,18 +1659,17 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: # base 64 is passed as data:image/jpeg;base64, image_metadata, img_without_base_64 = image_url.split(",") - image_format = image_metadata.split("/")[1] # read mime_type from img_without_base_64=data:image/jpeg;base64 # Extract MIME type using regular expression mime_type_match = re.match(r"data:(.*?);base64", image_metadata) - if mime_type_match: mime_type = mime_type_match.group(1) + image_format = mime_type.split("/")[1] else: - mime_type = "jpeg" - decoded_img = base64.b64decode(img_without_base_64) - _blob = BedrockImageSourceBlock(bytes=decoded_img) + mime_type = "image/jpeg" + image_format = "jpeg" + _blob = BedrockImageSourceBlock(bytes=img_without_base_64) supported_image_formats = ( litellm.AmazonConverseConfig().get_supported_image_types() ) @@ -1701,7 +1700,8 @@ def _process_bedrock_converse_image_block(image_url: str) -> BedrockImageBlock: ) else: raise ValueError( - "Unsupported image type. Expected either image url or base64 encoded string" + "Unsupported image type. Expected either image url or base64 encoded string - \ + e.g. 'data:image/jpeg;base64,'" ) diff --git a/litellm/tests/test_bedrock_completion.py b/litellm/tests/test_bedrock_completion.py index 047f0cb2e2..64e7741e2a 100644 --- a/litellm/tests/test_bedrock_completion.py +++ b/litellm/tests/test_bedrock_completion.py @@ -243,6 +243,7 @@ def test_completion_bedrock_claude_sts_oidc_auth(): except Exception as e: pytest.fail(f"Error occurred: {e}") + @pytest.mark.skipif( os.environ.get("CIRCLE_OIDC_TOKEN_V2") is None, reason="Cannot run without being in CircleCI Runner", @@ -277,7 +278,15 @@ def test_completion_bedrock_httpx_command_r_sts_oidc_auth(): except Exception as e: pytest.fail(f"Error occurred: {e}") -def test_bedrock_claude_3(): + +@pytest.mark.parametrize( + "image_url", + [ + "", + "https://avatars.githubusercontent.com/u/29436595?v=", + ], +) +def test_bedrock_claude_3(image_url): try: litellm.set_verbose = True data = { @@ -294,7 +303,7 @@ def test_bedrock_claude_3(): { "image_url": { "detail": "high", - "url": "", + "url": image_url, }, "type": "image_url", }, @@ -313,7 +322,6 @@ def test_bedrock_claude_3(): # Add any assertions here to check the response assert len(response.choices) > 0 assert len(response.choices[0].message.content) > 0 - except RateLimitError: pass except Exception as e: @@ -552,7 +560,7 @@ def test_bedrock_ptu(): assert "url" in mock_client_post.call_args.kwargs assert ( mock_client_post.call_args.kwargs["url"] - == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/invoke" + == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-west-2%3A888602223428%3Aprovisioned-model%2F8fxff74qyhs3/converse" ) mock_client_post.assert_called_once() diff --git a/litellm/types/llms/bedrock.py b/litellm/types/llms/bedrock.py index 757ece516f..b06075092f 100644 --- a/litellm/types/llms/bedrock.py +++ b/litellm/types/llms/bedrock.py @@ -16,7 +16,7 @@ class SystemContentBlock(TypedDict): class ImageSourceBlock(TypedDict): - bytes: Optional[bytes] + bytes: Optional[str] # base 64 encoded string class ImageBlock(TypedDict): diff --git a/litellm/utils.py b/litellm/utils.py index dfee72ea7d..955bcd3e96 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4066,7 +4066,9 @@ def openai_token_counter( for c in value: if c["type"] == "text": text += c["text"] - num_tokens += len(encoding.encode(c["text"], disallowed_special=())) + num_tokens += len( + encoding.encode(c["text"], disallowed_special=()) + ) elif c["type"] == "image_url": if isinstance(c["image_url"], dict): image_url_dict = c["image_url"] @@ -5639,16 +5641,30 @@ def get_optional_params( optional_params["stream"] = stream elif "anthropic" in model: _check_valid_arg(supported_params=supported_params) - optional_params = litellm.AmazonConverseConfig().map_openai_params( - model=model, - non_default_params=non_default_params, - optional_params=optional_params, - drop_params=( - drop_params - if drop_params is not None and isinstance(drop_params, bool) - else False - ), - ) + if "aws_bedrock_client" in passed_params: # deprecated boto3.invoke route. + if model.startswith("anthropic.claude-3"): + optional_params = ( + litellm.AmazonAnthropicClaude3Config().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) + ) + else: + optional_params = litellm.AmazonAnthropicConfig().map_openai_params( + non_default_params=non_default_params, + optional_params=optional_params, + ) + else: # bedrock httpx route + optional_params = litellm.AmazonConverseConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) elif "amazon" in model: # amazon titan llms _check_valid_arg(supported_params=supported_params) # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large From 53caf337fe6add39fb7b49a6a06ca66bfb8ec67b Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 10:04:35 -0700 Subject: [PATCH 35/52] build: cleanup pre-commit hook --- .pre-commit-config.yaml | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index e8bb1ff66a..cc41d85f14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -16,11 +16,11 @@ repos: name: Check if files match entry: python3 ci_cd/check_files_match.py language: system -# - repo: local -# hooks: -# - id: mypy -# name: mypy -# entry: python3 -m mypy --ignore-missing-imports -# language: system -# types: [python] -# files: ^litellm/ \ No newline at end of file +- repo: local + hooks: + - id: mypy + name: mypy + entry: python3 -m mypy --ignore-missing-imports + language: system + types: [python] + files: ^litellm/ \ No newline at end of file From 90f345951dbcc8f8e64b2b392428efccdd32ec2e Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 10:07:48 -0700 Subject: [PATCH 36/52] (ci/cd) use ruff --- .pre-commit-config.yaml | 19 +++++++++---------- litellm/tests/test_completion.py | 4 ---- 2 files changed, 9 insertions(+), 14 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index cc41d85f14..41bff6d84b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,15 +1,14 @@ repos: -- repo: https://github.com/psf/black - rev: 24.2.0 +- repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.4.8 hooks: - - id: black -- repo: https://github.com/pycqa/flake8 - rev: 7.0.0 # The version of flake8 to use - hooks: - - id: flake8 - exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ - additional_dependencies: [flake8-print] - files: litellm/.*\.py + # Run the linter. + - id: ruff + exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ + # Run the formatter. + - id: ruff-format + exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ - repo: local hooks: - id: check-files-match diff --git a/litellm/tests/test_completion.py b/litellm/tests/test_completion.py index 47c55ca4f3..d143d1ab80 100644 --- a/litellm/tests/test_completion.py +++ b/litellm/tests/test_completion.py @@ -1398,7 +1398,6 @@ def test_hf_test_completion_tgi(): def mock_post(url, data=None, json=None, headers=None): - print(f"url={url}") if "text-classification" in url: raise Exception("Model not found") @@ -2240,9 +2239,6 @@ def test_re_use_openaiClient(): pytest.fail("got Exception", e) -# test_re_use_openaiClient() - - def test_completion_azure(): try: print("azure gpt-3.5 test\n\n") From 2eec379d9257f3b98f8e9ac10ccd3be7f6642a37 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 11:53:03 -0700 Subject: [PATCH 37/52] test fix - proxy server chat completion --- litellm/tests/test_proxy_server.py | 1 + 1 file changed, 1 insertion(+) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 114b96872f..6e6012199e 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -152,6 +152,7 @@ def test_chat_completion(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) print(f"response - {response.text}") assert response.status_code == 200 From aab0747c7f3ed0b6d1c366842b80743ebab3fbd7 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 12:07:15 -0700 Subject: [PATCH 38/52] test(test_prompt_factory.py): cleanup test --- litellm/tests/test_prompt_factory.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/litellm/tests/test_prompt_factory.py b/litellm/tests/test_prompt_factory.py index 9f112a0b1b..b3aafab6e6 100644 --- a/litellm/tests/test_prompt_factory.py +++ b/litellm/tests/test_prompt_factory.py @@ -153,5 +153,3 @@ def test_bedrock_tool_calling_pt(): converted_tools = _bedrock_tools_pt(tools=tools) print(converted_tools) - - assert False From 522da0c4f8deba99b8068609442254ace45ce97c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 12:18:53 -0700 Subject: [PATCH 39/52] feat - refactor /chat/completions to have a common helper --- .pre-commit-config.yaml | 19 ++--- litellm/proxy/litellm_pre_call_utils.py | 93 +++++++++++++++++++++ litellm/proxy/proxy_server.py | 102 ++---------------------- 3 files changed, 111 insertions(+), 103 deletions(-) create mode 100644 litellm/proxy/litellm_pre_call_utils.py diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 41bff6d84b..cc41d85f14 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,14 +1,15 @@ repos: -- repo: https://github.com/astral-sh/ruff-pre-commit - # Ruff version. - rev: v0.4.8 +- repo: https://github.com/psf/black + rev: 24.2.0 hooks: - # Run the linter. - - id: ruff - exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ - # Run the formatter. - - id: ruff-format - exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ + - id: black +- repo: https://github.com/pycqa/flake8 + rev: 7.0.0 # The version of flake8 to use + hooks: + - id: flake8 + exclude: ^litellm/tests/|^litellm/proxy/proxy_cli.py|^litellm/proxy/tests/ + additional_dependencies: [flake8-print] + files: litellm/.*\.py - repo: local hooks: - id: check-files-match diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py new file mode 100644 index 0000000000..d4736f933e --- /dev/null +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -0,0 +1,93 @@ +import copy +from fastapi import Request +from typing import Any, Dict, Optional +from litellm.proxy._types import UserAPIKeyAuth +from litellm._logging import verbose_proxy_logger, verbose_logger + + +def parse_cache_control(cache_control): + cache_dict = {} + directives = cache_control.split(", ") + + for directive in directives: + if "=" in directive: + key, value = directive.split("=") + cache_dict[key] = value + else: + cache_dict[directive] = True + + return cache_dict + + +async def add_litellm_data_to_request( + data: dict, + request: Request, + user_api_key_dict: UserAPIKeyAuth, + general_settings: Optional[Dict[str, Any]] = None, + version: Optional[str] = None, +): + # Azure OpenAI only: check if user passed api-version + query_params = dict(request.query_params) + if "api-version" in query_params: + data["api_version"] = query_params["api-version"] + + # Include original request and headers in the data + data["proxy_server_request"] = { + "url": str(request.url), + "method": request.method, + "headers": dict(request.headers), + "body": copy.copy(data), # use copy instead of deepcopy + } + + ## Cache Controls + headers = request.headers + verbose_proxy_logger.debug("Request Headers: %s", headers) + cache_control_header = headers.get("Cache-Control", None) + if cache_control_header: + cache_dict = parse_cache_control(cache_control_header) + data["ttl"] = cache_dict.get("s-maxage") + + verbose_proxy_logger.debug("receiving data: %s", data) + # users can pass in 'user' param to /chat/completions. Don't override it + if data.get("user", None) is None and user_api_key_dict.user_id is not None: + # if users are using user_api_key_auth, set `user` in `data` + data["user"] = user_api_key_dict.user_id + + if "metadata" not in data: + data["metadata"] = {} + data["metadata"]["user_api_key"] = user_api_key_dict.api_key + data["metadata"]["user_api_key_alias"] = getattr( + user_api_key_dict, "key_alias", None + ) + data["metadata"]["user_api_end_user_max_budget"] = getattr( + user_api_key_dict, "end_user_max_budget", None + ) + data["metadata"]["litellm_api_version"] = version + + if general_settings is not None: + data["metadata"]["global_max_parallel_requests"] = general_settings.get( + "global_max_parallel_requests", None + ) + + data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id + data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id + data["metadata"]["user_api_key_team_id"] = getattr( + user_api_key_dict, "team_id", None + ) + data["metadata"]["user_api_key_team_alias"] = getattr( + user_api_key_dict, "team_alias", None + ) + data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata + _headers = dict(request.headers) + _headers.pop( + "authorization", None + ) # do not store the original `sk-..` api key in the db + data["metadata"]["headers"] = _headers + data["metadata"]["endpoint"] = str(request.url) + # Add the OTEL Parent Trace before sending it LiteLLM + data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span + + ### END-USER SPECIFIC PARAMS ### + if user_api_key_dict.allowed_model_region is not None: + data["allowed_model_region"] = user_api_key_dict.allowed_model_region + return data diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 25e46269e1..15846ad4db 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -89,6 +89,7 @@ import litellm from litellm.types.llms.openai import ( HttpxBinaryResponseContent, ) +from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request from litellm.proxy.utils import ( PrismaClient, DBClient, @@ -3827,20 +3828,6 @@ def get_litellm_model_info(model: dict = {}): return {} -def parse_cache_control(cache_control): - cache_dict = {} - directives = cache_control.split(", ") - - for directive in directives: - if "=" in directive: - key, value = directive.split("=") - cache_dict[key] = value - else: - cache_dict[directive] = True - - return cache_dict - - def on_backoff(details): # The 'tries' key in the details dictionary contains the number of completed tries verbose_proxy_logger.debug("Backing off... this was attempt # %s", details["tries"]) @@ -4153,28 +4140,14 @@ async def chat_completion( except: data = json.loads(body_str) - # Azure OpenAI only: check if user passed api-version - query_params = dict(request.query_params) - if "api-version" in query_params: - data["api_version"] = query_params["api-version"] + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + ) - # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - ## Cache Controls - headers = request.headers - verbose_proxy_logger.debug("Request Headers: %s", headers) - cache_control_header = headers.get("Cache-Control", None) - if cache_control_header: - cache_dict = parse_cache_control(cache_control_header) - data["ttl"] = cache_dict.get("s-maxage") - - verbose_proxy_logger.debug("receiving data: %s", data) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -4182,65 +4155,6 @@ async def chat_completion( or data["model"] # default passed in http request ) - # users can pass in 'user' param to /chat/completions. Don't override it - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - # if users are using user_api_key_auth, set `user` in `data` - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_end_user_max_budget"] = getattr( - user_api_key_dict, "end_user_max_budget", None - ) - data["metadata"]["litellm_api_version"] = version - - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_org_id"] = user_api_key_dict.org_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["endpoint"] = str(request.url) - # Add the OTEL Parent Trace before sending it LiteLLM - data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - _is_valid_team_configs( - team_id=team_id, team_config=team_config, request_data=data - ) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - - ### END-USER SPECIFIC PARAMS ### - if user_api_key_dict.allowed_model_region is not None: - data["allowed_model_region"] = user_api_key_dict.allowed_model_region - global user_temperature, user_request_timeout, user_max_tokens, user_api_base # override with user settings, these are params passed via cli if user_temperature: From 5df327aca28ce45191a19e1231fd8c81ab2364be Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 12:38:38 -0700 Subject: [PATCH 40/52] fix - refactor proxy server to use common func --- litellm/proxy/litellm_pre_call_utils.py | 33 +- litellm/proxy/proxy_server.py | 807 ++++-------------------- 2 files changed, 144 insertions(+), 696 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index d4736f933e..ddc2eb5528 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -3,6 +3,7 @@ from fastapi import Request from typing import Any, Dict, Optional from litellm.proxy._types import UserAPIKeyAuth from litellm._logging import verbose_proxy_logger, verbose_logger +from litellm.proxy.proxy_server import ProxyConfig def parse_cache_control(cache_control): @@ -23,10 +24,24 @@ async def add_litellm_data_to_request( data: dict, request: Request, user_api_key_dict: UserAPIKeyAuth, + proxy_config: ProxyConfig, general_settings: Optional[Dict[str, Any]] = None, version: Optional[str] = None, ): - # Azure OpenAI only: check if user passed api-version + """ + Adds LiteLLM-specific data to the request. + + Args: + data (dict): The data dictionary to be modified. + request (Request): The incoming request. + user_api_key_dict (UserAPIKeyAuth): The user API key dictionary. + general_settings (Optional[Dict[str, Any]], optional): General settings. Defaults to None. + version (Optional[str], optional): Version. Defaults to None. + + Returns: + dict: The modified data dictionary. + + """ query_params = dict(request.query_params) if "api-version" in query_params: data["api_version"] = query_params["api-version"] @@ -90,4 +105,20 @@ async def add_litellm_data_to_request( ### END-USER SPECIFIC PARAMS ### if user_api_key_dict.allowed_model_region is not None: data["allowed_model_region"] = user_api_key_dict.allowed_model_region + + ### TEAM-SPECIFIC PARAMS ### + if user_api_key_dict.team_id is not None: + team_config = await proxy_config.load_team_config( + team_id=user_api_key_dict.team_id + ) + if len(team_config) == 0: + pass + else: + team_id = team_config.pop("team_id", None) + data["metadata"]["team_id"] = team_id + data = { + **team_config, + **data, + } # add the team-specific configs to the completion call + return data diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 15846ad4db..2f8e26fe13 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -4146,6 +4146,7 @@ async def chat_completion( general_settings=general_settings, user_api_key_dict=user_api_key_dict, version=version, + proxy_config=proxy_config, ) data["model"] = ( @@ -4410,7 +4411,6 @@ async def completion( except: data = json.loads(body_str) - data["user"] = data.get("user", user_api_key_dict.user_id) data["model"] = ( general_settings.get("completion_model", None) # server default or user_model # model name passed via cli args @@ -4419,30 +4419,15 @@ async def completion( ) if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["endpoint"] = str(request.url) # override with user settings, these are params passed via cli if user_temperature: @@ -4642,15 +4627,14 @@ async def embeddings( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) data["model"] = ( general_settings.get("embedding_model", None) # server default @@ -4660,45 +4644,6 @@ async def embeddings( ) if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call ### MODEL ALIAS MAPPING ### # check if model name in model alias map @@ -4853,15 +4798,14 @@ async def image_generation( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) data["model"] = ( general_settings.get("image_generation_model", None) # server default @@ -4871,46 +4815,6 @@ async def image_generation( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - ### MODEL ALIAS MAPPING ### # check if model name in model alias map # get the actual model name @@ -5035,12 +4939,14 @@ async def audio_speech( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) if data.get("user", None) is None and user_api_key_dict.user_id is not None: data["user"] = user_api_key_dict.user_id @@ -5048,46 +4954,6 @@ async def audio_speech( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - router_model_names = llm_router.model_names if llm_router is not None else [] ### CALL HOOKS ### - modify incoming data / reject request before calling the model @@ -5200,12 +5066,14 @@ async def audio_transcriptions( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) if data.get("user", None) is None and user_api_key_dict.user_id is not None: data["user"] = user_api_key_dict.user_id @@ -5218,47 +5086,6 @@ async def audio_transcriptions( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - data["metadata"]["file_name"] = file.filename - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - router_model_names = llm_router.model_names if llm_router is not None else [] assert ( @@ -5410,55 +5237,14 @@ async def get_assistants( body = await request.body() # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - data["metadata"]["litellm_api_version"] = version - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5539,55 +5325,14 @@ async def create_threads( body = await request.body() # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "litellm_metadata" not in data: - data["litellm_metadata"] = {} - data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["litellm_metadata"]["headers"] = _headers - data["litellm_metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["litellm_metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["litellm_metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["litellm_metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["litellm_metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5667,55 +5412,14 @@ async def get_thread( try: # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5798,55 +5502,14 @@ async def add_messages( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "litellm_metadata" not in data: - data["litellm_metadata"] = {} - data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key - data["litellm_metadata"]["litellm_api_version"] = version - data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["litellm_metadata"]["headers"] = _headers - data["litellm_metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["litellm_metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["litellm_metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["litellm_metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["litellm_metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -5925,55 +5588,14 @@ async def get_messages( data: Dict = {} try: # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -6054,55 +5676,14 @@ async def run_thread( body = await request.body() data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "litellm_metadata" not in data: - data["litellm_metadata"] = {} - data["litellm_metadata"]["user_api_key"] = user_api_key_dict.api_key - data["litellm_metadata"]["litellm_api_version"] = version - data["litellm_metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["litellm_metadata"]["headers"] = _headers - data["litellm_metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["litellm_metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["litellm_metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["litellm_metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["litellm_metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["litellm_metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["litellm_metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch if llm_router is None: @@ -6214,55 +5795,14 @@ async def create_batch( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call _create_batch_data = CreateBatchRequest(**data) @@ -6355,55 +5895,14 @@ async def retrieve_batch( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call _retrieve_batch_request = RetrieveBatchRequest( batch_id=batch_id, @@ -6510,55 +6009,14 @@ async def create_file( data = {key: value for key, value in form_data.items() if key != "file"} # Include original request and headers in the data - data["proxy_server_request"] = { # type: ignore - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id - - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call _create_file_request = CreateFileRequest() @@ -6648,15 +6106,14 @@ async def moderations( data = orjson.loads(body) # Include original request and headers in the data - data["proxy_server_request"] = { - "url": str(request.url), - "method": request.method, - "headers": dict(request.headers), - "body": copy.copy(data), # use copy instead of deepcopy - } - - if data.get("user", None) is None and user_api_key_dict.user_id is not None: - data["user"] = user_api_key_dict.user_id + data = await add_litellm_data_to_request( + data=data, + request=request, + general_settings=general_settings, + user_api_key_dict=user_api_key_dict, + version=version, + proxy_config=proxy_config, + ) data["model"] = ( general_settings.get("moderation_model", None) # server default @@ -6666,46 +6123,6 @@ async def moderations( if user_model: data["model"] = user_model - if "metadata" not in data: - data["metadata"] = {} - data["metadata"]["user_api_key"] = user_api_key_dict.api_key - data["metadata"]["litellm_api_version"] = version - data["metadata"]["user_api_key_metadata"] = user_api_key_dict.metadata - _headers = dict(request.headers) - _headers.pop( - "authorization", None - ) # do not store the original `sk-..` api key in the db - data["metadata"]["headers"] = _headers - data["metadata"]["global_max_parallel_requests"] = general_settings.get( - "global_max_parallel_requests", None - ) - data["metadata"]["user_api_key_alias"] = getattr( - user_api_key_dict, "key_alias", None - ) - data["metadata"]["user_api_key_user_id"] = user_api_key_dict.user_id - data["metadata"]["user_api_key_team_id"] = getattr( - user_api_key_dict, "team_id", None - ) - data["metadata"]["user_api_key_team_alias"] = getattr( - user_api_key_dict, "team_alias", None - ) - data["metadata"]["endpoint"] = str(request.url) - - ### TEAM-SPECIFIC PARAMS ### - if user_api_key_dict.team_id is not None: - team_config = await proxy_config.load_team_config( - team_id=user_api_key_dict.team_id - ) - if len(team_config) == 0: - pass - else: - team_id = team_config.pop("team_id", None) - data["metadata"]["team_id"] = team_id - data = { - **team_config, - **data, - } # add the team-specific configs to the completion call - router_model_names = llm_router.model_names if llm_router is not None else [] ### CALL HOOKS ### - modify incoming data / reject request before calling the model From 22e653d9226a1a2c14a30b4c2d46375518bcf0de Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 12:50:50 -0700 Subject: [PATCH 41/52] fix importing litellm --- litellm/proxy/litellm_pre_call_utils.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index ddc2eb5528..1614a6dfe6 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -1,9 +1,15 @@ import copy from fastapi import Request -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, TYPE_CHECKING from litellm.proxy._types import UserAPIKeyAuth from litellm._logging import verbose_proxy_logger, verbose_logger -from litellm.proxy.proxy_server import ProxyConfig + +if TYPE_CHECKING: + from litellm.proxy.proxy_server import ProxyConfig as _ProxyConfig + + ProxyConfig = _ProxyConfig +else: + ProxyConfig = Any def parse_cache_control(cache_control): From 308c4b3b75876f3220468422df05faa66d0dd437 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 12:54:39 -0700 Subject: [PATCH 42/52] fix proxy server test --- litellm/tests/test_proxy_server.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 6e6012199e..2c643eff0b 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -191,6 +191,7 @@ def test_engines_model_chat_completions(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) print(f"response - {response.text}") assert response.status_code == 200 @@ -228,6 +229,7 @@ def test_chat_completion_azure(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) assert response.status_code == 200 result = response.json() @@ -272,6 +274,7 @@ def test_openai_deployments_model_chat_completions_azure( specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) assert response.status_code == 200 result = response.json() @@ -486,6 +489,7 @@ def test_chat_completion_optional_params(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) assert response.status_code == 200 result = response.json() From 6182718299541ce89a01c9bb9d9d685beaf736ee Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 13:31:18 -0700 Subject: [PATCH 43/52] test fix --- litellm/tests/test_proxy_exception_mapping.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index ccd071d01e..bd1cf4bb67 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -103,6 +103,7 @@ def test_chat_completion_exception_azure(mock_acompletion, client): request_timeout=mock.ANY, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) json_response = response.json() @@ -271,6 +272,7 @@ def test_chat_completion_exception_azure_context_window(mock_acompletion, client request_timeout=mock.ANY, metadata=mock.ANY, proxy_server_request=mock.ANY, + litellm_parent_otel_span=mock.ANY, ) json_response = response.json() From 8106a6dc9b3ce0c5b16be6133041bb5f8e2f1c53 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 13:48:21 -0700 Subject: [PATCH 44/52] fix simplify - pass litellm_parent_otel_span --- litellm/integrations/opentelemetry.py | 3 ++- litellm/main.py | 3 --- litellm/proxy/litellm_pre_call_utils.py | 2 +- litellm/proxy/utils.py | 3 ++- litellm/tests/test_proxy_server.py | 1 - litellm/utils.py | 2 -- 6 files changed, 5 insertions(+), 9 deletions(-) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 3fc50848e8..5d308dc378 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -162,7 +162,8 @@ class OpenTelemetry(CustomLogger): proxy_server_request = litellm_params.get("proxy_server_request", {}) or {} headers = proxy_server_request.get("headers", {}) or {} traceparent = headers.get("traceparent", None) - parent_otel_span = litellm_params.get("litellm_parent_otel_span", None) + _metadata = litellm_params.get("metadata", {}) + parent_otel_span = _metadata.get("litellm_parent_otel_span", None) """ Two way to use parents in opentelemetry diff --git a/litellm/main.py b/litellm/main.py index 596f85f334..f76d6c5213 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -600,7 +600,6 @@ def completion( client = kwargs.get("client", None) ### Admin Controls ### no_log = kwargs.get("no-log", False) - litellm_parent_otel_span = kwargs.get("litellm_parent_otel_span", None) ######## end of unpacking kwargs ########### openai_params = [ "functions", @@ -690,7 +689,6 @@ def completion( "allowed_model_region", "model_config", "fastest_response", - "litellm_parent_otel_span", ] default_params = openai_params + litellm_params @@ -875,7 +873,6 @@ def completion( input_cost_per_token=input_cost_per_token, output_cost_per_second=output_cost_per_second, output_cost_per_token=output_cost_per_token, - litellm_parent_otel_span=litellm_parent_otel_span, ) logging.update_environment_variables( model=model, diff --git a/litellm/proxy/litellm_pre_call_utils.py b/litellm/proxy/litellm_pre_call_utils.py index 1614a6dfe6..945799b4cf 100644 --- a/litellm/proxy/litellm_pre_call_utils.py +++ b/litellm/proxy/litellm_pre_call_utils.py @@ -106,7 +106,7 @@ async def add_litellm_data_to_request( data["metadata"]["headers"] = _headers data["metadata"]["endpoint"] = str(request.url) # Add the OTEL Parent Trace before sending it LiteLLM - data["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span + data["metadata"]["litellm_parent_otel_span"] = user_api_key_dict.parent_otel_span ### END-USER SPECIFIC PARAMS ### if user_api_key_dict.allowed_model_region is not None: diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 8ae811f112..f3bbf69c4b 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -73,7 +73,8 @@ def print_verbose(print_statement): def safe_deep_copy(data): if isinstance(data, dict): # remove litellm_parent_otel_span since this is not picklable - data.pop("litellm_parent_otel_span", None) + if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: + data["metadata"].pop("litellm_parent_otel_span") new_data = copy.deepcopy(data) return new_data diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 6e6012199e..114b96872f 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -152,7 +152,6 @@ def test_chat_completion(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) print(f"response - {response.text}") assert response.status_code == 200 diff --git a/litellm/utils.py b/litellm/utils.py index be7728dfef..822fea4815 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4918,7 +4918,6 @@ def get_litellm_params( input_cost_per_token=None, output_cost_per_token=None, output_cost_per_second=None, - litellm_parent_otel_span=None, ): litellm_params = { "acompletion": acompletion, @@ -4941,7 +4940,6 @@ def get_litellm_params( "input_cost_per_second": input_cost_per_second, "output_cost_per_token": output_cost_per_token, "output_cost_per_second": output_cost_per_second, - "litellm_parent_otel_span": litellm_parent_otel_span, } return litellm_params From 6ce970e7cdfc4e4e8aedd632bf9abe3e85947802 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 14:07:58 -0700 Subject: [PATCH 45/52] fix test litellm_parent_otel_span --- litellm/tests/test_proxy_exception_mapping.py | 10 ++++++---- litellm/tests/test_proxy_server.py | 4 ---- 2 files changed, 6 insertions(+), 8 deletions(-) diff --git a/litellm/tests/test_proxy_exception_mapping.py b/litellm/tests/test_proxy_exception_mapping.py index bd1cf4bb67..4988426616 100644 --- a/litellm/tests/test_proxy_exception_mapping.py +++ b/litellm/tests/test_proxy_exception_mapping.py @@ -103,7 +103,6 @@ def test_chat_completion_exception_azure(mock_acompletion, client): request_timeout=mock.ANY, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) json_response = response.json() @@ -211,7 +210,9 @@ def test_chat_completion_exception_any_model(client): ) assert isinstance(openai_exception, openai.BadRequestError) _error_message = openai_exception.message - assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str(_error_message) + assert "chat_completion: Invalid model name passed in model=Lite-GPT-12" in str( + _error_message + ) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") @@ -239,7 +240,9 @@ def test_embedding_exception_any_model(client): print("Exception raised=", openai_exception) assert isinstance(openai_exception, openai.BadRequestError) _error_message = openai_exception.message - assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str(_error_message) + assert "embeddings: Invalid model name passed in model=Lite-GPT-12" in str( + _error_message + ) except Exception as e: pytest.fail(f"LiteLLM Proxy test failed. Exception {str(e)}") @@ -272,7 +275,6 @@ def test_chat_completion_exception_azure_context_window(mock_acompletion, client request_timeout=mock.ANY, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) json_response = response.json() diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index a37f8adbd1..114b96872f 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -190,7 +190,6 @@ def test_engines_model_chat_completions(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) print(f"response - {response.text}") assert response.status_code == 200 @@ -228,7 +227,6 @@ def test_chat_completion_azure(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) assert response.status_code == 200 result = response.json() @@ -273,7 +271,6 @@ def test_openai_deployments_model_chat_completions_azure( specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) assert response.status_code == 200 result = response.json() @@ -488,7 +485,6 @@ def test_chat_completion_optional_params(mock_acompletion, client_no_auth): specific_deployment=True, metadata=mock.ANY, proxy_server_request=mock.ANY, - litellm_parent_otel_span=mock.ANY, ) assert response.status_code == 200 result = response.json() From a6ec9cd54bfe65ac21c380a6ff9dccbaf8484dc2 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 14:11:55 -0700 Subject: [PATCH 46/52] fix - use safe_deep copy from litellm --- litellm/proxy/utils.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index afe059ce1e..db6eefc926 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -71,11 +71,27 @@ def print_verbose(print_statement): def safe_deep_copy(data): + """ + Safe Deep Copy + + The LiteLLM Request has some object that can-not be pickled / deep copied + + Use this function to safely deep copy the LiteLLM Request + """ + + # Step 1: Remove the litellm_parent_otel_span if isinstance(data, dict): # remove litellm_parent_otel_span since this is not picklable if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: data["metadata"].pop("litellm_parent_otel_span") new_data = copy.deepcopy(data) + + # Step 2: re-add the litellm_parent_otel_span after doing a deep copy + if isinstance(data, dict): + if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: + data["metadata"]["litellm_parent_otel_span"] = data["metadata"][ + "litellm_parent_otel_span" + ] return new_data @@ -2891,6 +2907,7 @@ missing_keys_html_form = """ def _to_ns(dt): return int(dt.timestamp() * 1e9) + def get_error_message_str(e: Exception) -> str: error_message = "" if isinstance(e, HTTPException): From 0cd836cf46e374319f5c19e313ccb10fc7a1565f Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 14:23:06 -0700 Subject: [PATCH 47/52] fix type error on python 3.8 --- litellm/types/files.py | 209 ++++++++++++++++++++++------------------- 1 file changed, 111 insertions(+), 98 deletions(-) diff --git a/litellm/types/files.py b/litellm/types/files.py index 0545567ece..5badcddcc1 100644 --- a/litellm/types/files.py +++ b/litellm/types/files.py @@ -1,10 +1,11 @@ from enum import Enum -from types import MappingProxyType -from typing import List, Set +from typing import List, Set, Dict """ Base Enums/Consts """ + + class FileType(Enum): AAC = "AAC" CSV = "CSV" @@ -49,99 +50,106 @@ class FileType(Enum): XLS = "XLS" XLSX = "XLSX" -FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType({ - FileType.AAC: ["aac"], - FileType.CSV: ["csv"], - FileType.DOC: ["doc"], - FileType.DOCX: ["docx"], - FileType.FLAC: ["flac"], - FileType.FLV: ["flv"], - FileType.GIF: ["gif"], - FileType.GOOGLE_DOC: ["gdoc"], - FileType.GOOGLE_DRAWINGS: ["gdraw"], - FileType.GOOGLE_SHEETS: ["gsheet"], - FileType.GOOGLE_SLIDES: ["gslides"], - FileType.HEIC: ["heic"], - FileType.HEIF: ["heif"], - FileType.HTML: ["html", "htm"], - FileType.JPEG: ["jpeg", "jpg"], - FileType.JSON: ["json"], - FileType.M4A: ["m4a"], - FileType.M4V: ["m4v"], - FileType.MOV: ["mov"], - FileType.MP3: ["mp3"], - FileType.MP4: ["mp4"], - FileType.MPEG: ["mpeg"], - FileType.MPEGPS: ["mpegps"], - FileType.MPG: ["mpg"], - FileType.MPA: ["mpa"], - FileType.MPGA: ["mpga"], - FileType.OGG: ["ogg"], - FileType.OPUS: ["opus"], - FileType.PDF: ["pdf"], - FileType.PCM: ["pcm"], - FileType.PNG: ["png"], - FileType.PPT: ["ppt"], - FileType.PPTX: ["pptx"], - FileType.RTF: ["rtf"], - FileType.THREE_GPP: ["3gpp"], - FileType.TXT: ["txt"], - FileType.WAV: ["wav"], - FileType.WEBM: ["webm"], - FileType.WEBP: ["webp"], - FileType.WMV: ["wmv"], - FileType.XLS: ["xls"], - FileType.XLSX: ["xlsx"], -}) -FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType({ - FileType.AAC: "audio/aac", - FileType.CSV: "text/csv", - FileType.DOC: "application/msword", - FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", - FileType.FLAC: "audio/flac", - FileType.FLV: "video/x-flv", - FileType.GIF: "image/gif", - FileType.GOOGLE_DOC: "application/vnd.google-apps.document", - FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing", - FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet", - FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation", - FileType.HEIC: "image/heic", - FileType.HEIF: "image/heif", - FileType.HTML: "text/html", - FileType.JPEG: "image/jpeg", - FileType.JSON: "application/json", - FileType.M4A: "audio/x-m4a", - FileType.M4V: "video/x-m4v", - FileType.MOV: "video/quicktime", - FileType.MP3: "audio/mpeg", - FileType.MP4: "video/mp4", - FileType.MPEG: "video/mpeg", - FileType.MPEGPS: "video/mpegps", - FileType.MPG: "video/mpg", - FileType.MPA: "audio/m4a", - FileType.MPGA: "audio/mpga", - FileType.OGG: "audio/ogg", - FileType.OPUS: "audio/opus", - FileType.PDF: "application/pdf", - FileType.PCM: "audio/pcm", - FileType.PNG: "image/png", - FileType.PPT: "application/vnd.ms-powerpoint", - FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation", - FileType.RTF: "application/rtf", - FileType.THREE_GPP: "video/3gpp", - FileType.TXT: "text/plain", - FileType.WAV: "audio/wav", - FileType.WEBM: "video/webm", - FileType.WEBP: "image/webp", - FileType.WMV: "video/wmv", - FileType.XLS: "application/vnd.ms-excel", - FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", -}) +FILE_EXTENSIONS: Dict[FileType, List[str]] = Dict( + { + FileType.AAC: ["aac"], + FileType.CSV: ["csv"], + FileType.DOC: ["doc"], + FileType.DOCX: ["docx"], + FileType.FLAC: ["flac"], + FileType.FLV: ["flv"], + FileType.GIF: ["gif"], + FileType.GOOGLE_DOC: ["gdoc"], + FileType.GOOGLE_DRAWINGS: ["gdraw"], + FileType.GOOGLE_SHEETS: ["gsheet"], + FileType.GOOGLE_SLIDES: ["gslides"], + FileType.HEIC: ["heic"], + FileType.HEIF: ["heif"], + FileType.HTML: ["html", "htm"], + FileType.JPEG: ["jpeg", "jpg"], + FileType.JSON: ["json"], + FileType.M4A: ["m4a"], + FileType.M4V: ["m4v"], + FileType.MOV: ["mov"], + FileType.MP3: ["mp3"], + FileType.MP4: ["mp4"], + FileType.MPEG: ["mpeg"], + FileType.MPEGPS: ["mpegps"], + FileType.MPG: ["mpg"], + FileType.MPA: ["mpa"], + FileType.MPGA: ["mpga"], + FileType.OGG: ["ogg"], + FileType.OPUS: ["opus"], + FileType.PDF: ["pdf"], + FileType.PCM: ["pcm"], + FileType.PNG: ["png"], + FileType.PPT: ["ppt"], + FileType.PPTX: ["pptx"], + FileType.RTF: ["rtf"], + FileType.THREE_GPP: ["3gpp"], + FileType.TXT: ["txt"], + FileType.WAV: ["wav"], + FileType.WEBM: ["webm"], + FileType.WEBP: ["webp"], + FileType.WMV: ["wmv"], + FileType.XLS: ["xls"], + FileType.XLSX: ["xlsx"], + } +) + +FILE_MIME_TYPES: Dict[FileType, str] = Dict( + { + FileType.AAC: "audio/aac", + FileType.CSV: "text/csv", + FileType.DOC: "application/msword", + FileType.DOCX: "application/vnd.openxmlformats-officedocument.wordprocessingml.document", + FileType.FLAC: "audio/flac", + FileType.FLV: "video/x-flv", + FileType.GIF: "image/gif", + FileType.GOOGLE_DOC: "application/vnd.google-apps.document", + FileType.GOOGLE_DRAWINGS: "application/vnd.google-apps.drawing", + FileType.GOOGLE_SHEETS: "application/vnd.google-apps.spreadsheet", + FileType.GOOGLE_SLIDES: "application/vnd.google-apps.presentation", + FileType.HEIC: "image/heic", + FileType.HEIF: "image/heif", + FileType.HTML: "text/html", + FileType.JPEG: "image/jpeg", + FileType.JSON: "application/json", + FileType.M4A: "audio/x-m4a", + FileType.M4V: "video/x-m4v", + FileType.MOV: "video/quicktime", + FileType.MP3: "audio/mpeg", + FileType.MP4: "video/mp4", + FileType.MPEG: "video/mpeg", + FileType.MPEGPS: "video/mpegps", + FileType.MPG: "video/mpg", + FileType.MPA: "audio/m4a", + FileType.MPGA: "audio/mpga", + FileType.OGG: "audio/ogg", + FileType.OPUS: "audio/opus", + FileType.PDF: "application/pdf", + FileType.PCM: "audio/pcm", + FileType.PNG: "image/png", + FileType.PPT: "application/vnd.ms-powerpoint", + FileType.PPTX: "application/vnd.openxmlformats-officedocument.presentationml.presentation", + FileType.RTF: "application/rtf", + FileType.THREE_GPP: "video/3gpp", + FileType.TXT: "text/plain", + FileType.WAV: "audio/wav", + FileType.WEBM: "video/webm", + FileType.WEBP: "image/webp", + FileType.WMV: "video/wmv", + FileType.XLS: "application/vnd.ms-excel", + FileType.XLSX: "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", + } +) """ Util Functions """ + + def get_file_mime_type_from_extension(extension: str) -> str: for file_type, extensions in FILE_EXTENSIONS.items(): if extension in extensions: @@ -167,6 +175,7 @@ def get_file_type_from_extension(extension: str) -> FileType: def get_file_extension_for_file_type(file_type: FileType) -> str: return FILE_EXTENSIONS[file_type][0] + def get_file_mime_type_for_file_type(file_type: FileType) -> str: return FILE_MIME_TYPES[file_type] @@ -182,12 +191,14 @@ IMAGE_FILE_TYPES = { FileType.GIF, FileType.WEBP, FileType.HEIC, - FileType.HEIF + FileType.HEIF, } + def is_image_file_type(file_type): return file_type in IMAGE_FILE_TYPES + # Videos VIDEO_FILE_TYPES = { FileType.MOV, @@ -199,12 +210,14 @@ VIDEO_FILE_TYPES = { FileType.MPG, FileType.WEBM, FileType.WMV, - FileType.THREE_GPP + FileType.THREE_GPP, } + def is_video_file_type(file_type): return file_type in VIDEO_FILE_TYPES + # Audio AUDIO_FILE_TYPES = { FileType.AAC, @@ -217,20 +230,19 @@ AUDIO_FILE_TYPES = { FileType.WAV, } + def is_audio_file_type(file_type): return file_type in AUDIO_FILE_TYPES + # Text -TEXT_FILE_TYPES = { - FileType.CSV, - FileType.HTML, - FileType.RTF, - FileType.TXT -} +TEXT_FILE_TYPES = {FileType.CSV, FileType.HTML, FileType.RTF, FileType.TXT} + def is_text_file_type(file_type): return file_type in TEXT_FILE_TYPES + """ Other FileType Groupings """ @@ -263,5 +275,6 @@ GEMINI_1_5_ACCEPTED_FILE_TYPES: Set[FileType] = { FileType.PDF, } + def is_gemini_1_5_accepted_file_type(file_type: FileType) -> bool: return file_type in GEMINI_1_5_ACCEPTED_FILE_TYPES From 6ffe5e75ba1fcdf64401d1ff54695d63bd2a1e78 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 14:31:31 -0700 Subject: [PATCH 48/52] use MappingProxyType for now. will fix python 3.8 install later --- litellm/types/files.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/litellm/types/files.py b/litellm/types/files.py index 5badcddcc1..2da8fe4806 100644 --- a/litellm/types/files.py +++ b/litellm/types/files.py @@ -1,5 +1,6 @@ from enum import Enum -from typing import List, Set, Dict +from types import MappingProxyType +from typing import List, Set """ Base Enums/Consts @@ -51,7 +52,7 @@ class FileType(Enum): XLSX = "XLSX" -FILE_EXTENSIONS: Dict[FileType, List[str]] = Dict( +FILE_EXTENSIONS: MappingProxyType[FileType, List[str]] = MappingProxyType( { FileType.AAC: ["aac"], FileType.CSV: ["csv"], @@ -98,7 +99,7 @@ FILE_EXTENSIONS: Dict[FileType, List[str]] = Dict( } ) -FILE_MIME_TYPES: Dict[FileType, str] = Dict( +FILE_MIME_TYPES: MappingProxyType[FileType, str] = MappingProxyType( { FileType.AAC: "audio/aac", FileType.CSV: "text/csv", From cbd39d665db61706c81b0093528c1b0bf7f56958 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 14:47:01 -0700 Subject: [PATCH 49/52] fix(bedrock_httpx.py): fix post call success logging --- litellm/llms/bedrock_httpx.py | 11 ++--------- 1 file changed, 2 insertions(+), 9 deletions(-) diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 59945a5857..afc2657610 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -178,7 +178,7 @@ async def make_call( logging_obj.post_call( input=messages, api_key="", - original_response=completion_stream, # Pass the completion stream for logging + original_response="first stream response received", additional_args={"complete_input_dict": data}, ) @@ -209,7 +209,7 @@ def make_sync_call( logging_obj.post_call( input=messages, api_key="", - original_response=completion_stream, # Pass the completion stream for logging + original_response="first stream response received", additional_args={"complete_input_dict": data}, ) @@ -1805,13 +1805,6 @@ class BedrockConverseLLM(BaseLLM): logging_obj=logging_obj, ) - ## LOGGING - logging_obj.post_call( - input=messages, - api_key="", - original_response=streaming_response, - additional_args={"complete_input_dict": data}, - ) return streaming_response ### COMPLETION From cb4bdee18d94c0ce5fa51d605f4be34cf10027e5 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 14:48:16 -0700 Subject: [PATCH 50/52] v0 - log proxy server exceptions on OTEL --- litellm/integrations/opentelemetry.py | 28 +++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/litellm/integrations/opentelemetry.py b/litellm/integrations/opentelemetry.py index 5a5b6d14dd..089b67368a 100644 --- a/litellm/integrations/opentelemetry.py +++ b/litellm/integrations/opentelemetry.py @@ -9,10 +9,13 @@ from typing import Union, Optional, TYPE_CHECKING, Any if TYPE_CHECKING: from opentelemetry.trace import Span as _Span + from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth Span = _Span + UserAPIKeyAuth = _UserAPIKeyAuth else: Span = Any + UserAPIKeyAuth = Any LITELLM_TRACER_NAME = os.getenv("OTEL_TRACER_NAME", "litellm") @@ -111,6 +114,31 @@ class OpenTelemetry(CustomLogger): service_logging_span.set_status(Status(StatusCode.OK)) service_logging_span.end(end_time=self._to_ns(end_time)) + async def async_post_call_failure_hook( + self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth + ): + from opentelemetry.trace import Status, StatusCode + from opentelemetry import trace + + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + parent_otel_span.set_status(Status(StatusCode.ERROR)) + _span_name = "Failed Proxy Server Request" + + # Exception Logging Child Span + exception_logging_span = self.tracer.start_span( + name=_span_name, + context=trace.set_span_in_context(parent_otel_span), + ) + exception_logging_span.set_attribute( + key="exception", value=str(original_exception) + ) + exception_logging_span.set_status(Status(StatusCode.ERROR)) + exception_logging_span.end(end_time=self._to_ns(datetime.now())) + + # End Parent OTEL Sspan + parent_otel_span.end(end_time=self._to_ns(datetime.now())) + def _handle_sucess(self, kwargs, response_obj, start_time, end_time): from opentelemetry.trace import Status, StatusCode From a13699647eaebdc839d1bd7650ad262ca3bea56c Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Fri, 7 Jun 2024 15:10:05 -0700 Subject: [PATCH 51/52] fix logic for deep copying otel spans / traces --- litellm/proxy/utils.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index db6eefc926..e89aae6ad2 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -83,15 +83,13 @@ def safe_deep_copy(data): if isinstance(data, dict): # remove litellm_parent_otel_span since this is not picklable if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: - data["metadata"].pop("litellm_parent_otel_span") + litellm_parent_otel_span = data["metadata"].pop("litellm_parent_otel_span") new_data = copy.deepcopy(data) # Step 2: re-add the litellm_parent_otel_span after doing a deep copy if isinstance(data, dict): - if "metadata" in data and "litellm_parent_otel_span" in data["metadata"]: - data["metadata"]["litellm_parent_otel_span"] = data["metadata"][ - "litellm_parent_otel_span" - ] + if "metadata" in data: + data["metadata"]["litellm_parent_otel_span"] = litellm_parent_otel_span return new_data From b16666b5dc42a2bb1c38c415c556ebe201645234 Mon Sep 17 00:00:00 2001 From: Krrish Dholakia Date: Fri, 7 Jun 2024 16:05:53 -0700 Subject: [PATCH 52/52] fix(utils.py): fix vertex ai exception mapping --- litellm/utils.py | 45 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 35 insertions(+), 10 deletions(-) diff --git a/litellm/utils.py b/litellm/utils.py index ae5879f07f..41668e3f21 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -9111,7 +9111,7 @@ def exception_type( if "Unable to locate credentials" in error_str: exception_mapping_worked = True raise BadRequestError( - message=f"SagemakerException - {error_str}", + message=f"litellm.BadRequestError: SagemakerException - {error_str}", model=model, llm_provider="sagemaker", response=original_exception.response, @@ -9145,10 +9145,16 @@ def exception_type( ): exception_mapping_worked = True raise BadRequestError( - message=f"VertexAIException BadRequestError - {error_str}", + message=f"litellm.BadRequestError: VertexAIException - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), litellm_debug_info=extra_information, ) elif ( @@ -9156,12 +9162,19 @@ def exception_type( or "Content has no parts." in error_str ): exception_mapping_worked = True - raise APIError( - message=f"VertexAIException APIError - {error_str}", + raise litellm.InternalServerError( + message=f"litellm.InternalServerError: VertexAIException - {error_str}", status_code=500, model=model, llm_provider="vertex_ai", - request=original_exception.request, + request=( + original_exception.request + if hasattr(original_exception, "request") + else httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ) + ), litellm_debug_info=extra_information, ) elif "403" in error_str: @@ -9170,7 +9183,13 @@ def exception_type( message=f"VertexAIException BadRequestError - {error_str}", model=model, llm_provider="vertex_ai", - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), litellm_debug_info=extra_information, ) elif "The response was blocked." in error_str: @@ -9217,12 +9236,18 @@ def exception_type( model=model, llm_provider="vertex_ai", litellm_debug_info=extra_information, - response=original_exception.response, + response=httpx.Response( + status_code=429, + request=httpx.Request( + method="POST", + url=" https://cloud.google.com/vertex-ai/", + ), + ), ) if original_exception.status_code == 500: exception_mapping_worked = True - raise APIError( - message=f"VertexAIException APIError - {error_str}", + raise litellm.InternalServerError( + message=f"VertexAIException InternalServerError - {error_str}", status_code=500, model=model, llm_provider="vertex_ai",