diff --git a/docs/my-website/docs/providers/bedrock.md b/docs/my-website/docs/providers/bedrock.md
index 61e6e80487..4b85b55533 100644
--- a/docs/my-website/docs/providers/bedrock.md
+++ b/docs/my-website/docs/providers/bedrock.md
@@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
| Property | Details |
|-------|-------|
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
-| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy) |
+| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) |
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
@@ -1277,6 +1277,74 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
https://some-api-url/models
```
+## Bedrock Imported Models (Deepseek)
+
+| Property | Details |
+|----------|---------|
+| Provider Route | `bedrock/llama/{model_arn}` |
+| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
+
+Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
+
+
+
+
+
+```python
+from litellm import completion
+import os
+
+response = completion(
+ model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
+ messages=[{"role": "user", "content": "Tell me a joke"}],
+)
+```
+
+
+
+
+
+
+**1. Add to config**
+
+```yaml
+model_list:
+ - model_name: DeepSeek-R1-Distill-Llama-70B
+ litellm_params:
+ model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
+
+```
+
+**2. Start proxy**
+
+```bash
+litellm --config /path/to/config.yaml
+
+# RUNNING at http://0.0.0.0:4000
+```
+
+**3. Test it!**
+
+```bash
+curl --location 'http://0.0.0.0:4000/chat/completions' \
+ --header 'Authorization: Bearer sk-1234' \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
+ "messages": [
+ {
+ "role": "user",
+ "content": "what llm are you"
+ }
+ ],
+ }'
+```
+
+
+
+
+
+
## Provisioned throughput models
To use provisioned throughput Bedrock models pass
- `model=bedrock/`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
diff --git a/litellm/__init__.py b/litellm/__init__.py
index 3032d1b8c6..4c6014e309 100644
--- a/litellm/__init__.py
+++ b/litellm/__init__.py
@@ -359,6 +359,9 @@ BEDROCK_CONVERSE_MODELS = [
"meta.llama3-2-11b-instruct-v1:0",
"meta.llama3-2-90b-instruct-v1:0",
]
+BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
+ "cohere", "anthropic", "mistral", "amazon", "meta", "llama"
+]
####### COMPLETION MODELS ###################
open_ai_chat_completion_models: List = []
open_ai_text_completion_models: List = []
diff --git a/litellm/llms/bedrock/chat/invoke_handler.py b/litellm/llms/bedrock/chat/invoke_handler.py
index 5ade1dc2dc..59d9917aa2 100644
--- a/litellm/llms/bedrock/chat/invoke_handler.py
+++ b/litellm/llms/bedrock/chat/invoke_handler.py
@@ -19,12 +19,14 @@ from typing import (
Tuple,
Union,
cast,
+ get_args,
)
import httpx # type: ignore
import litellm
from litellm import verbose_logger
+from litellm._logging import print_verbose
from litellm.caching.caching import InMemoryCache
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging
@@ -206,7 +208,7 @@ async def make_call(
api_key="",
data=data,
messages=messages,
- print_verbose=litellm.print_verbose,
+ print_verbose=print_verbose,
encoding=litellm.encoding,
) # type: ignore
completion_stream: Any = MockResponseIterator(
@@ -286,7 +288,7 @@ class BedrockLLM(BaseAWSLLM):
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
- elif provider == "meta":
+ elif provider == "meta" or provider == "llama":
prompt = prompt_factory(
model=model, messages=messages, custom_llm_provider="bedrock"
)
@@ -318,7 +320,7 @@ class BedrockLLM(BaseAWSLLM):
print_verbose,
encoding,
) -> Union[ModelResponse, CustomStreamWrapper]:
- provider = model.split(".")[0]
+ provider = self.get_bedrock_invoke_provider(model)
## LOGGING
logging_obj.post_call(
input=messages,
@@ -465,7 +467,7 @@ class BedrockLLM(BaseAWSLLM):
outputText = (
completion_response.get("completions")[0].get("data").get("text")
)
- elif provider == "meta":
+ elif provider == "meta" or provider == "llama":
outputText = completion_response["generation"]
elif provider == "mistral":
outputText = completion_response["outputs"][0]["text"]
@@ -597,13 +599,13 @@ class BedrockLLM(BaseAWSLLM):
## SETUP ##
stream = optional_params.pop("stream", None)
- modelId = optional_params.pop("model_id", None)
- if modelId is not None:
- modelId = self.encode_model_id(model_id=modelId)
- else:
- modelId = model
- provider = model.split(".")[0]
+ provider = self.get_bedrock_invoke_provider(model)
+ modelId = self.get_bedrock_model_id(
+ model=model,
+ provider=provider,
+ optional_params=optional_params,
+ )
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
@@ -785,7 +787,7 @@ class BedrockLLM(BaseAWSLLM):
"textGenerationConfig": inference_params,
}
)
- elif provider == "meta":
+ elif provider == "meta" or provider == "llama":
## LOAD CONFIG
config = litellm.AmazonLlamaConfig.get_config()
for k, v in config.items():
@@ -1044,6 +1046,74 @@ class BedrockLLM(BaseAWSLLM):
)
return streaming_response
+ @staticmethod
+ def get_bedrock_invoke_provider(
+ model: str,
+ ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
+ """
+ Helper function to get the bedrock provider from the model
+
+ handles 2 scenarions:
+ 1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
+ 2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
+ """
+ _split_model = model.split(".")[0]
+ if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
+ return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
+
+ # If not a known provider, check for pattern with two slashes
+ provider = BedrockLLM._get_provider_from_model_path(model)
+ if provider is not None:
+ return provider
+ return None
+
+ @staticmethod
+ def _get_provider_from_model_path(
+ model_path: str,
+ ) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
+ """
+ Helper function to get the provider from a model path with format: provider/model-name
+
+ Args:
+ model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
+
+ Returns:
+ Optional[str]: The provider name, or None if no valid provider found
+ """
+ parts = model_path.split("/")
+ if len(parts) >= 1:
+ provider = parts[0]
+ if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
+ return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
+ return None
+
+ def get_bedrock_model_id(
+ self,
+ optional_params: dict,
+ provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL],
+ model: str,
+ ) -> str:
+ modelId = optional_params.pop("model_id", None)
+ if modelId is not None:
+ modelId = self.encode_model_id(model_id=modelId)
+ else:
+ modelId = model
+
+ if provider == "llama" and "llama/" in modelId:
+ modelId = self._get_model_id_for_llama_like_model(modelId)
+
+ return modelId
+
+ def _get_model_id_for_llama_like_model(
+ self,
+ model: str,
+ ) -> str:
+ """
+ Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
+ """
+ model_id = model.replace("llama/", "")
+ return self.encode_model_id(model_id=model_id)
+
def get_response_stream_shape():
global _response_stream_shape_cache
diff --git a/litellm/utils.py b/litellm/utils.py
index 0a69c861bd..9c4beaea90 100644
--- a/litellm/utils.py
+++ b/litellm/utils.py
@@ -6045,20 +6045,23 @@ class ProviderConfigManager:
return litellm.PetalsConfig()
elif litellm.LlmProviders.BEDROCK == provider:
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
+ bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
if (
base_model in litellm.bedrock_converse_models
or "converse_like" in model
):
return litellm.AmazonConverseConfig()
- elif "amazon" in model: # amazon titan llms
+ elif bedrock_provider == "amazon": # amazon titan llms
return litellm.AmazonTitanConfig()
- elif "meta" in model: # amazon / meta llms
+ elif (
+ bedrock_provider == "meta" or bedrock_provider == "llama"
+ ): # amazon / meta llms
return litellm.AmazonLlamaConfig()
- elif "ai21" in model: # ai21 llms
+ elif bedrock_provider == "ai21": # ai21 llms
return litellm.AmazonAI21Config()
- elif "cohere" in model: # cohere models on bedrock
+ elif bedrock_provider == "cohere": # cohere models on bedrock
return litellm.AmazonCohereConfig()
- elif "mistral" in model: # mistral models on bedrock
+ elif bedrock_provider == "mistral": # mistral models on bedrock
return litellm.AmazonMistralConfig()
return litellm.OpenAIGPTConfig()
diff --git a/tests/llm_translation/test_bedrock_completion.py b/tests/llm_translation/test_bedrock_completion.py
index dd59415443..5f9c01f7bb 100644
--- a/tests/llm_translation/test_bedrock_completion.py
+++ b/tests/llm_translation/test_bedrock_completion.py
@@ -2529,3 +2529,54 @@ def test_bedrock_custom_proxy():
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"
+
+
+def test_bedrock_custom_deepseek():
+ from litellm.llms.custom_httpx.http_handler import HTTPHandler
+ import json
+
+ litellm._turn_on_debug()
+ client = HTTPHandler()
+
+ with patch.object(client, "post") as mock_post:
+ # Mock the response
+ mock_response = Mock()
+ mock_response.text = json.dumps(
+ {"generation": "Here's a joke...", "stop_reason": "stop"}
+ )
+ mock_response.status_code = 200
+ # Add required response attributes
+ mock_response.headers = {"Content-Type": "application/json"}
+ mock_response.json = lambda: json.loads(mock_response.text)
+ mock_post.return_value = mock_response
+
+ try:
+ response = completion(
+ model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # Updated to specify provider
+ messages=[{"role": "user", "content": "Tell me a joke"}],
+ max_tokens=100,
+ client=client,
+ )
+
+ # Print request details
+ print("\nRequest Details:")
+ print(f"URL: {mock_post.call_args.kwargs['url']}")
+
+ # Verify the URL
+ assert (
+ mock_post.call_args.kwargs["url"]
+ == "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
+ )
+
+ # Verify the request body format
+ request_body = json.loads(mock_post.call_args.kwargs["data"])
+ print("request_body=", json.dumps(request_body, indent=4, default=str))
+ assert "prompt" in request_body
+ assert request_body["prompt"] == "Tell me a joke"
+
+ # follows the llama spec
+ assert request_body["max_gen_len"] == 100
+
+ except Exception as e:
+ print(f"Error: {str(e)}")
+ raise e