diff --git a/litellm/__init__.py b/litellm/__init__.py index d3d3dd0d4b..60b8cf81a0 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -53,6 +53,7 @@ from litellm.constants import ( cohere_embedding_models, bedrock_embedding_models, known_tokenizer_config, + BEDROCK_INVOKE_PROVIDERS_LITERAL, ) from litellm.types.guardrails import GuardrailItem from litellm.proxy._types import ( @@ -361,17 +362,7 @@ BEDROCK_CONVERSE_MODELS = [ "meta.llama3-2-11b-instruct-v1:0", "meta.llama3-2-90b-instruct-v1:0", ] -BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[ - "cohere", - "anthropic", - "mistral", - "amazon", - "meta", - "llama", - "ai21", - "nova", - "deepseek_r1", -] + ####### COMPLETION MODELS ################### open_ai_chat_completion_models: List = [] open_ai_text_completion_models: List = [] diff --git a/litellm/constants.py b/litellm/constants.py index 06756b8f20..0288c45e40 100644 --- a/litellm/constants.py +++ b/litellm/constants.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Literal ROUTER_MAX_FALLBACKS = 5 DEFAULT_BATCH_SIZE = 512 @@ -320,6 +320,17 @@ baseten_models: List = [ "31dxrj3", ] # FALCON 7B # WizardLM # Mosaic ML +BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[ + "cohere", + "anthropic", + "mistral", + "amazon", + "meta", + "llama", + "ai21", + "nova", + "deepseek_r1", +] open_ai_embedding_models: List = ["text-embedding-ada-002"] cohere_embedding_models: List = [ diff --git a/litellm/llms/anthropic/chat/transformation.py b/litellm/llms/anthropic/chat/transformation.py index 9f9c810233..e4f87aa5b4 100644 --- a/litellm/llms/anthropic/chat/transformation.py +++ b/litellm/llms/anthropic/chat/transformation.py @@ -308,7 +308,6 @@ class AnthropicConfig(BaseConfig): model: str, drop_params: bool, ) -> dict: - for param, value in non_default_params.items(): if param == "max_tokens": optional_params["max_tokens"] = value @@ -342,6 +341,10 @@ class AnthropicConfig(BaseConfig): optional_params["top_p"] = value if param == "response_format" and isinstance(value, dict): + ignore_response_format_types = ["text"] + if value["type"] in ignore_response_format_types: # value is a no-op + continue + json_schema: Optional[dict] = None if "response_schema" in value: json_schema = value["response_schema"] diff --git a/litellm/llms/base_llm/chat/transformation.py b/litellm/llms/base_llm/chat/transformation.py index 020223f98e..8c9c5acda3 100644 --- a/litellm/llms/base_llm/chat/transformation.py +++ b/litellm/llms/base_llm/chat/transformation.py @@ -317,6 +317,7 @@ class BaseConfig(ABC): data: dict, messages: list, client: Optional[AsyncHTTPHandler] = None, + json_mode: Optional[bool] = None, ) -> CustomStreamWrapper: raise NotImplementedError @@ -330,6 +331,7 @@ class BaseConfig(ABC): data: dict, messages: list, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + json_mode: Optional[bool] = None, ) -> CustomStreamWrapper: raise NotImplementedError diff --git a/litellm/llms/bedrock/base_aws_llm.py b/litellm/llms/bedrock/base_aws_llm.py index 8158ceab8f..bf9a070f26 100644 --- a/litellm/llms/bedrock/base_aws_llm.py +++ b/litellm/llms/bedrock/base_aws_llm.py @@ -2,13 +2,14 @@ import hashlib import json import os from datetime import datetime -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, cast, get_args import httpx from pydantic import BaseModel from litellm._logging import verbose_logger from litellm.caching.caching import DualCache +from litellm.constants import BEDROCK_INVOKE_PROVIDERS_LITERAL from litellm.litellm_core_utils.dd_tracing import tracer from litellm.secret_managers.main import get_secret @@ -223,6 +224,60 @@ class BaseAWSLLM: # Catch any unexpected errors and return None return None + @staticmethod + def _get_provider_from_model_path( + model_path: str, + ) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]: + """ + Helper function to get the provider from a model path with format: provider/model-name + + Args: + model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name') + + Returns: + Optional[str]: The provider name, or None if no valid provider found + """ + parts = model_path.split("/") + if len(parts) >= 1: + provider = parts[0] + if provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): + return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, provider) + return None + + @staticmethod + def get_bedrock_invoke_provider( + model: str, + ) -> Optional[BEDROCK_INVOKE_PROVIDERS_LITERAL]: + """ + Helper function to get the bedrock provider from the model + + handles 3 scenarions: + 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` + 4. model=us.amazon.nova-pro-v1:0 -> Returns `nova` + """ + if model.startswith("invoke/"): + model = model.replace("invoke/", "", 1) + + _split_model = model.split(".")[0] + if _split_model in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): + return cast(BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) + + # If not a known provider, check for pattern with two slashes + provider = BaseAWSLLM._get_provider_from_model_path(model) + if provider is not None: + return provider + + # check if provider == "nova" + if "nova" in model: + return "nova" + else: + for provider in get_args(BEDROCK_INVOKE_PROVIDERS_LITERAL): + if provider in model: + return provider + return None + def _get_aws_region_name( self, optional_params: dict, model: Optional[str] = None ) -> str: diff --git a/litellm/llms/bedrock/chat/converse_transformation.py b/litellm/llms/bedrock/chat/converse_transformation.py index b86fb7f0f3..3837369a8e 100644 --- a/litellm/llms/bedrock/chat/converse_transformation.py +++ b/litellm/llms/bedrock/chat/converse_transformation.py @@ -206,7 +206,12 @@ class AmazonConverseConfig(BaseConfig): messages: Optional[List[AllMessageValues]] = None, ) -> dict: for param, value in non_default_params.items(): - if param == "response_format": + if param == "response_format" and isinstance(value, dict): + + ignore_response_format_types = ["text"] + if value["type"] in ignore_response_format_types: # value is a no-op + continue + json_schema: Optional[dict] = None schema_name: str = "" if "response_schema" in value: diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 32cd137d93..56cf891e76 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -226,6 +226,7 @@ async def make_call( decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder( model=model, sync_stream=False, + json_mode=json_mode, ) completion_stream = decoder.aiter_bytes( response.aiter_bytes(chunk_size=1024) @@ -311,6 +312,7 @@ def make_sync_call( decoder: AWSEventStreamDecoder = AmazonAnthropicClaudeStreamDecoder( model=model, sync_stream=True, + json_mode=json_mode, ) completion_stream = decoder.iter_bytes(response.iter_bytes(chunk_size=1024)) elif bedrock_invoke_provider == "deepseek_r1": @@ -1149,27 +1151,6 @@ class BedrockLLM(BaseAWSLLM): ) return streaming_response - @staticmethod - def get_bedrock_invoke_provider( - model: str, - ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: - """ - Helper function to get the bedrock provider from the model - - handles 2 scenarions: - 1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` - 2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` - """ - _split_model = model.split(".")[0] - if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): - return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) - - # If not a known provider, check for pattern with two slashes - provider = BedrockLLM._get_provider_from_model_path(model) - if provider is not None: - return provider - return None - @staticmethod def _get_provider_from_model_path( model_path: str, @@ -1524,6 +1505,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder): self, model: str, sync_stream: bool, + json_mode: Optional[bool] = None, ) -> None: """ Child class of AWSEventStreamDecoder that handles the streaming response from the Anthropic family of models @@ -1534,6 +1516,7 @@ class AmazonAnthropicClaudeStreamDecoder(AWSEventStreamDecoder): self.anthropic_model_response_iterator = AnthropicModelResponseIterator( streaming_response=None, sync_stream=sync_stream, + json_mode=json_mode, ) def _chunk_parser(self, chunk_data: dict) -> ModelResponseStream: diff --git a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude2_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude2_transformation.py index 085cf0b9ca..d0d06ef2b2 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude2_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude2_transformation.py @@ -3,8 +3,10 @@ from typing import Optional import litellm +from .base_invoke_transformation import AmazonInvokeConfig -class AmazonAnthropicConfig: + +class AmazonAnthropicConfig(AmazonInvokeConfig): """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude @@ -57,9 +59,7 @@ class AmazonAnthropicConfig: and v is not None } - def get_supported_openai_params( - self, - ): + def get_supported_openai_params(self, model: str): return [ "max_tokens", "max_completion_tokens", @@ -69,7 +69,13 @@ class AmazonAnthropicConfig: "stream", ] - def map_openai_params(self, non_default_params: dict, optional_params: dict): + def map_openai_params( + self, + non_default_params: dict, + optional_params: dict, + model: str, + drop_params: bool, + ): for param, value in non_default_params.items(): if param == "max_tokens" or param == "max_completion_tokens": optional_params["max_tokens_to_sample"] = value diff --git a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py index 09842aef01..0cac339a3c 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/anthropic_claude3_transformation.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, List, Optional import httpx -import litellm +from litellm.llms.anthropic.chat.transformation import AnthropicConfig from litellm.llms.bedrock.chat.invoke_transformations.base_invoke_transformation import ( AmazonInvokeConfig, ) @@ -17,7 +17,7 @@ else: LiteLLMLoggingObj = Any -class AmazonAnthropicClaude3Config(AmazonInvokeConfig): +class AmazonAnthropicClaude3Config(AmazonInvokeConfig, AnthropicConfig): """ Reference: https://us-west-2.console.aws.amazon.com/bedrock/home?region=us-west-2#/providers?model=claude @@ -28,18 +28,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig): anthropic_version: str = "bedrock-2023-05-31" - def get_supported_openai_params(self, model: str): - return [ - "max_tokens", - "max_completion_tokens", - "tools", - "tool_choice", - "stream", - "stop", - "temperature", - "top_p", - "extra_headers", - ] + def get_supported_openai_params(self, model: str) -> List[str]: + return AnthropicConfig.get_supported_openai_params(self, model) def map_openai_params( self, @@ -47,21 +37,14 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig): optional_params: dict, model: str, drop_params: bool, - ): - for param, value in non_default_params.items(): - if param == "max_tokens" or param == "max_completion_tokens": - optional_params["max_tokens"] = value - if param == "tools": - optional_params["tools"] = value - if param == "stream": - optional_params["stream"] = value - if param == "stop": - optional_params["stop_sequences"] = value - if param == "temperature": - optional_params["temperature"] = value - if param == "top_p": - optional_params["top_p"] = value - return optional_params + ) -> dict: + return AnthropicConfig.map_openai_params( + self, + non_default_params, + optional_params, + model, + drop_params, + ) def transform_request( self, @@ -71,7 +54,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig): litellm_params: dict, headers: dict, ) -> dict: - _anthropic_request = litellm.AnthropicConfig().transform_request( + _anthropic_request = AnthropicConfig.transform_request( + self, model=model, messages=messages, optional_params=optional_params, @@ -80,6 +64,7 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig): ) _anthropic_request.pop("model", None) + _anthropic_request.pop("stream", None) if "anthropic_version" not in _anthropic_request: _anthropic_request["anthropic_version"] = self.anthropic_version @@ -99,7 +84,8 @@ class AmazonAnthropicClaude3Config(AmazonInvokeConfig): api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> ModelResponse: - return litellm.AnthropicConfig().transform_response( + return AnthropicConfig.transform_response( + self, model=model, raw_response=raw_response, model_response=model_response, diff --git a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py index e98cb4fa94..e0da783897 100644 --- a/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py +++ b/litellm/llms/bedrock/chat/invoke_transformations/base_invoke_transformation.py @@ -3,7 +3,7 @@ import json import time import urllib.parse from functools import partial -from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union, cast, get_args +from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union import httpx @@ -461,6 +461,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): data: dict, messages: list, client: Optional[AsyncHTTPHandler] = None, + json_mode: Optional[bool] = None, ) -> CustomStreamWrapper: streaming_response = CustomStreamWrapper( completion_stream=None, @@ -475,6 +476,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): logging_obj=logging_obj, fake_stream=True if "ai21" in api_base else False, bedrock_invoke_provider=self.get_bedrock_invoke_provider(model), + json_mode=json_mode, ), model=model, custom_llm_provider="bedrock", @@ -493,6 +495,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): data: dict, messages: list, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, + json_mode: Optional[bool] = None, ) -> CustomStreamWrapper: if client is None or isinstance(client, AsyncHTTPHandler): client = _get_httpx_client(params={}) @@ -509,6 +512,7 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): logging_obj=logging_obj, fake_stream=True if "ai21" in api_base else False, bedrock_invoke_provider=self.get_bedrock_invoke_provider(model), + json_mode=json_mode, ), model=model, custom_llm_provider="bedrock", @@ -527,56 +531,6 @@ class AmazonInvokeConfig(BaseConfig, BaseAWSLLM): """ return False - @staticmethod - def get_bedrock_invoke_provider( - model: str, - ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: - """ - Helper function to get the bedrock provider from the model - - handles 3 scenarions: - 1. model=invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` - 2. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` - 3. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` - 4. model=us.amazon.nova-pro-v1:0 -> Returns `nova` - """ - if model.startswith("invoke/"): - model = model.replace("invoke/", "", 1) - - _split_model = model.split(".")[0] - if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): - return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) - - # If not a known provider, check for pattern with two slashes - provider = AmazonInvokeConfig._get_provider_from_model_path(model) - if provider is not None: - return provider - - # check if provider == "nova" - if "nova" in model: - return "nova" - return None - - @staticmethod - def _get_provider_from_model_path( - model_path: str, - ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: - """ - Helper function to get the provider from a model path with format: provider/model-name - - Args: - model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name') - - Returns: - Optional[str]: The provider name, or None if no valid provider found - """ - parts = model_path.split("/") - if len(parts) >= 1: - provider = parts[0] - if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): - return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider) - return None - def get_bedrock_model_id( self, optional_params: dict, diff --git a/litellm/llms/custom_httpx/llm_http_handler.py b/litellm/llms/custom_httpx/llm_http_handler.py index ebe5308c1c..991e4aeaec 100644 --- a/litellm/llms/custom_httpx/llm_http_handler.py +++ b/litellm/llms/custom_httpx/llm_http_handler.py @@ -159,6 +159,7 @@ class BaseLLMHTTPHandler: encoding: Any, api_key: Optional[str] = None, client: Optional[AsyncHTTPHandler] = None, + json_mode: bool = False, ): if client is None: async_httpx_client = get_async_httpx_client( @@ -190,6 +191,7 @@ class BaseLLMHTTPHandler: optional_params=optional_params, litellm_params=litellm_params, encoding=encoding, + json_mode=json_mode, ) def completion( @@ -211,6 +213,7 @@ class BaseLLMHTTPHandler: headers: Optional[dict] = {}, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ): + json_mode: bool = optional_params.pop("json_mode", False) provider_config = ProviderConfigManager.get_provider_chat_config( model=model, provider=litellm.LlmProviders(custom_llm_provider) @@ -286,6 +289,7 @@ class BaseLLMHTTPHandler: else None ), litellm_params=litellm_params, + json_mode=json_mode, ) else: @@ -309,6 +313,7 @@ class BaseLLMHTTPHandler: if client is not None and isinstance(client, AsyncHTTPHandler) else None ), + json_mode=json_mode, ) if stream is True: @@ -327,6 +332,7 @@ class BaseLLMHTTPHandler: data=data, messages=messages, client=client, + json_mode=json_mode, ) completion_stream, headers = self.make_sync_call( provider_config=provider_config, @@ -380,6 +386,7 @@ class BaseLLMHTTPHandler: optional_params=optional_params, litellm_params=litellm_params, encoding=encoding, + json_mode=json_mode, ) def make_sync_call( @@ -453,6 +460,7 @@ class BaseLLMHTTPHandler: litellm_params: dict, fake_stream: bool = False, client: Optional[AsyncHTTPHandler] = None, + json_mode: Optional[bool] = None, ): if provider_config.has_custom_stream_wrapper is True: return provider_config.get_async_custom_stream_wrapper( @@ -464,6 +472,7 @@ class BaseLLMHTTPHandler: data=data, messages=messages, client=client, + json_mode=json_mode, ) completion_stream, _response_headers = await self.make_async_call_stream_helper( @@ -720,7 +729,7 @@ class BaseLLMHTTPHandler: api_base: Optional[str] = None, client: Optional[Union[HTTPHandler, AsyncHTTPHandler]] = None, ) -> RerankResponse: - + # get config from model, custom llm provider headers = provider_config.validate_environment( api_key=api_key, diff --git a/litellm/proxy/_experimental/out/onboarding.html b/litellm/proxy/_experimental/out/onboarding.html deleted file mode 100644 index 96d44e7a26..0000000000 --- a/litellm/proxy/_experimental/out/onboarding.html +++ /dev/null @@ -1 +0,0 @@ -LiteLLM Dashboard \ No newline at end of file diff --git a/litellm/proxy/_new_secret_config.yaml b/litellm/proxy/_new_secret_config.yaml index d9180abd4e..fe36f5708b 100644 --- a/litellm/proxy/_new_secret_config.yaml +++ b/litellm/proxy/_new_secret_config.yaml @@ -16,7 +16,10 @@ model_list: api_key: os.environ/COHERE_API_KEY - model_name: bedrock-claude-3-7 litellm_params: - model: bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0 + model: bedrock/invoke/us.anthropic.claude-3-7-sonnet-20250219-v1:0 + - model_name: bedrock-claude-3-5-sonnet + litellm_params: + model: bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0 litellm_settings: callbacks: ["langfuse"] \ No newline at end of file diff --git a/litellm/utils.py b/litellm/utils.py index 601594beda..495b0d45a6 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -3240,6 +3240,7 @@ def get_optional_params( # noqa: PLR0915 ) elif custom_llm_provider == "bedrock": bedrock_route = BedrockModelInfo.get_bedrock_route(model) + bedrock_base_model = BedrockModelInfo.get_base_model(model) if bedrock_route == "converse" or bedrock_route == "converse_like": optional_params = litellm.AmazonConverseConfig().map_openai_params( model=model, @@ -3253,8 +3254,9 @@ def get_optional_params( # noqa: PLR0915 messages=messages, ) - elif "anthropic" in model and bedrock_route == "invoke": - if model.startswith("anthropic.claude-3"): + elif "anthropic" in bedrock_base_model and bedrock_route == "invoke": + if bedrock_base_model.startswith("anthropic.claude-3"): + optional_params = ( litellm.AmazonAnthropicClaude3Config().map_openai_params( non_default_params=non_default_params, @@ -3267,10 +3269,17 @@ def get_optional_params( # noqa: PLR0915 ), ) ) + else: optional_params = litellm.AmazonAnthropicConfig().map_openai_params( non_default_params=non_default_params, optional_params=optional_params, + model=model, + drop_params=( + drop_params + if drop_params is not None and isinstance(drop_params, bool) + else False + ), ) elif provider_config is not None: optional_params = provider_config.map_openai_params( @@ -6158,13 +6167,19 @@ class ProviderConfigManager: elif litellm.LlmProviders.BEDROCK == provider: bedrock_route = BedrockModelInfo.get_bedrock_route(model) bedrock_invoke_provider = litellm.BedrockLLM.get_bedrock_invoke_provider( - model + model=model ) + base_model = BedrockModelInfo.get_base_model(model) if bedrock_route == "converse" or bedrock_route == "converse_like": return litellm.AmazonConverseConfig() elif bedrock_invoke_provider == "amazon": # amazon titan llms return litellm.AmazonTitanConfig() + elif bedrock_invoke_provider == "anthropic": + if base_model.startswith("anthropic.claude-3"): + return litellm.AmazonAnthropicClaude3Config() + else: + return litellm.AmazonAnthropicConfig() elif ( bedrock_invoke_provider == "meta" or bedrock_invoke_provider == "llama" ): # amazon / meta llms diff --git a/tests/llm_translation/base_llm_unit_tests.py b/tests/llm_translation/base_llm_unit_tests.py index ad7dd9e1d1..6948b8ad2d 100644 --- a/tests/llm_translation/base_llm_unit_tests.py +++ b/tests/llm_translation/base_llm_unit_tests.py @@ -18,6 +18,7 @@ from litellm.utils import ( CustomStreamWrapper, get_supported_openai_params, get_optional_params, + ProviderConfigManager, ) from typing import Union @@ -247,12 +248,45 @@ class BaseLLMChatTest(ABC): response_format=response_format, ) - print(response) + print(f"response={response}") # OpenAI guarantees that the JSON schema is returned in the content # relevant issue: https://github.com/BerriAI/litellm/issues/6741 assert response.choices[0].message.content is not None + def test_response_format_type_text(self): + """ + Test that the response format type text does not lead to tool calls + """ + from litellm import LlmProviders + + base_completion_call_args = self.get_base_completion_call_args() + litellm.set_verbose = True + + _, provider, _, _ = litellm.get_llm_provider( + model=base_completion_call_args["model"] + ) + + provider_config = ProviderConfigManager.get_provider_chat_config( + base_completion_call_args["model"], LlmProviders(provider) + ) + + print(f"provider_config={provider_config}") + + translated_params = provider_config.map_openai_params( + non_default_params={"response_format": {"type": "text"}}, + optional_params={}, + model=base_completion_call_args["model"], + drop_params=False, + ) + + assert "tool_choice" not in translated_params + assert ( + "tools" not in translated_params + ), f"Got tools={translated_params['tools']}, expected no tools" + + print(f"translated_params={translated_params}") + @pytest.mark.flaky(retries=6, delay=1) def test_json_response_pydantic_obj(self): litellm.set_verbose = True diff --git a/tests/llm_translation/test_bedrock_invoke_tests.py b/tests/llm_translation/test_bedrock_invoke_tests.py index 381a203d3d..93d7c73174 100644 --- a/tests/llm_translation/test_bedrock_invoke_tests.py +++ b/tests/llm_translation/test_bedrock_invoke_tests.py @@ -22,17 +22,9 @@ class TestBedrockInvokeClaudeJson(BaseLLMChatTest): """Test that tool calls with no arguments is translated correctly. Relevant issue: https://github.com/BerriAI/litellm/issues/6833""" pass - @pytest.fixture(autouse=True) - def skip_non_json_tests(self, request): - if not "json" in request.function.__name__.lower(): - pytest.skip( - f"Skipping non-JSON test: {request.function.__name__} does not contain 'json'" - ) - class TestBedrockInvokeNovaJson(BaseLLMChatTest): def get_base_completion_call_args(self) -> dict: - litellm._turn_on_debug() return { "model": "bedrock/invoke/us.amazon.nova-micro-v1:0", } diff --git a/tests/llm_translation/test_max_completion_tokens.py b/tests/llm_translation/test_max_completion_tokens.py index c20ef3d1a0..a8f3dd50a8 100644 --- a/tests/llm_translation/test_max_completion_tokens.py +++ b/tests/llm_translation/test_max_completion_tokens.py @@ -319,12 +319,15 @@ def test_all_model_configs(): ) == {"max_tokens": 10} assert ( - "max_completion_tokens" in AmazonAnthropicConfig().get_supported_openai_params() + "max_completion_tokens" + in AmazonAnthropicConfig().get_supported_openai_params(model="") ) assert AmazonAnthropicConfig().map_openai_params( non_default_params={"max_completion_tokens": 10}, optional_params={}, + model="", + drop_params=False, ) == {"max_tokens_to_sample": 10} from litellm.llms.databricks.chat.handler import DatabricksConfig diff --git a/tests/llm_translation/test_optional_params.py b/tests/llm_translation/test_optional_params.py index 09071debc8..ad698e54c7 100644 --- a/tests/llm_translation/test_optional_params.py +++ b/tests/llm_translation/test_optional_params.py @@ -1114,3 +1114,257 @@ def test_anthropic_thinking_param(model, expected_thinking): assert "thinking" in optional_params else: assert "thinking" not in optional_params + + +def test_bedrock_invoke_anthropic_max_tokens(): + passed_params = { + "model": "invoke/us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "functions": None, + "function_call": None, + "temperature": 0.8, + "top_p": None, + "n": 1, + "stream": False, + "stream_options": None, + "stop": None, + "max_tokens": None, + "max_completion_tokens": 1024, + "modalities": None, + "prediction": None, + "audio": None, + "presence_penalty": None, + "frequency_penalty": None, + "logit_bias": None, + "user": None, + "custom_llm_provider": "bedrock", + "response_format": {"type": "text"}, + "seed": None, + "tools": [ + { + "type": "function", + "function": { + "name": "generate_plan", + "description": "Generate a plan to execute the task using only the tools outlined in your context.", + "input_schema": { + "type": "object", + "properties": { + "steps": { + "type": "array", + "items": { + "type": "object", + "properties": { + "type": { + "type": "string", + "description": "The type of step to execute", + }, + "tool_name": { + "type": "string", + "description": "The name of the tool to use for this step", + }, + "tool_input": { + "type": "object", + "description": "The input to pass to the tool. Make sure this complies with the schema for the tool.", + }, + "tool_output": { + "type": "object", + "description": "(Optional) The output from the tool if needed for future steps. Make sure this complies with the schema for the tool.", + }, + }, + "required": ["type"], + }, + } + }, + }, + }, + }, + { + "type": "function", + "function": { + "name": "generate_wire_tool", + "description": "Create a wire transfer with complete wire instructions", + "input_schema": { + "type": "object", + "properties": { + "company_id": { + "type": "integer", + "description": "The ID of the company receiving the investment", + }, + "investment_id": { + "type": "integer", + "description": "The ID of the investment memo", + }, + "dollar_amount": { + "type": "number", + "description": "The amount to wire in USD", + }, + "wiring_instructions": { + "type": "object", + "description": "Complete bank account and routing information for the wire", + "properties": { + "account_name": { + "type": "string", + "description": "Name on the bank account", + }, + "address_1": { + "type": "string", + "description": "Primary address line", + }, + "address_2": { + "type": "string", + "description": "Secondary address line (optional)", + }, + "city": {"type": "string"}, + "state": {"type": "string"}, + "zip": {"type": "string"}, + "country": {"type": "string", "default": "US"}, + "bank_name": {"type": "string"}, + "account_number": {"type": "string"}, + "routing_number": {"type": "string"}, + "account_type": { + "type": "string", + "enum": ["checking", "savings"], + "default": "checking", + }, + "swift_code": { + "type": "string", + "description": "Required for international wires", + }, + "iban": { + "type": "string", + "description": "Required for some international wires", + }, + "bank_city": {"type": "string"}, + "bank_state": {"type": "string"}, + "bank_country": {"type": "string", "default": "US"}, + "bank_to_bank_instructions": { + "type": "string", + "description": "Additional instructions for the bank (optional)", + }, + "intermediary_bank_name": { + "type": "string", + "description": "Name of intermediary bank if required (optional)", + }, + }, + "required": [ + "account_name", + "address_1", + "country", + "bank_name", + "account_number", + "routing_number", + "account_type", + "bank_country", + ], + }, + }, + "required": [ + "company_id", + "investment_id", + "dollar_amount", + "wiring_instructions", + ], + }, + }, + }, + { + "type": "function", + "function": { + "name": "search_companies", + "description": "Search for companies by name or other criteria to get their IDs", + "input_schema": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "Name or part of name to search for", + }, + "batch": { + "type": "string", + "description": 'Optional batch filter (e.g., "W21", "S22")', + }, + "status": { + "type": "string", + "enum": [ + "live", + "dead", + "adrift", + "exited", + "went_public", + "all", + ], + "description": "Filter by company status", + "default": "live", + }, + "limit": { + "type": "integer", + "description": "Maximum number of results to return", + "default": 10, + }, + }, + "required": ["query"], + }, + "output_schema": { + "type": "object", + "properties": { + "status": { + "type": "string", + "description": "Success or error status", + }, + "results": { + "type": "array", + "description": "List of companies matching the search criteria", + "items": { + "type": "object", + "properties": { + "id": { + "type": "integer", + "description": "Company ID to use in other API calls", + }, + "name": {"type": "string"}, + "batch": {"type": "string"}, + "status": {"type": "string"}, + "valuation": {"type": "string"}, + "url": {"type": "string"}, + "description": {"type": "string"}, + "founders": {"type": "string"}, + }, + }, + }, + "results_count": { + "type": "integer", + "description": "Number of companies returned", + }, + "total_matches": { + "type": "integer", + "description": "Total number of matches found", + }, + }, + }, + }, + }, + ], + "tool_choice": None, + "max_retries": 0, + "logprobs": None, + "top_logprobs": None, + "extra_headers": None, + "api_version": None, + "parallel_tool_calls": None, + "drop_params": True, + "reasoning_effort": None, + "additional_drop_params": None, + "messages": [ + { + "role": "system", + "content": "You are an AI assistant that helps prepare a wire for a pro rata investment.", + }, + {"role": "user", "content": [{"type": "text", "text": "hi"}]}, + ], + "thinking": None, + "kwargs": {}, + } + optional_params = get_optional_params(**passed_params) + print(f"optional_params: {optional_params}") + + assert "max_tokens_to_sample" not in optional_params + assert optional_params["max_tokens"] == 1024