diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md index 61e6e80487..4b85b55533 100644 --- a/docs/my-website/docs/providers/bedrock.md +++ b/docs/my-website/docs/providers/bedrock.md @@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor | Property | Details | |-------|-------| | Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). | -| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy) | +| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) | | Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) | | Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` | | Pass-through Endpoint | [Supported](../pass_through/bedrock.md) | @@ -1277,6 +1277,74 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \ https://some-api-url/models ``` +## Bedrock Imported Models (Deepseek) + +| Property | Details | +|----------|---------| +| Provider Route | `bedrock/llama/{model_arn}` | +| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) | + +Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec + + + + + +```python +from litellm import completion +import os + +response = completion( + model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn} + messages=[{"role": "user", "content": "Tell me a joke"}], +) +``` + + + + + + +**1. Add to config** + +```yaml +model_list: + - model_name: DeepSeek-R1-Distill-Llama-70B + litellm_params: + model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n + +``` + +**2. Start proxy** + +```bash +litellm --config /path/to/config.yaml + +# RUNNING at http://0.0.0.0:4000 +``` + +**3. Test it!** + +```bash +curl --location 'http://0.0.0.0:4000/chat/completions' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config + "messages": [ + { + "role": "user", + "content": "what llm are you" + } + ], + }' +``` + + + + + + ## Provisioned throughput models To use provisioned throughput Bedrock models pass - `model=bedrock/`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models) diff --git a/litellm/__init__.py b/litellm/__init__.py index 3032d1b8c6..4c6014e309 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -359,6 +359,9 @@ BEDROCK_CONVERSE_MODELS = [ "meta.llama3-2-11b-instruct-v1:0", "meta.llama3-2-90b-instruct-v1:0", ] +BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[ + "cohere", "anthropic", "mistral", "amazon", "meta", "llama" +] ####### COMPLETION MODELS ################### open_ai_chat_completion_models: List = [] open_ai_text_completion_models: List = [] diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py index 5ade1dc2dc..59d9917aa2 100644 --- a/litellm/llms/bedrock/chat/invoke_handler.py +++ b/litellm/llms/bedrock/chat/invoke_handler.py @@ -19,12 +19,14 @@ from typing import ( Tuple, Union, cast, + get_args, ) import httpx # type: ignore import litellm from litellm import verbose_logger +from litellm._logging import print_verbose from litellm.caching.caching import InMemoryCache from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.litellm_logging import Logging @@ -206,7 +208,7 @@ async def make_call( api_key="", data=data, messages=messages, - print_verbose=litellm.print_verbose, + print_verbose=print_verbose, encoding=litellm.encoding, ) # type: ignore completion_stream: Any = MockResponseIterator( @@ -286,7 +288,7 @@ class BedrockLLM(BaseAWSLLM): prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="bedrock" ) - elif provider == "meta": + elif provider == "meta" or provider == "llama": prompt = prompt_factory( model=model, messages=messages, custom_llm_provider="bedrock" ) @@ -318,7 +320,7 @@ class BedrockLLM(BaseAWSLLM): print_verbose, encoding, ) -> Union[ModelResponse, CustomStreamWrapper]: - provider = model.split(".")[0] + provider = self.get_bedrock_invoke_provider(model) ## LOGGING logging_obj.post_call( input=messages, @@ -465,7 +467,7 @@ class BedrockLLM(BaseAWSLLM): outputText = ( completion_response.get("completions")[0].get("data").get("text") ) - elif provider == "meta": + elif provider == "meta" or provider == "llama": outputText = completion_response["generation"] elif provider == "mistral": outputText = completion_response["outputs"][0]["text"] @@ -597,13 +599,13 @@ class BedrockLLM(BaseAWSLLM): ## SETUP ## stream = optional_params.pop("stream", None) - modelId = optional_params.pop("model_id", None) - if modelId is not None: - modelId = self.encode_model_id(model_id=modelId) - else: - modelId = model - provider = model.split(".")[0] + provider = self.get_bedrock_invoke_provider(model) + modelId = self.get_bedrock_model_id( + model=model, + provider=provider, + optional_params=optional_params, + ) ## CREDENTIALS ## # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them @@ -785,7 +787,7 @@ class BedrockLLM(BaseAWSLLM): "textGenerationConfig": inference_params, } ) - elif provider == "meta": + elif provider == "meta" or provider == "llama": ## LOAD CONFIG config = litellm.AmazonLlamaConfig.get_config() for k, v in config.items(): @@ -1044,6 +1046,74 @@ class BedrockLLM(BaseAWSLLM): ) return streaming_response + @staticmethod + def get_bedrock_invoke_provider( + model: str, + ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: + """ + Helper function to get the bedrock provider from the model + + handles 2 scenarions: + 1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic` + 2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama` + """ + _split_model = model.split(".")[0] + if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): + return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model) + + # If not a known provider, check for pattern with two slashes + provider = BedrockLLM._get_provider_from_model_path(model) + if provider is not None: + return provider + return None + + @staticmethod + def _get_provider_from_model_path( + model_path: str, + ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]: + """ + Helper function to get the provider from a model path with format: provider/model-name + + Args: + model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name') + + Returns: + Optional[str]: The provider name, or None if no valid provider found + """ + parts = model_path.split("/") + if len(parts) >= 1: + provider = parts[0] + if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL): + return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider) + return None + + def get_bedrock_model_id( + self, + optional_params: dict, + provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL], + model: str, + ) -> str: + modelId = optional_params.pop("model_id", None) + if modelId is not None: + modelId = self.encode_model_id(model_id=modelId) + else: + modelId = model + + if provider == "llama" and "llama/" in modelId: + modelId = self._get_model_id_for_llama_like_model(modelId) + + return modelId + + def _get_model_id_for_llama_like_model( + self, + model: str, + ) -> str: + """ + Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models + """ + model_id = model.replace("llama/", "") + return self.encode_model_id(model_id=model_id) + def get_response_stream_shape(): global _response_stream_shape_cache diff --git a/litellm/utils.py b/litellm/utils.py index 0a69c861bd..9c4beaea90 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -6045,20 +6045,23 @@ class ProviderConfigManager: return litellm.PetalsConfig() elif litellm.LlmProviders.BEDROCK == provider: base_model = litellm.AmazonConverseConfig()._get_base_model(model) + bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model) if ( base_model in litellm.bedrock_converse_models or "converse_like" in model ): return litellm.AmazonConverseConfig() - elif "amazon" in model: # amazon titan llms + elif bedrock_provider == "amazon": # amazon titan llms return litellm.AmazonTitanConfig() - elif "meta" in model: # amazon / meta llms + elif ( + bedrock_provider == "meta" or bedrock_provider == "llama" + ): # amazon / meta llms return litellm.AmazonLlamaConfig() - elif "ai21" in model: # ai21 llms + elif bedrock_provider == "ai21": # ai21 llms return litellm.AmazonAI21Config() - elif "cohere" in model: # cohere models on bedrock + elif bedrock_provider == "cohere": # cohere models on bedrock return litellm.AmazonCohereConfig() - elif "mistral" in model: # mistral models on bedrock + elif bedrock_provider == "mistral": # mistral models on bedrock return litellm.AmazonMistralConfig() return litellm.OpenAIGPTConfig() diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py index dd59415443..5f9c01f7bb 100644 --- a/tests/llm_translation/test_bedrock_completion.py +++ b/tests/llm_translation/test_bedrock_completion.py @@ -2529,3 +2529,54 @@ def test_bedrock_custom_proxy(): assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models" assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token" + + +def test_bedrock_custom_deepseek(): + from litellm.llms.custom_httpx.http_handler import HTTPHandler + import json + + litellm._turn_on_debug() + client = HTTPHandler() + + with patch.object(client, "post") as mock_post: + # Mock the response + mock_response = Mock() + mock_response.text = json.dumps( + {"generation": "Here's a joke...", "stop_reason": "stop"} + ) + mock_response.status_code = 200 + # Add required response attributes + mock_response.headers = {"Content-Type": "application/json"} + mock_response.json = lambda: json.loads(mock_response.text) + mock_post.return_value = mock_response + + try: + response = completion( + model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # Updated to specify provider + messages=[{"role": "user", "content": "Tell me a joke"}], + max_tokens=100, + client=client, + ) + + # Print request details + print("\nRequest Details:") + print(f"URL: {mock_post.call_args.kwargs['url']}") + + # Verify the URL + assert ( + mock_post.call_args.kwargs["url"] + == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke" + ) + + # Verify the request body format + request_body = json.loads(mock_post.call_args.kwargs["data"]) + print("request_body=", json.dumps(request_body, indent=4, default=str)) + assert "prompt" in request_body + assert request_body["prompt"] == "Tell me a joke" + + # follows the llama spec + assert request_body["max_gen_len"] == 100 + + except Exception as e: + print(f"Error: {str(e)}") + raise e