diff --git a/litellm/llms/bedrock_httpx.py b/litellm/llms/bedrock_httpx.py index 6329f165e..7aba78d7c 100644 --- a/litellm/llms/bedrock_httpx.py +++ b/litellm/llms/bedrock_httpx.py @@ -1125,7 +1125,7 @@ class AmazonConverseConfig: maxTokens: Optional[int] = None, stopSequences: Optional[List[str]] = None, temperature: Optional[int] = None, - top_p: Optional[int] = None, + topP: Optional[int] = None, ) -> None: locals_ = locals() for key, value in locals_.items(): @@ -1481,6 +1481,93 @@ class BedrockConverseLLM(BaseLLM): return session.get_credentials() + async def async_streaming( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> CustomStreamWrapper: + streaming_response = CustomStreamWrapper( + completion_stream=None, + make_call=partial( + make_call, + client=client, + api_base=api_base, + headers=headers, + data=data, + model=model, + messages=messages, + logging_obj=logging_obj, + ), + model=model, + custom_llm_provider="bedrock", + logging_obj=logging_obj, + ) + return streaming_response + + async def async_completion( + self, + model: str, + messages: list, + api_base: str, + model_response: ModelResponse, + print_verbose: Callable, + data: str, + timeout: Optional[Union[float, httpx.Timeout]], + encoding, + logging_obj, + stream, + optional_params: dict, + litellm_params=None, + logger_fn=None, + headers={}, + client: Optional[AsyncHTTPHandler] = None, + ) -> Union[ModelResponse, CustomStreamWrapper]: + if client is None: + _params = {} + if timeout is not None: + if isinstance(timeout, float) or isinstance(timeout, int): + timeout = httpx.Timeout(timeout) + _params["timeout"] = timeout + client = AsyncHTTPHandler(**_params) # type: ignore + else: + client = client # type: ignore + + try: + response = await client.post(api_base, headers=headers, data=data) # type: ignore + response.raise_for_status() + except httpx.HTTPStatusError as err: + error_code = err.response.status_code + raise BedrockError(status_code=error_code, message=err.response.text) + except httpx.TimeoutException as e: + raise BedrockError(status_code=408, message="Timeout error occurred.") + + return self.process_response( + model=model, + response=response, + model_response=model_response, + stream=stream if isinstance(stream, bool) else False, + logging_obj=logging_obj, + api_key="", + data=data, + messages=messages, + print_verbose=print_verbose, + optional_params=optional_params, + encoding=encoding, + ) + def completion( self, model: str, @@ -1504,7 +1591,7 @@ class BedrockConverseLLM(BaseLLM): from botocore.auth import SigV4Auth from botocore.awsrequest import AWSRequest from botocore.credentials import Credentials - except ImportError as e: + except ImportError: raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") ## SETUP ## @@ -1658,6 +1745,46 @@ class BedrockConverseLLM(BaseLLM): ) ### ROUTING (ASYNC, STREAMING, SYNC) + if acompletion: + if isinstance(client, HTTPHandler): + client = None + if stream is True and provider != "ai21": + return self.async_streaming( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=True, + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + ### ASYNC COMPLETION + return self.async_completion( + model=model, + messages=messages, + data=data, + api_base=prepped.url, + model_response=model_response, + print_verbose=print_verbose, + encoding=encoding, + logging_obj=logging_obj, + optional_params=optional_params, + stream=stream, # type: ignore + litellm_params=litellm_params, + logger_fn=logger_fn, + headers=prepped.headers, + timeout=timeout, + client=client, + ) # type: ignore + if (stream is not None and stream is True) and provider != "ai21": streaming_response = CustomStreamWrapper( @@ -1666,7 +1793,7 @@ class BedrockConverseLLM(BaseLLM): make_sync_call, client=None, api_base=prepped.url, - headers=prepped.headers, + headers=prepped.headers, # type: ignore data=data, model=model, messages=messages, @@ -1702,7 +1829,7 @@ class BedrockConverseLLM(BaseLLM): except httpx.HTTPStatusError as err: error_code = err.response.status_code raise BedrockError(status_code=error_code, message=response.text) - except httpx.TimeoutException as e: + except httpx.TimeoutException: raise BedrockError(status_code=408, message="Timeout error occurred.") return self.process_response( @@ -1737,7 +1864,7 @@ class AWSEventStreamDecoder: self.model = model self.parser = EventStreamJSONParser() - def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + def converse_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" tool_str = "" is_finished = False @@ -1762,7 +1889,7 @@ class AWSEventStreamDecoder: ) return response - def _old_chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: + def _chunk_parser(self, chunk_data: dict) -> GenericStreamingChunk: text = "" is_finished = False finish_reason = "" @@ -1774,19 +1901,8 @@ class AWSEventStreamDecoder: is_finished = True finish_reason = "stop" ######## bedrock.anthropic mappings ############### - elif "completion" in chunk_data: # not claude-3 - text = chunk_data["completion"] # bedrock.anthropic - stop_reason = chunk_data.get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason elif "delta" in chunk_data: - if chunk_data["delta"].get("text", None) is not None: - text = chunk_data["delta"]["text"] - stop_reason = chunk_data["delta"].get("stop_reason", None) - if stop_reason != None: - is_finished = True - finish_reason = stop_reason + return self.converse_chunk_parser(chunk_data=chunk_data) ######## bedrock.mistral mappings ############### elif "outputs" in chunk_data: if ( @@ -1851,11 +1967,17 @@ class AWSEventStreamDecoder: def _parse_message_from_event(self, event) -> Optional[str]: response_dict = event.to_response_dict() + parsed_response = self.parser.parse(response_dict, get_response_stream_shape()) if response_dict["status_code"] != 200: raise ValueError(f"Bad response code, expected 200: {response_dict}") + if "chunk" in parsed_response: + chunk = parsed_response.get("chunk") + if not chunk: + return None + return chunk.get("bytes").decode() # type: ignore[no-any-return] + else: + chunk = response_dict.get("body") + if not chunk: + return None - chunk = response_dict.get("body") - if not chunk: - return None - - return chunk.decode() # type: ignore[no-any-return] + return chunk.decode() # type: ignore[no-any-return] diff --git a/litellm/llms/custom_httpx/http_handler.py b/litellm/llms/custom_httpx/http_handler.py index d8dd4f01e..5ec9c79bb 100644 --- a/litellm/llms/custom_httpx/http_handler.py +++ b/litellm/llms/custom_httpx/http_handler.py @@ -168,7 +168,6 @@ class HTTPHandler: return response def __del__(self) -> None: - traceback.print_stack() try: self.close() except Exception: diff --git a/litellm/main.py b/litellm/main.py index c95b419ba..15334d041 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -121,7 +121,8 @@ azure_text_completions = AzureTextCompletion() huggingface = Huggingface() predibase_chat_completions = PredibaseChatCompletion() triton_chat_completions = TritonChatCompletion() -bedrock_chat_completion = BedrockConverseLLM() +bedrock_chat_completion = BedrockLLM() +bedrock_converse_chat_completion = BedrockConverseLLM() vertex_chat_completion = VertexLLM() ####### COMPLETION ENDPOINTS ################ @@ -2097,22 +2098,40 @@ def completion( logging_obj=logging, ) else: - response = bedrock_chat_completion.completion( - model=model, - messages=messages, - custom_prompt_dict=custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - extra_headers=extra_headers, - timeout=timeout, - acompletion=acompletion, - client=client, - ) + if model.startswith("anthropic"): + response = bedrock_converse_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + ) + else: + response = bedrock_chat_completion.completion( + model=model, + messages=messages, + custom_prompt_dict=custom_prompt_dict, + model_response=model_response, + print_verbose=print_verbose, + optional_params=optional_params, + litellm_params=litellm_params, + logger_fn=logger_fn, + encoding=encoding, + logging_obj=logging, + extra_headers=extra_headers, + timeout=timeout, + acompletion=acompletion, + client=client, + ) if optional_params.get("stream", False): ## LOGGING logging.post_call( diff --git a/litellm/tests/log.txt b/litellm/tests/log.txt deleted file mode 100644 index ea07ca7e1..000000000 --- a/litellm/tests/log.txt +++ /dev/null @@ -1,4274 +0,0 @@ -============================= test session starts ============================== -platform darwin -- Python 3.11.4, pytest-8.2.0, pluggy-1.5.0 -rootdir: /Users/krrishdholakia/Documents/litellm -configfile: pyproject.toml -plugins: asyncio-0.23.6, mock-3.14.0, anyio-4.2.0 -asyncio: mode=Mode.STRICT -collected 1 item - -test_amazing_vertex_completion.py F [100%] - -=================================== FAILURES =================================== -____________________________ test_gemini_pro_vision ____________________________ - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -model_response = ModelResponse(id='chatcmpl-722df0e7-4e2d-44e6-9e2c-49823faa0189', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1716145725, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = -logging_obj = -vertex_project = None, vertex_location = None, vertex_credentials = None -optional_params = {} -litellm_params = {'acompletion': False, 'api_base': '', 'api_key': None, 'completion_call_id': None, ...} -logger_fn = None, acompletion = False - - def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - ): - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - ) - from google.cloud import aiplatform # type: ignore - from google.protobuf import json_format # type: ignore - from google.protobuf.struct_pb2 import Value # type: ignore - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - import google.auth # type: ignore - import proto # type: ignore - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" - ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - json_obj = json.loads(vertex_credentials) - - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - - ## Load Config - config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## Process safety settings into format expected by vertex AI - safety_settings = None - if "safety_settings" in optional_params: - safety_settings = optional_params.pop("safety_settings") - if not isinstance(safety_settings, list): - raise ValueError("safety_settings must be a list") - if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): - raise ValueError("safety_settings must be a list of dicts") - safety_settings = [ - gapic_content_types.SafetySetting(x) for x in safety_settings - ] - - # vertexai does not use an API key, it looks for credentials.json in the environment - - prompt = " ".join( - [ - message["content"] - for message in messages - if isinstance(message["content"], str) - ] - ) - - mode = "" - - request_str = "" - response_obj = None - async_client = None - instances = None - client_options = { - "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" - } - if ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - llm_model = GenerativeModel(model) - mode = "vision" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_chat_models: - llm_model = ChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = ChatModel.from_pretrained({model})\n" - elif model in litellm.vertex_text_models: - llm_model = TextGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_text_models: - llm_model = CodeGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models - llm_model = CodeChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - elif model == "private": - mode = "private" - model = optional_params.pop("model_id", None) - # private endpoint requires a dict instead of JSON - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - llm_model = aiplatform.PrivateEndpoint( - endpoint_name=model, - project=vertex_project, - location=vertex_location, - ) - request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" - else: # assume vertex model garden on public endpoint - mode = "custom" - - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - # Will determine the API used based on async parameter - llm_model = None - - # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: - data = { - "llm_model": llm_model, - "mode": mode, - "prompt": prompt, - "logging_obj": logging_obj, - "request_str": request_str, - "model": model, - "model_response": model_response, - "encoding": encoding, - "messages": messages, - "print_verbose": print_verbose, - "client_options": client_options, - "instances": instances, - "vertex_location": vertex_location, - "vertex_project": vertex_project, - "safety_settings": safety_settings, - **optional_params, - } - if optional_params.get("stream", False) is True: - # async streaming - return async_streaming(**data) - - return async_completion(**data) - - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) - stream = optional_params.pop("stream", False) - if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - model_response = llm_model.generate_content( - contents={"content": content}, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call -> response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - -../llms/vertex_ai.py:740: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:405: in generate_content - return self._generate_content( -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:487: in _generate_content - request = self._prepare_request( -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:274: in _prepare_request - contents = [ -../proxy/myenv/lib/python3.11/site-packages/vertexai/generative_models/_generative_models.py:275: in - gapic_content_types.Content(content_dict) for content_dict in contents -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -self = <[AttributeError('Unknown field for Content: _pb') raised in repr()] Content object at 0x1646aaa90> -mapping = {'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -], 'role': 'user'} -ignore_unknown_fields = False, kwargs = {} -params = {'parts': [text: "Whats in this image?" -, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -], 'role': 'user'} -marshal = , key = 'parts' -value = [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -] -pb_value = [text: "Whats in this image?" -, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -] - - def __init__( - self, - mapping=None, - *, - ignore_unknown_fields=False, - **kwargs, - ): - # We accept several things for `mapping`: - # * An instance of this class. - # * An instance of the underlying protobuf descriptor class. - # * A dict - # * Nothing (keyword arguments only). - if mapping is None: - if not kwargs: - # Special fast path for empty construction. - super().__setattr__("_pb", self._meta.pb()) - return - - mapping = kwargs - elif isinstance(mapping, self._meta.pb): - # Make a copy of the mapping. - # This is a constructor for a new object, so users will assume - # that it will not have side effects on the arguments being - # passed in. - # - # The `wrap` method on the metaclass is the public API for taking - # ownership of the passed in protobuf object. - mapping = copy.deepcopy(mapping) - if kwargs: - mapping.MergeFrom(self._meta.pb(**kwargs)) - - super().__setattr__("_pb", mapping) - return - elif isinstance(mapping, type(self)): - # Just use the above logic on mapping's underlying pb. - self.__init__(mapping=mapping._pb, **kwargs) - return - elif isinstance(mapping, collections.abc.Mapping): - # Can't have side effects on mapping. - mapping = copy.copy(mapping) - # kwargs entries take priority for duplicate keys. - mapping.update(kwargs) - else: - # Sanity check: Did we get something not a map? Error if so. - raise TypeError( - "Invalid constructor input for %s: %r" - % ( - self.__class__.__name__, - mapping, - ) - ) - - params = {} - # Update the mapping to address any values that need to be - # coerced. - marshal = self._meta.marshal - for key, value in mapping.items(): - (key, pb_type) = self._get_pb_type_from_key(key) - if pb_type is None: - if ignore_unknown_fields: - continue - - raise ValueError( - "Unknown field for {}: {}".format(self.__class__.__name__, key) - ) - - try: - pb_value = marshal.to_proto(pb_type, value) - except ValueError: - # Underscores may be appended to field names - # that collide with python or proto-plus keywords. - # In case a key only exists with a `_` suffix, coerce the key - # to include the `_` suffix. It's not possible to - # natively define the same field with a trailing underscore in protobuf. - # See related issue - # https://github.com/googleapis/python-api-core/issues/227 - if isinstance(value, dict): - if _upb: - # In UPB, pb_type is MessageMeta which doesn't expose attrs like it used to in Python/CPP. - keys_to_update = [ - item - for item in value - if item not in pb_type.DESCRIPTOR.fields_by_name - and f"{item}_" in pb_type.DESCRIPTOR.fields_by_name - ] - else: - keys_to_update = [ - item - for item in value - if not hasattr(pb_type, item) - and hasattr(pb_type, f"{item}_") - ] - for item in keys_to_update: - value[f"{item}_"] = value.pop(item) - - pb_value = marshal.to_proto(pb_type, value) - - if pb_value is not None: - params[key] = pb_value - - # Create the internal protocol buffer. -> super().__setattr__("_pb", self._meta.pb(**params)) -E TypeError: Parameter to MergeFrom() must be instance of same class: expected got . - -../proxy/myenv/lib/python3.11/site-packages/proto/message.py:615: TypeError - -During handling of the above exception, another exception occurred: - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -timeout = 600.0, temperature = None, top_p = None, n = None, stream = None -stream_options = None, stop = None, max_tokens = None, presence_penalty = None -frequency_penalty = None, logit_bias = None, user = None, response_format = None -seed = None, tools = None, tool_choice = None, logprobs = None -top_logprobs = None, deployment_id = None, extra_headers = None -functions = None, function_call = None, base_url = None, api_version = None -api_key = None, model_list = None -kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } -args = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -api_base = None, mock_response = None, force_timeout = 600, logger_fn = None -verbose = False, custom_llm_provider = 'vertex_ai' - - @client - def completion( - model: str, - # Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create - messages: List = [], - timeout: Optional[Union[float, str, httpx.Timeout]] = None, - temperature: Optional[float] = None, - top_p: Optional[float] = None, - n: Optional[int] = None, - stream: Optional[bool] = None, - stream_options: Optional[dict] = None, - stop=None, - max_tokens: Optional[int] = None, - presence_penalty: Optional[float] = None, - frequency_penalty: Optional[float] = None, - logit_bias: Optional[dict] = None, - user: Optional[str] = None, - # openai v1.0+ new params - response_format: Optional[dict] = None, - seed: Optional[int] = None, - tools: Optional[List] = None, - tool_choice: Optional[str] = None, - logprobs: Optional[bool] = None, - top_logprobs: Optional[int] = None, - deployment_id=None, - extra_headers: Optional[dict] = None, - # soon to be deprecated params by OpenAI - functions: Optional[List] = None, - function_call: Optional[str] = None, - # set api_base, api_version, api_key - base_url: Optional[str] = None, - api_version: Optional[str] = None, - api_key: Optional[str] = None, - model_list: Optional[list] = None, # pass in a list of api_base,keys, etc. - # Optional liteLLM function params - **kwargs, - ) -> Union[ModelResponse, CustomStreamWrapper]: - """ - Perform a completion() using any of litellm supported llms (example gpt-4, gpt-3.5-turbo, claude-2, command-nightly) - Parameters: - model (str): The name of the language model to use for text completion. see all supported LLMs: https://docs.litellm.ai/docs/providers/ - messages (List): A list of message objects representing the conversation context (default is an empty list). - - OPTIONAL PARAMS - functions (List, optional): A list of functions to apply to the conversation messages (default is an empty list). - function_call (str, optional): The name of the function to call within the conversation (default is an empty string). - temperature (float, optional): The temperature parameter for controlling the randomness of the output (default is 1.0). - top_p (float, optional): The top-p parameter for nucleus sampling (default is 1.0). - n (int, optional): The number of completions to generate (default is 1). - stream (bool, optional): If True, return a streaming response (default is False). - stream_options (dict, optional): A dictionary containing options for the streaming response. Only set this when you set stream: true. - stop(string/list, optional): - Up to 4 sequences where the LLM API will stop generating further tokens. - max_tokens (integer, optional): The maximum number of tokens in the generated completion (default is infinity). - presence_penalty (float, optional): It is used to penalize new tokens based on their existence in the text so far. - frequency_penalty: It is used to penalize new tokens based on their frequency in the text so far. - logit_bias (dict, optional): Used to modify the probability of specific tokens appearing in the completion. - user (str, optional): A unique identifier representing your end-user. This can help the LLM provider to monitor and detect abuse. - logprobs (bool, optional): Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the content of message - top_logprobs (int, optional): An integer between 0 and 5 specifying the number of most likely tokens to return at each token position, each with an associated log probability. logprobs must be set to true if this parameter is used. - metadata (dict, optional): Pass in additional metadata to tag your completion calls - eg. prompt version, details, etc. - api_base (str, optional): Base URL for the API (default is None). - api_version (str, optional): API version (default is None). - api_key (str, optional): API key (default is None). - model_list (list, optional): List of api base, version, keys - extra_headers (dict, optional): Additional headers to include in the request. - - LITELLM Specific Params - mock_response (str, optional): If provided, return a mock completion response for testing or debugging purposes (default is None). - custom_llm_provider (str, optional): Used for Non-OpenAI LLMs, Example usage for bedrock, set model="amazon.titan-tg1-large" and custom_llm_provider="bedrock" - max_retries (int, optional): The number of retries to attempt (default is 0). - Returns: - ModelResponse: A response object containing the generated completion and associated metadata. - - Note: - - This function is used to perform completions() using the specified language model. - - It supports various optional parameters for customizing the completion behavior. - - If 'mock_response' is provided, a mock completion response is returned for testing or debugging. - """ - ######### unpacking kwargs ##################### - args = locals() - api_base = kwargs.get("api_base", None) - mock_response = kwargs.get("mock_response", None) - force_timeout = kwargs.get("force_timeout", 600) ## deprecated - logger_fn = kwargs.get("logger_fn", None) - verbose = kwargs.get("verbose", False) - custom_llm_provider = kwargs.get("custom_llm_provider", None) - litellm_logging_obj = kwargs.get("litellm_logging_obj", None) - id = kwargs.get("id", None) - metadata = kwargs.get("metadata", None) - model_info = kwargs.get("model_info", None) - proxy_server_request = kwargs.get("proxy_server_request", None) - fallbacks = kwargs.get("fallbacks", None) - headers = kwargs.get("headers", None) or extra_headers - num_retries = kwargs.get("num_retries", None) ## deprecated - max_retries = kwargs.get("max_retries", None) - context_window_fallback_dict = kwargs.get("context_window_fallback_dict", None) - organization = kwargs.get("organization", None) - ### CUSTOM MODEL COST ### - input_cost_per_token = kwargs.get("input_cost_per_token", None) - output_cost_per_token = kwargs.get("output_cost_per_token", None) - input_cost_per_second = kwargs.get("input_cost_per_second", None) - output_cost_per_second = kwargs.get("output_cost_per_second", None) - ### CUSTOM PROMPT TEMPLATE ### - initial_prompt_value = kwargs.get("initial_prompt_value", None) - roles = kwargs.get("roles", None) - final_prompt_value = kwargs.get("final_prompt_value", None) - bos_token = kwargs.get("bos_token", None) - eos_token = kwargs.get("eos_token", None) - preset_cache_key = kwargs.get("preset_cache_key", None) - hf_model_name = kwargs.get("hf_model_name", None) - supports_system_message = kwargs.get("supports_system_message", None) - ### TEXT COMPLETION CALLS ### - text_completion = kwargs.get("text_completion", False) - atext_completion = kwargs.get("atext_completion", False) - ### ASYNC CALLS ### - acompletion = kwargs.get("acompletion", False) - client = kwargs.get("client", None) - ### Admin Controls ### - no_log = kwargs.get("no-log", False) - ######## end of unpacking kwargs ########### - openai_params = [ - "functions", - "function_call", - "temperature", - "temperature", - "top_p", - "n", - "stream", - "stream_options", - "stop", - "max_tokens", - "presence_penalty", - "frequency_penalty", - "logit_bias", - "user", - "request_timeout", - "api_base", - "api_version", - "api_key", - "deployment_id", - "organization", - "base_url", - "default_headers", - "timeout", - "response_format", - "seed", - "tools", - "tool_choice", - "max_retries", - "logprobs", - "top_logprobs", - "extra_headers", - ] - litellm_params = [ - "metadata", - "acompletion", - "atext_completion", - "text_completion", - "caching", - "mock_response", - "api_key", - "api_version", - "api_base", - "force_timeout", - "logger_fn", - "verbose", - "custom_llm_provider", - "litellm_logging_obj", - "litellm_call_id", - "use_client", - "id", - "fallbacks", - "azure", - "headers", - "model_list", - "num_retries", - "context_window_fallback_dict", - "retry_policy", - "roles", - "final_prompt_value", - "bos_token", - "eos_token", - "request_timeout", - "complete_response", - "self", - "client", - "rpm", - "tpm", - "max_parallel_requests", - "input_cost_per_token", - "output_cost_per_token", - "input_cost_per_second", - "output_cost_per_second", - "hf_model_name", - "model_info", - "proxy_server_request", - "preset_cache_key", - "caching_groups", - "ttl", - "cache", - "no-log", - "base_model", - "stream_timeout", - "supports_system_message", - "region_name", - "allowed_model_region", - "model_config", - ] - - default_params = openai_params + litellm_params - non_default_params = { - k: v for k, v in kwargs.items() if k not in default_params - } # model-specific params - pass them straight to the model/provider - - try: - if base_url is not None: - api_base = base_url - if max_retries is not None: # openai allows openai.OpenAI(max_retries=3) - num_retries = max_retries - logging = litellm_logging_obj - fallbacks = fallbacks or litellm.model_fallbacks - if fallbacks is not None: - return completion_with_fallbacks(**args) - if model_list is not None: - deployments = [ - m["litellm_params"] for m in model_list if m["model_name"] == model - ] - return batch_completion_models(deployments=deployments, **args) - if litellm.model_alias_map and model in litellm.model_alias_map: - model = litellm.model_alias_map[ - model - ] # update the model to the actual value if an alias has been passed in - model_response = ModelResponse() - setattr(model_response, "usage", litellm.Usage()) - if ( - kwargs.get("azure", False) == True - ): # don't remove flag check, to remain backwards compatible for repos like Codium - custom_llm_provider = "azure" - if deployment_id != None: # azure llms - model = deployment_id - custom_llm_provider = "azure" - model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( - model=model, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - api_key=api_key, - ) - if model_response is not None and hasattr(model_response, "_hidden_params"): - model_response._hidden_params["custom_llm_provider"] = custom_llm_provider - model_response._hidden_params["region_name"] = kwargs.get( - "aws_region_name", None - ) # support region-based pricing for bedrock - - ### TIMEOUT LOGIC ### - timeout = timeout or kwargs.get("request_timeout", 600) or 600 - # set timeout for 10 minutes by default - if isinstance(timeout, httpx.Timeout) and not supports_httpx_timeout( - custom_llm_provider - ): - timeout = timeout.read or 600 # default 10 min timeout - elif not isinstance(timeout, httpx.Timeout): - timeout = float(timeout) # type: ignore - - ### REGISTER CUSTOM MODEL PRICING -- IF GIVEN ### - if input_cost_per_token is not None and output_cost_per_token is not None: - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_token": input_cost_per_token, - "output_cost_per_token": output_cost_per_token, - "litellm_provider": custom_llm_provider, - }, - } - ) - elif ( - input_cost_per_second is not None - ): # time based pricing just needs cost in place - output_cost_per_second = output_cost_per_second - litellm.register_model( - { - f"{custom_llm_provider}/{model}": { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - model: { - "input_cost_per_second": input_cost_per_second, - "output_cost_per_second": output_cost_per_second, - "litellm_provider": custom_llm_provider, - }, - } - ) - ### BUILD CUSTOM PROMPT TEMPLATE -- IF GIVEN ### - custom_prompt_dict = {} # type: ignore - if ( - initial_prompt_value - or roles - or final_prompt_value - or bos_token - or eos_token - ): - custom_prompt_dict = {model: {}} - if initial_prompt_value: - custom_prompt_dict[model]["initial_prompt_value"] = initial_prompt_value - if roles: - custom_prompt_dict[model]["roles"] = roles - if final_prompt_value: - custom_prompt_dict[model]["final_prompt_value"] = final_prompt_value - if bos_token: - custom_prompt_dict[model]["bos_token"] = bos_token - if eos_token: - custom_prompt_dict[model]["eos_token"] = eos_token - - if ( - supports_system_message is not None - and isinstance(supports_system_message, bool) - and supports_system_message == False - ): - messages = map_system_message_pt(messages=messages) - model_api_key = get_api_key( - llm_provider=custom_llm_provider, dynamic_api_key=api_key - ) # get the api key from the environment if required for the model - - if dynamic_api_key is not None: - api_key = dynamic_api_key - # check if user passed in any of the OpenAI optional params - optional_params = get_optional_params( - functions=functions, - function_call=function_call, - temperature=temperature, - top_p=top_p, - n=n, - stream=stream, - stream_options=stream_options, - stop=stop, - max_tokens=max_tokens, - presence_penalty=presence_penalty, - frequency_penalty=frequency_penalty, - logit_bias=logit_bias, - user=user, - # params to identify the model - model=model, - custom_llm_provider=custom_llm_provider, - response_format=response_format, - seed=seed, - tools=tools, - tool_choice=tool_choice, - max_retries=max_retries, - logprobs=logprobs, - top_logprobs=top_logprobs, - extra_headers=extra_headers, - **non_default_params, - ) - - if litellm.add_function_to_prompt and optional_params.get( - "functions_unsupported_model", None - ): # if user opts to add it to prompt, when API doesn't support function calling - functions_unsupported_model = optional_params.pop( - "functions_unsupported_model" - ) - messages = function_call_prompt( - messages=messages, functions=functions_unsupported_model - ) - - # For logging - save the values of the litellm-specific params passed in - litellm_params = get_litellm_params( - acompletion=acompletion, - api_key=api_key, - force_timeout=force_timeout, - logger_fn=logger_fn, - verbose=verbose, - custom_llm_provider=custom_llm_provider, - api_base=api_base, - litellm_call_id=kwargs.get("litellm_call_id", None), - model_alias_map=litellm.model_alias_map, - completion_call_id=id, - metadata=metadata, - model_info=model_info, - proxy_server_request=proxy_server_request, - preset_cache_key=preset_cache_key, - no_log=no_log, - input_cost_per_second=input_cost_per_second, - input_cost_per_token=input_cost_per_token, - output_cost_per_second=output_cost_per_second, - output_cost_per_token=output_cost_per_token, - ) - logging.update_environment_variables( - model=model, - user=user, - optional_params=optional_params, - litellm_params=litellm_params, - ) - if mock_response: - return mock_completion( - model, - messages, - stream=stream, - mock_response=mock_response, - logging=logging, - acompletion=acompletion, - ) - if custom_llm_provider == "azure": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif custom_llm_provider == "azure_text": - # azure configs - api_type = get_secret("AZURE_API_TYPE") or "azure" - - api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE") - - api_version = ( - api_version or litellm.api_version or get_secret("AZURE_API_VERSION") - ) - - api_key = ( - api_key - or litellm.api_key - or litellm.azure_key - or get_secret("AZURE_OPENAI_API_KEY") - or get_secret("AZURE_API_KEY") - ) - - azure_ad_token = optional_params.get("extra_body", {}).pop( - "azure_ad_token", None - ) or get_secret("AZURE_AD_TOKEN") - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.AzureOpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > azure_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - response = azure_text_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - api_version=api_version, - api_type=api_type, - azure_ad_token=azure_ad_token, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, - client=client, # pass AsyncAzureOpenAI, AzureOpenAI client - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={ - "headers": headers, - "api_version": api_version, - "api_base": api_base, - }, - ) - elif ( - model in litellm.open_ai_chat_completion_models - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "groq" - or custom_llm_provider == "deepseek" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openai" - or custom_llm_provider == "together_ai" - or custom_llm_provider in litellm.openai_compatible_providers - or "ft:gpt-3.5-turbo" in model # finetune gpt-3.5-turbo - ): # allow user to make an openai call with a custom base - # note: if a user sets a custom base - we should ensure this works - # allow for the setting of dynamic and stateful api-bases - api_base = ( - api_base # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - openai.organization = ( - organization - or litellm.organization - or get_secret("OPENAI_ORGANIZATION") - or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105 - ) - # set API KEY - api_key = ( - api_key - or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAIConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - - ## COMPLETION CALL - try: - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - custom_prompt_dict=custom_prompt_dict, - client=client, # pass AsyncOpenAI, OpenAI client - organization=organization, - custom_llm_provider=custom_llm_provider, - ) - except Exception as e: - ## LOGGING - log the original exception returned - logging.post_call( - input=messages, - api_key=api_key, - original_response=str(e), - additional_args={"headers": headers}, - ) - raise e - - if optional_params.get("stream", False): - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - additional_args={"headers": headers}, - ) - elif ( - custom_llm_provider == "text-completion-openai" - or "ft:babbage-002" in model - or "ft:davinci-002" in model # support for finetuned completion models - ): - openai.api_type = "openai" - - api_base = ( - api_base - or litellm.api_base - or get_secret("OPENAI_API_BASE") - or "https://api.openai.com/v1" - ) - - openai.api_version = None - # set API KEY - - api_key = ( - api_key - or litellm.api_key - or litellm.openai_key - or get_secret("OPENAI_API_KEY") - ) - - headers = headers or litellm.headers - - ## LOAD CONFIG - if set - config = litellm.OpenAITextCompletionConfig.get_config() - for k, v in config.items(): - if ( - k not in optional_params - ): # completion(top_k=3) > openai_text_config(top_k=3) <- allows for dynamic variables to be passed in - optional_params[k] = v - if litellm.organization: - openai.organization = litellm.organization - - if ( - len(messages) > 0 - and "content" in messages[0] - and type(messages[0]["content"]) == list - ): - # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] - # https://platform.openai.com/docs/api-reference/completions/create - prompt = messages[0]["content"] - else: - prompt = " ".join([message["content"] for message in messages]) # type: ignore - - ## COMPLETION CALL - _response = openai_text_completions.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - api_key=api_key, - api_base=api_base, - acompletion=acompletion, - client=client, # pass AsyncOpenAI, OpenAI client - logging_obj=logging, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - timeout=timeout, # type: ignore - ) - - if ( - optional_params.get("stream", False) == False - and acompletion == False - and text_completion == False - ): - # convert to chat completion response - _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( - response_object=_response, model_response_object=model_response - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=_response, - additional_args={"headers": headers}, - ) - response = _response - elif ( - "replicate" in model - or custom_llm_provider == "replicate" - or model in litellm.replicate_models - ): - # Setting the relevant API KEY for replicate, replicate defaults to using os.environ.get("REPLICATE_API_TOKEN") - replicate_key = None - replicate_key = ( - api_key - or litellm.replicate_key - or litellm.api_key - or get_secret("REPLICATE_API_KEY") - or get_secret("REPLICATE_API_TOKEN") - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("REPLICATE_API_BASE") - or "https://api.replicate.com/v1" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = replicate.completion( # type: ignore - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=replicate_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - acompletion=acompletion, - ) - - if optional_params.get("stream", False) == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=replicate_key, - original_response=model_response, - ) - - response = model_response - elif ( - "clarifai" in model - or custom_llm_provider == "clarifai" - or model in litellm.clarifai_models - ): - clarifai_key = None - clarifai_key = ( - api_key - or litellm.clarifai_key - or litellm.api_key - or get_secret("CLARIFAI_API_KEY") - or get_secret("CLARIFAI_API_TOKEN") - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("CLARIFAI_API_BASE") - or "https://api.clarifai.com/v2" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = clarifai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - acompletion=acompletion, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=clarifai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=model_response, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=clarifai_key, - original_response=model_response, - ) - response = model_response - - elif custom_llm_provider == "anthropic": - api_key = ( - api_key - or litellm.anthropic_key - or litellm.api_key - or os.environ.get("ANTHROPIC_API_KEY") - ) - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - if (model == "claude-2") or (model == "claude-instant-1"): - # call anthropic /completion, only use this route for claude-2, claude-instant-1 - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/complete" - ) - response = anthropic_text_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - else: - # call /messages - # default route for all anthropic models - api_base = ( - api_base - or litellm.api_base - or get_secret("ANTHROPIC_API_BASE") - or "https://api.anthropic.com/v1/messages" - ) - response = anthropic_chat_completions.completion( - model=model, - messages=messages, - api_base=api_base, - acompletion=acompletion, - custom_prompt_dict=litellm.custom_prompt_dict, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, # for calculating input/output tokens - api_key=api_key, - logging_obj=logging, - headers=headers, - ) - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - response = response - elif custom_llm_provider == "nlp_cloud": - nlp_cloud_key = ( - api_key - or litellm.nlp_cloud_key - or get_secret("NLP_CLOUD_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("NLP_CLOUD_API_BASE") - or "https://api.nlpcloud.io/v1/gpu/" - ) - - response = nlp_cloud.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=nlp_cloud_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - response, - model, - custom_llm_provider="nlp_cloud", - logging_obj=logging, - ) - - if optional_params.get("stream", False) or acompletion == True: - ## LOGGING - logging.post_call( - input=messages, - api_key=api_key, - original_response=response, - ) - - response = response - elif custom_llm_provider == "aleph_alpha": - aleph_alpha_key = ( - api_key - or litellm.aleph_alpha_key - or get_secret("ALEPH_ALPHA_API_KEY") - or get_secret("ALEPHALPHA_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("ALEPH_ALPHA_API_BASE") - or "https://api.aleph-alpha.com/complete" - ) - - model_response = aleph_alpha.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - default_max_tokens_to_sample=litellm.max_tokens, - api_key=aleph_alpha_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="aleph_alpha", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/generate" - ) - - model_response = cohere.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "cohere_chat": - cohere_key = ( - api_key - or litellm.cohere_key - or get_secret("COHERE_API_KEY") - or get_secret("CO_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("COHERE_API_BASE") - or "https://api.cohere.ai/v1/chat" - ) - - model_response = cohere_chat.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=cohere_key, - logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="cohere_chat", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "maritalk": - maritalk_key = ( - api_key - or litellm.maritalk_key - or get_secret("MARITALK_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("MARITALK_API_BASE") - or "https://chat.maritaca.ai/api/chat/inference" - ) - - model_response = maritalk.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=maritalk_key, - logging_obj=logging, - ) - - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="maritalk", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "huggingface": - custom_llm_provider = "huggingface" - huggingface_key = ( - api_key - or litellm.huggingface_key - or os.environ.get("HF_TOKEN") - or os.environ.get("HUGGINGFACE_API_KEY") - or litellm.api_key - ) - hf_headers = headers or litellm.headers - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - model_response = huggingface.completion( - model=model, - messages=messages, - api_base=api_base, # type: ignore - headers=hf_headers, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=huggingface_key, - acompletion=acompletion, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - timeout=timeout, # type: ignore - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion is False - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="huggingface", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "oobabooga": - custom_llm_provider = "oobabooga" - model_response = oobabooga.completion( - model=model, - messages=messages, - model_response=model_response, - api_base=api_base, # type: ignore - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - api_key=None, - logger_fn=logger_fn, - encoding=encoding, - logging_obj=logging, - ) - if "stream" in optional_params and optional_params["stream"] == True: - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="oobabooga", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "openrouter": - api_base = api_base or litellm.api_base or "https://openrouter.ai/api/v1" - - api_key = ( - api_key - or litellm.api_key - or litellm.openrouter_key - or get_secret("OPENROUTER_API_KEY") - or get_secret("OR_API_KEY") - ) - - openrouter_site_url = get_secret("OR_SITE_URL") or "https://litellm.ai" - - openrouter_app_name = get_secret("OR_APP_NAME") or "liteLLM" - - headers = ( - headers - or litellm.headers - or { - "HTTP-Referer": openrouter_site_url, - "X-Title": openrouter_app_name, - } - ) - - ## Load Config - config = openrouter.OpenrouterConfig.get_config() - for k, v in config.items(): - if k == "extra_body": - # we use openai 'extra_body' to pass openrouter specific params - transforms, route, models - if "extra_body" in optional_params: - optional_params[k].update(v) - else: - optional_params[k] = v - elif k not in optional_params: - optional_params[k] = v - - data = {"model": model, "messages": messages, **optional_params} - - ## COMPLETION CALL - response = openai_chat_completions.completion( - model=model, - messages=messages, - headers=headers, - api_key=api_key, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - logging_obj=logging, - acompletion=acompletion, - timeout=timeout, # type: ignore - ) - ## LOGGING - logging.post_call( - input=messages, api_key=openai.api_key, original_response=response - ) - elif ( - custom_llm_provider == "together_ai" - or ("togethercomputer" in model) - or (model in litellm.together_ai_models) - ): - """ - Deprecated. We now do together ai calls via the openai client - https://docs.together.ai/docs/openai-api-compatibility - """ - custom_llm_provider = "together_ai" - together_ai_key = ( - api_key - or litellm.togetherai_api_key - or get_secret("TOGETHER_AI_TOKEN") - or get_secret("TOGETHERAI_API_KEY") - or litellm.api_key - ) - - api_base = ( - api_base - or litellm.api_base - or get_secret("TOGETHERAI_API_BASE") - or "https://api.together.xyz/inference" - ) - - custom_prompt_dict = custom_prompt_dict or litellm.custom_prompt_dict - - model_response = together_ai.completion( - model=model, - messages=messages, - api_base=api_base, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=together_ai_key, - logging_obj=logging, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream_tokens" in optional_params - and optional_params["stream_tokens"] == True - ): - # don't try to access stream object, - response = CustomStreamWrapper( - model_response, - model, - custom_llm_provider="together_ai", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "palm": - palm_api_key = api_key or get_secret("PALM_API_KEY") or litellm.api_key - - # palm does not support streaming as yet :( - model_response = palm.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=palm_api_key, - logging_obj=logging, - ) - # fake palm streaming - if "stream" in optional_params and optional_params["stream"] == True: - # fake streaming for palm - resp_string = model_response["choices"][0]["message"]["content"] - response = CustomStreamWrapper( - resp_string, model, custom_llm_provider="palm", logging_obj=logging - ) - return response - response = model_response - elif custom_llm_provider == "gemini": - gemini_api_key = ( - api_key - or get_secret("GEMINI_API_KEY") - or get_secret("PALM_API_KEY") # older palm api key should also work - or litellm.api_key - ) - - # palm does not support streaming as yet :( - model_response = gemini.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=optional_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - api_key=gemini_api_key, - logging_obj=logging, - acompletion=acompletion, - custom_prompt_dict=custom_prompt_dict, - ) - if ( - "stream" in optional_params - and optional_params["stream"] == True - and acompletion == False - ): - response = CustomStreamWrapper( - iter(model_response), - model, - custom_llm_provider="gemini", - logging_obj=logging, - ) - return response - response = model_response - elif custom_llm_provider == "vertex_ai": - vertex_ai_project = ( - optional_params.pop("vertex_project", None) - or optional_params.pop("vertex_ai_project", None) - or litellm.vertex_project - or get_secret("VERTEXAI_PROJECT") - ) - vertex_ai_location = ( - optional_params.pop("vertex_location", None) - or optional_params.pop("vertex_ai_location", None) - or litellm.vertex_location - or get_secret("VERTEXAI_LOCATION") - ) - vertex_credentials = ( - optional_params.pop("vertex_credentials", None) - or optional_params.pop("vertex_ai_credentials", None) - or get_secret("VERTEXAI_CREDENTIALS") - ) - new_params = deepcopy(optional_params) - if "claude-3" in model: - model_response = vertex_ai_anthropic.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - else: -> model_response = vertex_ai.completion( - model=model, - messages=messages, - model_response=model_response, - print_verbose=print_verbose, - optional_params=new_params, - litellm_params=litellm_params, - logger_fn=logger_fn, - encoding=encoding, - vertex_location=vertex_ai_location, - vertex_project=vertex_ai_project, - vertex_credentials=vertex_credentials, - logging_obj=logging, - acompletion=acompletion, - ) - -../main.py:1824: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -messages = [{'content': [{'text': 'Whats in this image?', 'type': 'text'}, {'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}, 'type': 'image_url'}], 'role': 'user'}] -model_response = ModelResponse(id='chatcmpl-722df0e7-4e2d-44e6-9e2c-49823faa0189', choices=[Choices(finish_reason='stop', index=0, mess... role='assistant'))], created=1716145725, model=None, object='chat.completion', system_fingerprint=None, usage=Usage()) -print_verbose = -encoding = -logging_obj = -vertex_project = None, vertex_location = None, vertex_credentials = None -optional_params = {} -litellm_params = {'acompletion': False, 'api_base': '', 'api_key': None, 'completion_call_id': None, ...} -logger_fn = None, acompletion = False - - def completion( - model: str, - messages: list, - model_response: ModelResponse, - print_verbose: Callable, - encoding, - logging_obj, - vertex_project=None, - vertex_location=None, - vertex_credentials=None, - optional_params=None, - litellm_params=None, - logger_fn=None, - acompletion: bool = False, - ): - try: - import vertexai - except: - raise VertexAIError( - status_code=400, - message="vertexai import failed please run `pip install google-cloud-aiplatform`", - ) - - if not ( - hasattr(vertexai, "preview") or hasattr(vertexai.preview, "language_models") - ): - raise VertexAIError( - status_code=400, - message="""Upgrade vertex ai. Run `pip install "google-cloud-aiplatform>=1.38"`""", - ) - try: - from vertexai.preview.language_models import ( - ChatModel, - CodeChatModel, - InputOutputTextPair, - ) - from vertexai.language_models import TextGenerationModel, CodeGenerationModel - from vertexai.preview.generative_models import ( - GenerativeModel, - Part, - GenerationConfig, - ) - from google.cloud import aiplatform # type: ignore - from google.protobuf import json_format # type: ignore - from google.protobuf.struct_pb2 import Value # type: ignore - from google.cloud.aiplatform_v1beta1.types import content as gapic_content_types # type: ignore - import google.auth # type: ignore - import proto # type: ignore - - ## Load credentials with the correct quota project ref: https://github.com/googleapis/python-aiplatform/issues/2557#issuecomment-1709284744 - print_verbose( - f"VERTEX AI: vertex_project={vertex_project}; vertex_location={vertex_location}" - ) - if vertex_credentials is not None and isinstance(vertex_credentials, str): - import google.oauth2.service_account - - json_obj = json.loads(vertex_credentials) - - creds = google.oauth2.service_account.Credentials.from_service_account_info( - json_obj, - scopes=["https://www.googleapis.com/auth/cloud-platform"], - ) - else: - creds, _ = google.auth.default(quota_project_id=vertex_project) - print_verbose( - f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" - ) - vertexai.init( - project=vertex_project, location=vertex_location, credentials=creds - ) - - ## Load Config - config = litellm.VertexAIConfig.get_config() - for k, v in config.items(): - if k not in optional_params: - optional_params[k] = v - - ## Process safety settings into format expected by vertex AI - safety_settings = None - if "safety_settings" in optional_params: - safety_settings = optional_params.pop("safety_settings") - if not isinstance(safety_settings, list): - raise ValueError("safety_settings must be a list") - if len(safety_settings) > 0 and not isinstance(safety_settings[0], dict): - raise ValueError("safety_settings must be a list of dicts") - safety_settings = [ - gapic_content_types.SafetySetting(x) for x in safety_settings - ] - - # vertexai does not use an API key, it looks for credentials.json in the environment - - prompt = " ".join( - [ - message["content"] - for message in messages - if isinstance(message["content"], str) - ] - ) - - mode = "" - - request_str = "" - response_obj = None - async_client = None - instances = None - client_options = { - "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" - } - if ( - model in litellm.vertex_language_models - or model in litellm.vertex_vision_models - ): - llm_model = GenerativeModel(model) - mode = "vision" - request_str += f"llm_model = GenerativeModel({model})\n" - elif model in litellm.vertex_chat_models: - llm_model = ChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = ChatModel.from_pretrained({model})\n" - elif model in litellm.vertex_text_models: - llm_model = TextGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = TextGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_text_models: - llm_model = CodeGenerationModel.from_pretrained(model) - mode = "text" - request_str += f"llm_model = CodeGenerationModel.from_pretrained({model})\n" - elif model in litellm.vertex_code_chat_models: # vertex_code_llm_models - llm_model = CodeChatModel.from_pretrained(model) - mode = "chat" - request_str += f"llm_model = CodeChatModel.from_pretrained({model})\n" - elif model == "private": - mode = "private" - model = optional_params.pop("model_id", None) - # private endpoint requires a dict instead of JSON - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - llm_model = aiplatform.PrivateEndpoint( - endpoint_name=model, - project=vertex_project, - location=vertex_location, - ) - request_str += f"llm_model = aiplatform.PrivateEndpoint(endpoint_name={model}, project={vertex_project}, location={vertex_location})\n" - else: # assume vertex model garden on public endpoint - mode = "custom" - - instances = [optional_params.copy()] - instances[0]["prompt"] = prompt - instances = [ - json_format.ParseDict(instance_dict, Value()) - for instance_dict in instances - ] - # Will determine the API used based on async parameter - llm_model = None - - # NOTE: async prediction and streaming under "private" mode isn't supported by aiplatform right now - if acompletion == True: - data = { - "llm_model": llm_model, - "mode": mode, - "prompt": prompt, - "logging_obj": logging_obj, - "request_str": request_str, - "model": model, - "model_response": model_response, - "encoding": encoding, - "messages": messages, - "print_verbose": print_verbose, - "client_options": client_options, - "instances": instances, - "vertex_location": vertex_location, - "vertex_project": vertex_project, - "safety_settings": safety_settings, - **optional_params, - } - if optional_params.get("stream", False) is True: - # async streaming - return async_streaming(**data) - - return async_completion(**data) - - if mode == "vision": - print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") - print_verbose(f"\nProcessing input messages = {messages}") - tools = optional_params.pop("tools", None) - content = _gemini_convert_messages_text(messages=messages) - stream = optional_params.pop("stream", False) - if stream == True: - request_str += f"response = llm_model.generate_content({content}, generation_config=GenerationConfig(**{optional_params}), safety_settings={safety_settings}, stream={stream})\n" - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - model_response = llm_model.generate_content( - contents={"content": content}, - generation_config=optional_params, - safety_settings=safety_settings, - stream=True, - tools=tools, - ) - - return model_response - - request_str += f"response = llm_model.generate_content({content})\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - - ## LLM Call - response = llm_model.generate_content( - contents=content, - generation_config=optional_params, - safety_settings=safety_settings, - tools=tools, - ) - - if tools is not None and bool( - getattr(response.candidates[0].content.parts[0], "function_call", None) - ): - function_call = response.candidates[0].content.parts[0].function_call - args_dict = {} - - # Check if it's a RepeatedComposite instance - for key, val in function_call.args.items(): - if isinstance( - val, proto.marshal.collections.repeated.RepeatedComposite - ): - # If so, convert to list - args_dict[key] = [v for v in val] - else: - args_dict[key] = val - - try: - args_str = json.dumps(args_dict) - except Exception as e: - raise VertexAIError(status_code=422, message=str(e)) - message = litellm.Message( - content=None, - tool_calls=[ - { - "id": f"call_{str(uuid.uuid4())}", - "function": { - "arguments": args_str, - "name": function_call.name, - }, - "type": "function", - } - ], - ) - completion_response = message - else: - completion_response = response.text - response_obj = response._raw_response - optional_params["tools"] = tools - elif mode == "chat": - chat = llm_model.start_chat() - request_str += f"chat = llm_model.start_chat()\n" - - if "stream" in optional_params and optional_params["stream"] == True: - # NOTE: VertexAI does not accept stream=True as a param and raises an error, - # we handle this by removing 'stream' from optional params and sending the request - # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format - optional_params.pop( - "stream", None - ) # vertex ai raises an error when passing stream in optional params - request_str += ( - f"chat.send_message_streaming({prompt}, **{optional_params})\n" - ) - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = chat.send_message_streaming(prompt, **optional_params) - - return model_response - - request_str += f"chat.send_message({prompt}, **{optional_params}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - completion_response = chat.send_message(prompt, **optional_params).text - elif mode == "text": - if "stream" in optional_params and optional_params["stream"] == True: - optional_params.pop( - "stream", None - ) # See note above on handling streaming for vertex ai - request_str += ( - f"llm_model.predict_streaming({prompt}, **{optional_params})\n" - ) - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - model_response = llm_model.predict_streaming(prompt, **optional_params) - - return model_response - - request_str += f"llm_model.predict({prompt}, **{optional_params}).text\n" - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - completion_response = llm_model.predict(prompt, **optional_params).text - elif mode == "custom": - """ - Vertex AI Model Garden - """ - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - llm_model = aiplatform.gapic.PredictionServiceClient( - client_options=client_options - ) - request_str += f"llm_model = aiplatform.gapic.PredictionServiceClient(client_options={client_options})\n" - endpoint_path = llm_model.endpoint_path( - project=vertex_project, location=vertex_location, endpoint=model - ) - request_str += ( - f"llm_model.predict(endpoint={endpoint_path}, instances={instances})\n" - ) - response = llm_model.predict( - endpoint=endpoint_path, instances=instances - ).predictions - - completion_response = response[0] - if ( - isinstance(completion_response, str) - and "\nOutput:\n" in completion_response - ): - completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: - response = TextStreamer(completion_response) - return response - elif mode == "private": - """ - Vertex AI Model Garden deployed on private endpoint - """ - ## LOGGING - logging_obj.pre_call( - input=prompt, - api_key=None, - additional_args={ - "complete_input_dict": optional_params, - "request_str": request_str, - }, - ) - request_str += f"llm_model.predict(instances={instances})\n" - response = llm_model.predict(instances=instances).predictions - - completion_response = response[0] - if ( - isinstance(completion_response, str) - and "\nOutput:\n" in completion_response - ): - completion_response = completion_response.split("\nOutput:\n", 1)[1] - if "stream" in optional_params and optional_params["stream"] == True: - response = TextStreamer(completion_response) - return response - - ## LOGGING - logging_obj.post_call( - input=prompt, api_key=None, original_response=completion_response - ) - - ## RESPONSE OBJECT - if isinstance(completion_response, litellm.Message): - model_response["choices"][0]["message"] = completion_response - elif len(str(completion_response)) > 0: - model_response["choices"][0]["message"]["content"] = str( - completion_response - ) - model_response["created"] = int(time.time()) - model_response["model"] = model - ## CALCULATING USAGE - if model in litellm.vertex_language_models and response_obj is not None: - model_response["choices"][0].finish_reason = map_finish_reason( - response_obj.candidates[0].finish_reason.name - ) - usage = Usage( - prompt_tokens=response_obj.usage_metadata.prompt_token_count, - completion_tokens=response_obj.usage_metadata.candidates_token_count, - total_tokens=response_obj.usage_metadata.total_token_count, - ) - else: - # init prompt tokens - # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter - prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 - if response_obj is not None: - if hasattr(response_obj, "usage_metadata") and hasattr( - response_obj.usage_metadata, "prompt_token_count" - ): - prompt_tokens = response_obj.usage_metadata.prompt_token_count - completion_tokens = ( - response_obj.usage_metadata.candidates_token_count - ) - else: - prompt_tokens = len(encoding.encode(prompt)) - completion_tokens = len( - encoding.encode( - model_response["choices"][0]["message"].get("content", "") - ) - ) - - usage = Usage( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=prompt_tokens + completion_tokens, - ) - setattr(model_response, "usage", usage) - return model_response - except Exception as e: - if isinstance(e, VertexAIError): - raise e -> raise VertexAIError(status_code=500, message=str(e)) -E litellm.llms.vertex_ai.VertexAIError: Parameter to MergeFrom() must be instance of same class: expected got . - -../llms/vertex_ai.py:971: VertexAIError - -During handling of the above exception, another exception occurred: - -args = () -kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': -call_type = 'completion', model = 'vertex_ai/gemini-1.5-flash-preview-0514' -k = 'litellm_logging_obj' - - @wraps(original_function) - def wrapper(*args, **kwargs): - # DO NOT MOVE THIS. It always needs to run first - # Check if this is an async function. If so only execute the async function - if ( - kwargs.get("acompletion", False) == True - or kwargs.get("aembedding", False) == True - or kwargs.get("aimg_generation", False) == True - or kwargs.get("amoderation", False) == True - or kwargs.get("atext_completion", False) == True - or kwargs.get("atranscription", False) == True - ): - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # MODEL CALL - result = original_function(*args, **kwargs) - if "stream" in kwargs and kwargs["stream"] == True: - if ( - "complete_response" in kwargs - and kwargs["complete_response"] == True - ): - chunks = [] - for idx, chunk in enumerate(result): - chunks.append(chunk) - return litellm.stream_chunk_builder( - chunks, messages=kwargs.get("messages", None) - ) - else: - return result - - return result - - # Prints Exactly what was passed to litellm function - don't execute any logic here - it should just print - print_args_passed_to_litellm(original_function, args, kwargs) - start_time = datetime.datetime.now() - result = None - logging_obj = kwargs.get("litellm_logging_obj", None) - - # only set litellm_call_id if its not in kwargs - call_type = original_function.__name__ - if "litellm_call_id" not in kwargs: - kwargs["litellm_call_id"] = str(uuid.uuid4()) - try: - model = args[0] if len(args) > 0 else kwargs["model"] - except: - model = None - if ( - call_type != CallTypes.image_generation.value - and call_type != CallTypes.text_completion.value - ): - raise ValueError("model param not passed in.") - - try: - if logging_obj is None: - logging_obj, kwargs = function_setup( - original_function.__name__, rules_obj, start_time, *args, **kwargs - ) - kwargs["litellm_logging_obj"] = logging_obj - - # CHECK FOR 'os.environ/' in kwargs - for k, v in kwargs.items(): - if v is not None and isinstance(v, str) and v.startswith("os.environ/"): - kwargs[k] = litellm.get_secret(v) - # [OPTIONAL] CHECK BUDGET - if litellm.max_budget: - if litellm._current_cost > litellm.max_budget: - raise BudgetExceededError( - current_cost=litellm._current_cost, - max_budget=litellm.max_budget, - ) - - # [OPTIONAL] CHECK MAX RETRIES / REQUEST - if litellm.num_retries_per_request is not None: - # check if previous_models passed in as ['litellm_params']['metadata]['previous_models'] - previous_models = kwargs.get("metadata", {}).get( - "previous_models", None - ) - if previous_models is not None: - if litellm.num_retries_per_request <= len(previous_models): - raise Exception(f"Max retries per request hit!") - - # [OPTIONAL] CHECK CACHE - print_verbose( - f"SYNC kwargs[caching]: {kwargs.get('caching', False)}; litellm.cache: {litellm.cache}; kwargs.get('cache')['no-cache']: {kwargs.get('cache', {}).get('no-cache', False)}" - ) - # if caching is false or cache["no-cache"]==True, don't run this - if ( - ( - ( - ( - kwargs.get("caching", None) is None - and litellm.cache is not None - ) - or kwargs.get("caching", False) == True - ) - and kwargs.get("cache", {}).get("no-cache", False) != True - ) - and kwargs.get("aembedding", False) != True - and kwargs.get("atext_completion", False) != True - and kwargs.get("acompletion", False) != True - and kwargs.get("aimg_generation", False) != True - and kwargs.get("atranscription", False) != True - ): # allow users to control returning cached responses from the completion function - # checking cache - print_verbose(f"INSIDE CHECKING CACHE") - if ( - litellm.cache is not None - and str(original_function.__name__) - in litellm.cache.supported_call_types - ): - print_verbose(f"Checking Cache") - preset_cache_key = litellm.cache.get_cache_key(*args, **kwargs) - kwargs["preset_cache_key"] = ( - preset_cache_key # for streaming calls, we need to pass the preset_cache_key - ) - cached_result = litellm.cache.get_cache(*args, **kwargs) - if cached_result != None: - if "detail" in cached_result: - # implies an error occurred - pass - else: - call_type = original_function.__name__ - print_verbose( - f"Cache Response Object routing: call_type - {call_type}; cached_result instace: {type(cached_result)}" - ) - if call_type == CallTypes.completion.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - model_response_object=ModelResponse(), - stream=kwargs.get("stream", False), - ) - - if kwargs.get("stream", False) == True: - cached_result = CustomStreamWrapper( - completion_stream=cached_result, - model=model, - custom_llm_provider="cached_response", - logging_obj=logging_obj, - ) - elif call_type == CallTypes.embedding.value and isinstance( - cached_result, dict - ): - cached_result = convert_to_model_response_object( - response_object=cached_result, - response_type="embedding", - ) - - # LOG SUCCESS - cache_hit = True - end_time = datetime.datetime.now() - ( - model, - custom_llm_provider, - dynamic_api_key, - api_base, - ) = litellm.get_llm_provider( - model=model, - custom_llm_provider=kwargs.get( - "custom_llm_provider", None - ), - api_base=kwargs.get("api_base", None), - api_key=kwargs.get("api_key", None), - ) - print_verbose( - f"Async Wrapper: Completed Call, calling async_success_handler: {logging_obj.async_success_handler}" - ) - logging_obj.update_environment_variables( - model=model, - user=kwargs.get("user", None), - optional_params={}, - litellm_params={ - "logger_fn": kwargs.get("logger_fn", None), - "acompletion": False, - "metadata": kwargs.get("metadata", {}), - "model_info": kwargs.get("model_info", {}), - "proxy_server_request": kwargs.get( - "proxy_server_request", None - ), - "preset_cache_key": kwargs.get( - "preset_cache_key", None - ), - "stream_response": kwargs.get( - "stream_response", {} - ), - }, - input=kwargs.get("messages", ""), - api_key=kwargs.get("api_key", None), - original_response=str(cached_result), - additional_args=None, - stream=kwargs.get("stream", False), - ) - threading.Thread( - target=logging_obj.success_handler, - args=(cached_result, start_time, end_time, cache_hit), - ).start() - return cached_result - - # CHECK MAX TOKENS - if ( - kwargs.get("max_tokens", None) is not None - and model is not None - and litellm.modify_params - == True # user is okay with params being modified - and ( - call_type == CallTypes.acompletion.value - or call_type == CallTypes.completion.value - ) - ): - try: - base_model = model - if kwargs.get("hf_model_name", None) is not None: - base_model = f"huggingface/{kwargs.get('hf_model_name')}" - max_output_tokens = ( - get_max_tokens(model=base_model) or 4096 - ) # assume min context window is 4k tokens - user_max_tokens = kwargs.get("max_tokens") - ## Scenario 1: User limit + prompt > model limit - messages = None - if len(args) > 1: - messages = args[1] - elif kwargs.get("messages", None): - messages = kwargs["messages"] - input_tokens = token_counter(model=base_model, messages=messages) - input_tokens += max( - 0.1 * input_tokens, 10 - ) # give at least a 10 token buffer. token counting can be imprecise. - if input_tokens > max_output_tokens: - pass # allow call to fail normally - elif user_max_tokens + input_tokens > max_output_tokens: - user_max_tokens = max_output_tokens - input_tokens - print_verbose(f"user_max_tokens: {user_max_tokens}") - kwargs["max_tokens"] = int( - round(user_max_tokens) - ) # make sure max tokens is always an int - except Exception as e: - print_verbose(f"Error while checking max token limit: {str(e)}") - # MODEL CALL -> result = original_function(*args, **kwargs) - -../utils.py:3211: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../main.py:2368: in completion - raise exception_type( -../utils.py:9709: in exception_type - raise e -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -original_exception = VertexAIError("Parameter to MergeFrom() must be instance of same class: expected got .") -custom_llm_provider = 'vertex_ai' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -extra_kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } - - def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - extra_kwargs={}, - ): - global user_logger_fn, liteDebuggerClient - exception_mapping_worked = False - if litellm.suppress_debug_info is False: - print() # noqa - print( # noqa - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa - ) # noqa - print( # noqa - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa - ) # noqa - print() # noqa - try: - if model: - error_str = str(original_exception) - if isinstance(original_exception, BaseException): - exception_type = type(original_exception).__name__ - else: - exception_type = "" - - ################################################################################ - # Common Extra information needed for all providers - # We pass num retries, api_base, vertex_deployment etc to the exception here - ################################################################################ - extra_information = "" - try: - _api_base = litellm.get_api_base( - model=model, optional_params=extra_kwargs - ) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" - - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) - except: - # DO NOT LET this Block raising the original exception - pass - - ################################################################################ - # End of Common Extra information Needed for all providers - ################################################################################ - - ################################################################################ - #################### Start of Provider Exception mapping #################### - ################################################################################ - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: - exception_mapping_worked = True - raise Timeout( - message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) - if message is not None and isinstance(message, str): - message = message.replace("OPENAI", custom_llm_provider.upper()) - message = message.replace("openai", custom_llm_provider) - message = message.replace("OpenAI", custom_llm_provider) - if custom_llm_provider == "openai": - exception_provider = "OpenAI" + "Exception" - else: - exception_provider = ( - custom_llm_provider[0].upper() - + custom_llm_provider[1:] - + "Exception" - ) - - if "This model's maximum context length is" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "content_policy_violation" in error_str - ): - exception_mapping_worked = True - raise ContentPolicyViolationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Request too large" in error_str: - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Mistral API raised a streaming error" in error_str: - exception_mapping_worked = True - _request = httpx.Request( - method="POST", url="https://api.openai.com/v1" - ) - raise APIError( - status_code=500, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=_request, - litellm_debug_info=extra_information, - ) - elif hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=original_exception.request, - litellm_debug_info=extra_information, - ) - else: - # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors - raise APIConnectionError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - litellm_debug_info=extra_information, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if hasattr(original_exception, "status_code"): - print_verbose(f"status_code: {original_exception.status_code}") - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", - llm_provider="anthropic", - model=model, - request=original_exception.request, - ) - elif custom_llm_provider == "replicate": - if "Incorrect authentication token" in error_str: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif "input is too long" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif exception_type == "ModelError": - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif "Request was throttled" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"ReplicateException - {str(original_exception)}", - llm_provider="replicate", - model=model, - request=httpx.Request( - method="POST", - url="https://api.replicate.com/v1/deployments", - ), - ) - elif custom_llm_provider == "watsonx": - if "token_quota_reached" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"WatsonxException: Rate Limit Errror - {error_str}", - llm_provider="watsonx", - model=model, - response=original_exception.response, - ) - elif custom_llm_provider == "predibase": - if "authorization denied for" in error_str: - exception_mapping_worked = True - - # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception - if ( - error_str is not None - and isinstance(error_str, str) - and "bearer" in error_str.lower() - ): - # only keep the first 10 chars after the occurnence of "bearer" - _bearer_token_start_index = error_str.lower().find("bearer") - error_str = error_str[: _bearer_token_start_index + 14] - error_str += "XXXXXXX" + '"' - - raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "prompt: length: 1.." in error_str - or "Too many input tokens" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"BedrockException: Context Window Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "Malformed input request" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "AccessDeniedException" in error_str: - exception_mapping_worked = True - raise PermissionDeniedError( - message=f"BedrockException PermissionDeniedError - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Connect timeout on endpoint URL" in error_str - or "timed out" in error_str - ): - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException: Timeout Error - {error_str}", - model=model, - llm_provider="bedrock", - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=httpx.Response( - status_code=500, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ), - ) - elif original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "sagemaker": - if "Unable to locate credentials" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "`inputs` tokens + `max_new_tokens` must be <=" in error_str - or "instance type with more CPU capacity or memory" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "None Unknown Error." in error_str - or "Content has no parts." in error_str - ): - exception_mapping_worked = True - raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - request=original_exception.request, - litellm_debug_info=extra_information, - ) - elif "403" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "The response was blocked." in error_str: - exception_mapping_worked = True - raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - elif ( - "429 Quota exceeded" in error_str - or "IndexError: list index out of range" in error_str - or "429 Unable to submit request because the service is temporarily out of capacity." - in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - if hasattr(original_exception, "status_code"): - if original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=original_exception.response, - ) - if original_exception.status_code == 500: - exception_mapping_worked = True -> raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - request=original_exception.request, -E litellm.exceptions.APIError: VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -../utils.py:8922: APIError - -During handling of the above exception, another exception occurred: - - def test_gemini_pro_vision(): - try: - load_vertex_ai_credentials() - litellm.set_verbose = True - litellm.num_retries = 3 -> resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - }, - }, - ], - } - ], - ) - -test_amazing_vertex_completion.py:510: -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ -../utils.py:3289: in wrapper - return litellm.completion_with_retries(*args, **kwargs) -../main.py:2401: in completion_with_retries - return retryer(original_function, *args, **kwargs) -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:379: in __call__ - do = self.iter(retry_state=retry_state) -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:325: in iter - raise retry_exc.reraise() -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:158: in reraise - raise self.last_attempt.result() -/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/_base.py:449: in result - return self.__get_result() -/Library/Frameworks/Python.framework/Versions/3.11/lib/python3.11/concurrent/futures/_base.py:401: in __get_result - raise self._exception -../proxy/myenv/lib/python3.11/site-packages/tenacity/__init__.py:382: in __call__ - result = fn(*args, **kwargs) -../utils.py:3317: in wrapper - raise e -../utils.py:3211: in wrapper - result = original_function(*args, **kwargs) -../main.py:2368: in completion - raise exception_type( -../utils.py:9709: in exception_type - raise e -_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ - -model = 'gemini-1.5-flash-preview-0514' -original_exception = VertexAIError("Parameter to MergeFrom() must be instance of same class: expected got .") -custom_llm_provider = 'vertex_ai' -completion_kwargs = {'acompletion': False, 'api_base': None, 'api_key': None, 'api_version': None, ...} -extra_kwargs = {'litellm_call_id': '7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', 'litellm_logging_obj': } - - def exception_type( - model, - original_exception, - custom_llm_provider, - completion_kwargs={}, - extra_kwargs={}, - ): - global user_logger_fn, liteDebuggerClient - exception_mapping_worked = False - if litellm.suppress_debug_info is False: - print() # noqa - print( # noqa - "\033[1;31mGive Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new\033[0m" # noqa - ) # noqa - print( # noqa - "LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'." # noqa - ) # noqa - print() # noqa - try: - if model: - error_str = str(original_exception) - if isinstance(original_exception, BaseException): - exception_type = type(original_exception).__name__ - else: - exception_type = "" - - ################################################################################ - # Common Extra information needed for all providers - # We pass num retries, api_base, vertex_deployment etc to the exception here - ################################################################################ - extra_information = "" - try: - _api_base = litellm.get_api_base( - model=model, optional_params=extra_kwargs - ) - messages = litellm.get_first_chars_messages(kwargs=completion_kwargs) - _vertex_project = extra_kwargs.get("vertex_project") - _vertex_location = extra_kwargs.get("vertex_location") - _metadata = extra_kwargs.get("metadata", {}) or {} - _model_group = _metadata.get("model_group") - _deployment = _metadata.get("deployment") - extra_information = f"\nModel: {model}" - if _api_base: - extra_information += f"\nAPI Base: {_api_base}" - if messages and len(messages) > 0: - extra_information += f"\nMessages: {messages}" - - if _model_group is not None: - extra_information += f"\nmodel_group: {_model_group}\n" - if _deployment is not None: - extra_information += f"\ndeployment: {_deployment}\n" - if _vertex_project is not None: - extra_information += f"\nvertex_project: {_vertex_project}\n" - if _vertex_location is not None: - extra_information += f"\nvertex_location: {_vertex_location}\n" - - # on litellm proxy add key name + team to exceptions - extra_information = _add_key_name_and_team_to_alert( - request_info=extra_information, metadata=_metadata - ) - except: - # DO NOT LET this Block raising the original exception - pass - - ################################################################################ - # End of Common Extra information Needed for all providers - ################################################################################ - - ################################################################################ - #################### Start of Provider Exception mapping #################### - ################################################################################ - - if "Request Timeout Error" in error_str or "Request timed out" in error_str: - exception_mapping_worked = True - raise Timeout( - message=f"APITimeoutError - Request timed out. \nerror_str: {error_str}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - - if ( - custom_llm_provider == "openai" - or custom_llm_provider == "text-completion-openai" - or custom_llm_provider == "custom_openai" - or custom_llm_provider in litellm.openai_compatible_providers - ): - # custom_llm_provider is openai, make it OpenAI - if hasattr(original_exception, "message"): - message = original_exception.message - else: - message = str(original_exception) - if message is not None and isinstance(message, str): - message = message.replace("OPENAI", custom_llm_provider.upper()) - message = message.replace("openai", custom_llm_provider) - message = message.replace("OpenAI", custom_llm_provider) - if custom_llm_provider == "openai": - exception_provider = "OpenAI" + "Exception" - else: - exception_provider = ( - custom_llm_provider[0].upper() - + custom_llm_provider[1:] - + "Exception" - ) - - if "This model's maximum context length is" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "model_not_found" in error_str - ): - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "content_policy_violation" in error_str - ): - exception_mapping_worked = True - raise ContentPolicyViolationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "invalid_request_error" in error_str - and "Incorrect API key provided" not in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Request too large" in error_str: - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "The api_key client option must be set either by passing api_key to the client or by setting the OPENAI_API_KEY environment variable" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "Mistral API raised a streaming error" in error_str: - exception_mapping_worked = True - _request = httpx.Request( - method="POST", url="https://api.openai.com/v1" - ) - raise APIError( - status_code=500, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=_request, - litellm_debug_info=extra_information, - ) - elif hasattr(original_exception, "status_code"): - exception_mapping_worked = True - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"{exception_provider} - {message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - else: - exception_mapping_worked = True - raise APIError( - status_code=original_exception.status_code, - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - request=original_exception.request, - litellm_debug_info=extra_information, - ) - else: - # if no status code then it is an APIConnectionError: https://github.com/openai/openai-python#handling-errors - raise APIConnectionError( - message=f"{exception_provider} - {message}", - llm_provider=custom_llm_provider, - model=model, - litellm_debug_info=extra_information, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ) - elif custom_llm_provider == "anthropic": # one of the anthropics - if hasattr(original_exception, "message"): - if ( - "prompt is too long" in original_exception.message - or "prompt: length" in original_exception.message - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if "Invalid API Key" in original_exception.message: - exception_mapping_worked = True - raise AuthenticationError( - message=original_exception.message, - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - if hasattr(original_exception, "status_code"): - print_verbose(f"status_code: {original_exception.status_code}") - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"AnthropicException - {original_exception.message}", - model=model, - llm_provider="anthropic", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"AnthropicException - {original_exception.message}", - llm_provider="anthropic", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"AnthropicException - {original_exception.message}. Handle with `litellm.APIError`.", - llm_provider="anthropic", - model=model, - request=original_exception.request, - ) - elif custom_llm_provider == "replicate": - if "Incorrect authentication token" in error_str: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif "input is too long" in error_str: - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif exception_type == "ModelError": - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {error_str}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif "Request was throttled" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {error_str}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif ( - original_exception.status_code == 400 - or original_exception.status_code == 422 - or original_exception.status_code == 413 - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"ReplicateException - {original_exception.message}", - model=model, - llm_provider="replicate", - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"ReplicateException - {original_exception.message}", - llm_provider="replicate", - model=model, - response=original_exception.response, - ) - exception_mapping_worked = True - raise APIError( - status_code=500, - message=f"ReplicateException - {str(original_exception)}", - llm_provider="replicate", - model=model, - request=httpx.Request( - method="POST", - url="https://api.replicate.com/v1/deployments", - ), - ) - elif custom_llm_provider == "watsonx": - if "token_quota_reached" in error_str: - exception_mapping_worked = True - raise RateLimitError( - message=f"WatsonxException: Rate Limit Errror - {error_str}", - llm_provider="watsonx", - model=model, - response=original_exception.response, - ) - elif custom_llm_provider == "predibase": - if "authorization denied for" in error_str: - exception_mapping_worked = True - - # Predibase returns the raw API Key in the response - this block ensures it's not returned in the exception - if ( - error_str is not None - and isinstance(error_str, str) - and "bearer" in error_str.lower() - ): - # only keep the first 10 chars after the occurnence of "bearer" - _bearer_token_start_index = error_str.lower().find("bearer") - error_str = error_str[: _bearer_token_start_index + 14] - error_str += "XXXXXXX" + '"' - - raise AuthenticationError( - message=f"PredibaseException: Authentication Error - {error_str}", - llm_provider="predibase", - model=model, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "bedrock": - if ( - "too many tokens" in error_str - or "expected maxLength:" in error_str - or "Input is too long" in error_str - or "prompt: length: 1.." in error_str - or "Too many input tokens" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"BedrockException: Context Window Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "Malformed input request" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Unable to locate credentials" in error_str - or "The security token included in the request is invalid" - in error_str - ): - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException Invalid Authentication - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif "AccessDeniedException" in error_str: - exception_mapping_worked = True - raise PermissionDeniedError( - message=f"BedrockException PermissionDeniedError - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "throttlingException" in error_str - or "ThrottlingException" in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException: Rate Limit Error - {error_str}", - model=model, - llm_provider="bedrock", - response=original_exception.response, - ) - elif ( - "Connect timeout on endpoint URL" in error_str - or "timed out" in error_str - ): - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException: Timeout Error - {error_str}", - model=model, - llm_provider="bedrock", - ) - elif hasattr(original_exception, "status_code"): - if original_exception.status_code == 500: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=httpx.Response( - status_code=500, - request=httpx.Request( - method="POST", url="https://api.openai.com/v1/" - ), - ), - ) - elif original_exception.status_code == 401: - exception_mapping_worked = True - raise AuthenticationError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 404: - exception_mapping_worked = True - raise NotFoundError( - message=f"BedrockException - {original_exception.message}", - llm_provider="bedrock", - model=model, - response=original_exception.response, - ) - elif original_exception.status_code == 408: - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 422: - exception_mapping_worked = True - raise BadRequestError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 429: - exception_mapping_worked = True - raise RateLimitError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 503: - exception_mapping_worked = True - raise ServiceUnavailableError( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif original_exception.status_code == 504: # gateway timeout error - exception_mapping_worked = True - raise Timeout( - message=f"BedrockException - {original_exception.message}", - model=model, - llm_provider=custom_llm_provider, - litellm_debug_info=extra_information, - ) - elif custom_llm_provider == "sagemaker": - if "Unable to locate credentials" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "Input validation error: `best_of` must be > 0 and <= 2" - in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif ( - "`inputs` tokens + `max_new_tokens` must be <=" in error_str - or "instance type with more CPU capacity or memory" in error_str - ): - exception_mapping_worked = True - raise ContextWindowExceededError( - message=f"SagemakerException - {error_str}", - model=model, - llm_provider="sagemaker", - response=original_exception.response, - ) - elif custom_llm_provider == "vertex_ai": - if ( - "Vertex AI API has not been used in project" in error_str - or "Unable to find your project" in error_str - ): - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif ( - "None Unknown Error." in error_str - or "Content has no parts." in error_str - ): - exception_mapping_worked = True - raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - request=original_exception.request, - litellm_debug_info=extra_information, - ) - elif "403" in error_str: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - response=original_exception.response, - litellm_debug_info=extra_information, - ) - elif "The response was blocked." in error_str: - exception_mapping_worked = True - raise UnprocessableEntityError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - elif ( - "429 Quota exceeded" in error_str - or "IndexError: list index out of range" in error_str - or "429 Unable to submit request because the service is temporarily out of capacity." - in error_str - ): - exception_mapping_worked = True - raise RateLimitError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=httpx.Response( - status_code=429, - request=httpx.Request( - method="POST", - url=" https://cloud.google.com/vertex-ai/", - ), - ), - ) - if hasattr(original_exception, "status_code"): - if original_exception.status_code == 400: - exception_mapping_worked = True - raise BadRequestError( - message=f"VertexAIException - {error_str}", - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - response=original_exception.response, - ) - if original_exception.status_code == 500: - exception_mapping_worked = True -> raise APIError( - message=f"VertexAIException - {error_str}", - status_code=500, - model=model, - llm_provider="vertex_ai", - litellm_debug_info=extra_information, - request=original_exception.request, -E litellm.exceptions.APIError: VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -../utils.py:8922: APIError - -During handling of the above exception, another exception occurred: - - def test_gemini_pro_vision(): - try: - load_vertex_ai_credentials() - litellm.set_verbose = True - litellm.num_retries = 3 - resp = litellm.completion( - model="vertex_ai/gemini-1.5-flash-preview-0514", - messages=[ - { - "role": "user", - "content": [ - {"type": "text", "text": "Whats in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "gs://cloud-samples-data/generative-ai/image/boats.jpeg" - }, - }, - ], - } - ], - ) - print(resp) - - prompt_tokens = resp.usage.prompt_tokens - - # DO Not DELETE this ASSERT - # Google counts the prompt tokens for us, we should ensure we use the tokens from the orignal response - assert prompt_tokens == 263 # the gemini api returns 263 to us - except litellm.RateLimitError as e: - pass - except Exception as e: - if "500 Internal error encountered.'" in str(e): - pass - else: -> pytest.fail(f"An exception occurred - {str(e)}") -E Failed: An exception occurred - VertexAIException - Parameter to MergeFrom() must be instance of same class: expected got . - -test_amazing_vertex_completion.py:540: Failed ----------------------------- Captured stdout setup ----------------------------- - ------------------------------ Captured stdout call ----------------------------- -loading vertex ai credentials -Read vertexai file path - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}]) - - -self.optional_params: {} -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] - - -Request to litellm: -litellm.completion(model='vertex_ai/gemini-1.5-flash-preview-0514', messages=[{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}], litellm_call_id='7f48b7ab-47b3-4beb-b2b5-fa298be49d3f', litellm_logging_obj=) - - -SYNC kwargs[caching]: False; litellm.cache: None; kwargs.get('cache')['no-cache']: False -(start) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK -(end) INSIDE THE VERTEX AI OPTIONAL PARAM BLOCK - optional_params: {} -Final returned optional params: {} -self.optional_params: {} -VERTEX AI: vertex_project=None; vertex_location=None -VERTEX AI: creds=; google application credentials: /var/folders/gf/5h3fnlwx40sdrycs4y5qzqx40000gn/T/tmpolsest5s - -Making VertexAI Gemini Pro / Pro Vision Call - -Processing input messages = [{'role': 'user', 'content': [{'type': 'text', 'text': 'Whats in this image?'}, {'type': 'image_url', 'image_url': {'url': 'gs://cloud-samples-data/generative-ai/image/boats.jpeg'}}]}] - -Request Sent from LiteLLM: -llm_model = GenerativeModel(gemini-1.5-flash-preview-0514) -response = llm_model.generate_content([{'role': 'user', 'parts': [{'text': 'Whats in this image?'}, file_data { - mime_type: "image/jpeg" - file_uri: "gs://cloud-samples-data/generative-ai/image/boats.jpeg" -} -]}]) - - - -Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new -LiteLLM.Info: If you need to debug this error, use `litellm.set_verbose=True'. - -Logging Details: logger_fn - None | callable(logger_fn) - False -Logging Details LiteLLM-Failure Call -self.failure_callback: [] -=============================== warnings summary =============================== -../proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: 25 warnings - /Users/krrishdholakia/Documents/litellm/litellm/proxy/myenv/lib/python3.11/site-packages/pydantic/_internal/_config.py:284: PydanticDeprecatedSince20: Support for class-based `config` is deprecated, use ConfigDict instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - warnings.warn(DEPRECATION_MESSAGE, DeprecationWarning) - -../proxy/_types.py:255 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:255: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:342 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:342: PydanticDeprecatedSince20: `pydantic.config.Extra` is deprecated, use literal values instead (e.g. `extra='allow'`). Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - extra = Extra.allow # Allow extra fields - -../proxy/_types.py:345 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:345: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:374 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:374: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:421 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:421: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:490 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:490: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:510 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:510: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:523 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:523: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:568 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:568: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:605 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:605: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:923 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:923: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:950 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:950: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../proxy/_types.py:971 - /Users/krrishdholakia/Documents/litellm/litellm/proxy/_types.py:971: PydanticDeprecatedSince20: Pydantic V1 style `@root_validator` validators are deprecated. You should migrate to Pydantic V2 style `@model_validator` validators, see the migration guide for more details. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.7/migration/ - @root_validator(pre=True) - -../utils.py:60 - /Users/krrishdholakia/Documents/litellm/litellm/utils.py:60: DeprecationWarning: open_text is deprecated. Use files() instead. Refer to https://importlib-resources.readthedocs.io/en/latest/using.html#migrating-from-legacy for migration advice. - with resources.open_text("litellm.llms.tokenizers", "anthropic_tokenizer.json") as f: - --- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html -=========================== short test summary info ============================ -FAILED test_amazing_vertex_completion.py::test_gemini_pro_vision - Failed: An... -======================== 1 failed, 39 warnings in 2.09s ======================== diff --git a/litellm/tests/test_streaming.py b/litellm/tests/test_streaming.py index 1113adc40..a5e098b02 100644 --- a/litellm/tests/test_streaming.py +++ b/litellm/tests/test_streaming.py @@ -1288,14 +1288,14 @@ async def test_completion_replicate_llama3_streaming(sync_mode): @pytest.mark.parametrize( "model", [ - # "bedrock/cohere.command-r-plus-v1:0", - # "anthropic.claude-3-sonnet-20240229-v1:0", - # "anthropic.claude-instant-v1", - # "bedrock/ai21.j2-mid", - # "mistral.mistral-7b-instruct-v0:2", - # "bedrock/amazon.titan-tg1-large", - # "meta.llama3-8b-instruct-v1:0", - "cohere.command-text-v14" + "bedrock/cohere.command-r-plus-v1:0", + "anthropic.claude-3-sonnet-20240229-v1:0", + "anthropic.claude-instant-v1", + "bedrock/ai21.j2-mid", + "mistral.mistral-7b-instruct-v0:2", + "bedrock/amazon.titan-tg1-large", + "meta.llama3-8b-instruct-v1:0", + "cohere.command-text-v14", ], ) @pytest.mark.asyncio @@ -1324,8 +1324,6 @@ async def test_bedrock_httpx_streaming(sync_mode, model): raise Exception("finish reason not set") if complete_response.strip() == "": raise Exception("Empty response received") - - assert False else: response: litellm.CustomStreamWrapper = await litellm.acompletion( # type: ignore model=model, diff --git a/litellm/utils.py b/litellm/utils.py index 75dd85328..728173f38 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5620,13 +5620,80 @@ def get_optional_params( supported_params = get_supported_openai_params( model=model, custom_llm_provider=custom_llm_provider ) - _check_valid_arg(supported_params=supported_params) - optional_params = litellm.AmazonConverseConfig().map_openai_params( - model=model, - non_default_params=non_default_params, - optional_params=optional_params, - drop_params=drop_params, - ) + if "ai21" in model: + _check_valid_arg(supported_params=supported_params) + # params "maxTokens":200,"temperature":0,"topP":250,"stop_sequences":[], + # https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=j2-ultra + if max_tokens is not None: + optional_params["maxTokens"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["topP"] = top_p + if stream: + optional_params["stream"] = stream + elif "anthropic" in model: + _check_valid_arg(supported_params=supported_params) + optional_params = litellm.AmazonConverseConfig().map_openai_params( + model=model, + non_default_params=non_default_params, + optional_params=optional_params, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), + ) + elif "amazon" in model: # amazon titan llms + _check_valid_arg(supported_params=supported_params) + # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large + if max_tokens is not None: + optional_params["maxTokenCount"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if stop is not None: + filtered_stop = _map_and_modify_arg( + {"stop": stop}, provider="bedrock", model=model + ) + optional_params["stopSequences"] = filtered_stop["stop"] + if top_p is not None: + optional_params["topP"] = top_p + if stream: + optional_params["stream"] = stream + elif "meta" in model: # amazon / meta llms + _check_valid_arg(supported_params=supported_params) + # see https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=titan-large + if max_tokens is not None: + optional_params["max_gen_len"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if stream: + optional_params["stream"] = stream + elif "cohere" in model: # cohere models on bedrock + _check_valid_arg(supported_params=supported_params) + # handle cohere params + if stream: + optional_params["stream"] = stream + if temperature is not None: + optional_params["temperature"] = temperature + if max_tokens is not None: + optional_params["max_tokens"] = max_tokens + elif "mistral" in model: + _check_valid_arg(supported_params=supported_params) + # mistral params on bedrock + # \"max_tokens\":400,\"temperature\":0.7,\"top_p\":0.7,\"stop\":[\"\\\\n\\\\nHuman:\"]}" + if max_tokens is not None: + optional_params["max_tokens"] = max_tokens + if temperature is not None: + optional_params["temperature"] = temperature + if top_p is not None: + optional_params["top_p"] = top_p + if stop is not None: + optional_params["stop"] = stop + if stream is not None: + optional_params["stream"] = stream elif custom_llm_provider == "aleph_alpha": supported_params = [ "max_tokens",