mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) add bedrock/deepseek custom import models (#8132)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
* add support for using llama spec with bedrock * fix get_bedrock_invoke_provider * add support for using bedrock provider in mappings * working request * test_bedrock_custom_deepseek * test_bedrock_custom_deepseek * fix _get_model_id_for_llama_like_model * test_bedrock_custom_deepseek * doc DeepSeek-R1-Distill-Llama-70B * test_bedrock_custom_deepseek
This commit is contained in:
parent
29a8a613a7
commit
9ff27809b2
5 changed files with 212 additions and 17 deletions
|
@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
|
||||||
| Property | Details |
|
| Property | Details |
|
||||||
|-------|-------|
|
|-------|-------|
|
||||||
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
|
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
|
||||||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy) |
|
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) |
|
||||||
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
||||||
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
||||||
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
||||||
|
@ -1277,6 +1277,74 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
||||||
https://some-api-url/models
|
https://some-api-url/models
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Bedrock Imported Models (Deepseek)
|
||||||
|
|
||||||
|
| Property | Details |
|
||||||
|
|----------|---------|
|
||||||
|
| Provider Route | `bedrock/llama/{model_arn}` |
|
||||||
|
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||||
|
|
||||||
|
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
|
||||||
|
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import completion
|
||||||
|
import os
|
||||||
|
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
|
||||||
|
<TabItem value="proxy" label="Proxy">
|
||||||
|
|
||||||
|
|
||||||
|
**1. Add to config**
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
model_list:
|
||||||
|
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||||
|
litellm_params:
|
||||||
|
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
**2. Start proxy**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
litellm --config /path/to/config.yaml
|
||||||
|
|
||||||
|
# RUNNING at http://0.0.0.0:4000
|
||||||
|
```
|
||||||
|
|
||||||
|
**3. Test it!**
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "what llm are you"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
## Provisioned throughput models
|
## Provisioned throughput models
|
||||||
To use provisioned throughput Bedrock models pass
|
To use provisioned throughput Bedrock models pass
|
||||||
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
||||||
|
|
|
@ -359,6 +359,9 @@ BEDROCK_CONVERSE_MODELS = [
|
||||||
"meta.llama3-2-11b-instruct-v1:0",
|
"meta.llama3-2-11b-instruct-v1:0",
|
||||||
"meta.llama3-2-90b-instruct-v1:0",
|
"meta.llama3-2-90b-instruct-v1:0",
|
||||||
]
|
]
|
||||||
|
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||||
|
"cohere", "anthropic", "mistral", "amazon", "meta", "llama"
|
||||||
|
]
|
||||||
####### COMPLETION MODELS ###################
|
####### COMPLETION MODELS ###################
|
||||||
open_ai_chat_completion_models: List = []
|
open_ai_chat_completion_models: List = []
|
||||||
open_ai_text_completion_models: List = []
|
open_ai_text_completion_models: List = []
|
||||||
|
|
|
@ -19,12 +19,14 @@ from typing import (
|
||||||
Tuple,
|
Tuple,
|
||||||
Union,
|
Union,
|
||||||
cast,
|
cast,
|
||||||
|
get_args,
|
||||||
)
|
)
|
||||||
|
|
||||||
import httpx # type: ignore
|
import httpx # type: ignore
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import verbose_logger
|
from litellm import verbose_logger
|
||||||
|
from litellm._logging import print_verbose
|
||||||
from litellm.caching.caching import InMemoryCache
|
from litellm.caching.caching import InMemoryCache
|
||||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||||
|
@ -206,7 +208,7 @@ async def make_call(
|
||||||
api_key="",
|
api_key="",
|
||||||
data=data,
|
data=data,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
print_verbose=litellm.print_verbose,
|
print_verbose=print_verbose,
|
||||||
encoding=litellm.encoding,
|
encoding=litellm.encoding,
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
completion_stream: Any = MockResponseIterator(
|
completion_stream: Any = MockResponseIterator(
|
||||||
|
@ -286,7 +288,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
prompt = prompt_factory(
|
prompt = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
elif provider == "meta":
|
elif provider == "meta" or provider == "llama":
|
||||||
prompt = prompt_factory(
|
prompt = prompt_factory(
|
||||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||||
)
|
)
|
||||||
|
@ -318,7 +320,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
print_verbose,
|
print_verbose,
|
||||||
encoding,
|
encoding,
|
||||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||||
provider = model.split(".")[0]
|
provider = self.get_bedrock_invoke_provider(model)
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
input=messages,
|
input=messages,
|
||||||
|
@ -465,7 +467,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
outputText = (
|
outputText = (
|
||||||
completion_response.get("completions")[0].get("data").get("text")
|
completion_response.get("completions")[0].get("data").get("text")
|
||||||
)
|
)
|
||||||
elif provider == "meta":
|
elif provider == "meta" or provider == "llama":
|
||||||
outputText = completion_response["generation"]
|
outputText = completion_response["generation"]
|
||||||
elif provider == "mistral":
|
elif provider == "mistral":
|
||||||
outputText = completion_response["outputs"][0]["text"]
|
outputText = completion_response["outputs"][0]["text"]
|
||||||
|
@ -597,13 +599,13 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
|
|
||||||
## SETUP ##
|
## SETUP ##
|
||||||
stream = optional_params.pop("stream", None)
|
stream = optional_params.pop("stream", None)
|
||||||
modelId = optional_params.pop("model_id", None)
|
|
||||||
if modelId is not None:
|
|
||||||
modelId = self.encode_model_id(model_id=modelId)
|
|
||||||
else:
|
|
||||||
modelId = model
|
|
||||||
|
|
||||||
provider = model.split(".")[0]
|
provider = self.get_bedrock_invoke_provider(model)
|
||||||
|
modelId = self.get_bedrock_model_id(
|
||||||
|
model=model,
|
||||||
|
provider=provider,
|
||||||
|
optional_params=optional_params,
|
||||||
|
)
|
||||||
|
|
||||||
## CREDENTIALS ##
|
## CREDENTIALS ##
|
||||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||||
|
@ -785,7 +787,7 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
"textGenerationConfig": inference_params,
|
"textGenerationConfig": inference_params,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif provider == "meta":
|
elif provider == "meta" or provider == "llama":
|
||||||
## LOAD CONFIG
|
## LOAD CONFIG
|
||||||
config = litellm.AmazonLlamaConfig.get_config()
|
config = litellm.AmazonLlamaConfig.get_config()
|
||||||
for k, v in config.items():
|
for k, v in config.items():
|
||||||
|
@ -1044,6 +1046,74 @@ class BedrockLLM(BaseAWSLLM):
|
||||||
)
|
)
|
||||||
return streaming_response
|
return streaming_response
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_bedrock_invoke_provider(
|
||||||
|
model: str,
|
||||||
|
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||||
|
"""
|
||||||
|
Helper function to get the bedrock provider from the model
|
||||||
|
|
||||||
|
handles 2 scenarions:
|
||||||
|
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||||
|
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||||
|
"""
|
||||||
|
_split_model = model.split(".")[0]
|
||||||
|
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||||
|
|
||||||
|
# If not a known provider, check for pattern with two slashes
|
||||||
|
provider = BedrockLLM._get_provider_from_model_path(model)
|
||||||
|
if provider is not None:
|
||||||
|
return provider
|
||||||
|
return None
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_provider_from_model_path(
|
||||||
|
model_path: str,
|
||||||
|
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||||
|
"""
|
||||||
|
Helper function to get the provider from a model path with format: provider/model-name
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[str]: The provider name, or None if no valid provider found
|
||||||
|
"""
|
||||||
|
parts = model_path.split("/")
|
||||||
|
if len(parts) >= 1:
|
||||||
|
provider = parts[0]
|
||||||
|
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||||
|
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_bedrock_model_id(
|
||||||
|
self,
|
||||||
|
optional_params: dict,
|
||||||
|
provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL],
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
modelId = optional_params.pop("model_id", None)
|
||||||
|
if modelId is not None:
|
||||||
|
modelId = self.encode_model_id(model_id=modelId)
|
||||||
|
else:
|
||||||
|
modelId = model
|
||||||
|
|
||||||
|
if provider == "llama" and "llama/" in modelId:
|
||||||
|
modelId = self._get_model_id_for_llama_like_model(modelId)
|
||||||
|
|
||||||
|
return modelId
|
||||||
|
|
||||||
|
def _get_model_id_for_llama_like_model(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
||||||
|
"""
|
||||||
|
model_id = model.replace("llama/", "")
|
||||||
|
return self.encode_model_id(model_id=model_id)
|
||||||
|
|
||||||
|
|
||||||
def get_response_stream_shape():
|
def get_response_stream_shape():
|
||||||
global _response_stream_shape_cache
|
global _response_stream_shape_cache
|
||||||
|
|
|
@ -6045,20 +6045,23 @@ class ProviderConfigManager:
|
||||||
return litellm.PetalsConfig()
|
return litellm.PetalsConfig()
|
||||||
elif litellm.LlmProviders.BEDROCK == provider:
|
elif litellm.LlmProviders.BEDROCK == provider:
|
||||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||||
|
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
|
||||||
if (
|
if (
|
||||||
base_model in litellm.bedrock_converse_models
|
base_model in litellm.bedrock_converse_models
|
||||||
or "converse_like" in model
|
or "converse_like" in model
|
||||||
):
|
):
|
||||||
return litellm.AmazonConverseConfig()
|
return litellm.AmazonConverseConfig()
|
||||||
elif "amazon" in model: # amazon titan llms
|
elif bedrock_provider == "amazon": # amazon titan llms
|
||||||
return litellm.AmazonTitanConfig()
|
return litellm.AmazonTitanConfig()
|
||||||
elif "meta" in model: # amazon / meta llms
|
elif (
|
||||||
|
bedrock_provider == "meta" or bedrock_provider == "llama"
|
||||||
|
): # amazon / meta llms
|
||||||
return litellm.AmazonLlamaConfig()
|
return litellm.AmazonLlamaConfig()
|
||||||
elif "ai21" in model: # ai21 llms
|
elif bedrock_provider == "ai21": # ai21 llms
|
||||||
return litellm.AmazonAI21Config()
|
return litellm.AmazonAI21Config()
|
||||||
elif "cohere" in model: # cohere models on bedrock
|
elif bedrock_provider == "cohere": # cohere models on bedrock
|
||||||
return litellm.AmazonCohereConfig()
|
return litellm.AmazonCohereConfig()
|
||||||
elif "mistral" in model: # mistral models on bedrock
|
elif bedrock_provider == "mistral": # mistral models on bedrock
|
||||||
return litellm.AmazonMistralConfig()
|
return litellm.AmazonMistralConfig()
|
||||||
return litellm.OpenAIGPTConfig()
|
return litellm.OpenAIGPTConfig()
|
||||||
|
|
||||||
|
|
|
@ -2529,3 +2529,54 @@ def test_bedrock_custom_proxy():
|
||||||
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
|
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
|
||||||
|
|
||||||
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"
|
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"
|
||||||
|
|
||||||
|
|
||||||
|
def test_bedrock_custom_deepseek():
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
import json
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
client = HTTPHandler()
|
||||||
|
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
# Mock the response
|
||||||
|
mock_response = Mock()
|
||||||
|
mock_response.text = json.dumps(
|
||||||
|
{"generation": "Here's a joke...", "stop_reason": "stop"}
|
||||||
|
)
|
||||||
|
mock_response.status_code = 200
|
||||||
|
# Add required response attributes
|
||||||
|
mock_response.headers = {"Content-Type": "application/json"}
|
||||||
|
mock_response.json = lambda: json.loads(mock_response.text)
|
||||||
|
mock_post.return_value = mock_response
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = completion(
|
||||||
|
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # Updated to specify provider
|
||||||
|
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||||
|
max_tokens=100,
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Print request details
|
||||||
|
print("\nRequest Details:")
|
||||||
|
print(f"URL: {mock_post.call_args.kwargs['url']}")
|
||||||
|
|
||||||
|
# Verify the URL
|
||||||
|
assert (
|
||||||
|
mock_post.call_args.kwargs["url"]
|
||||||
|
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the request body format
|
||||||
|
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
print("request_body=", json.dumps(request_body, indent=4, default=str))
|
||||||
|
assert "prompt" in request_body
|
||||||
|
assert request_body["prompt"] == "Tell me a joke"
|
||||||
|
|
||||||
|
# follows the llama spec
|
||||||
|
assert request_body["max_gen_len"] == 100
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error: {str(e)}")
|
||||||
|
raise e
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue