mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Feat) add bedrock/deepseek custom import models (#8132)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 16s
* add support for using llama spec with bedrock * fix get_bedrock_invoke_provider * add support for using bedrock provider in mappings * working request * test_bedrock_custom_deepseek * test_bedrock_custom_deepseek * fix _get_model_id_for_llama_like_model * test_bedrock_custom_deepseek * doc DeepSeek-R1-Distill-Llama-70B * test_bedrock_custom_deepseek
This commit is contained in:
parent
29a8a613a7
commit
9ff27809b2
5 changed files with 212 additions and 17 deletions
|
@ -7,7 +7,7 @@ ALL Bedrock models (Anthropic, Meta, Deepseek, Mistral, Amazon, etc.) are Suppor
|
|||
| Property | Details |
|
||||
|-------|-------|
|
||||
| Description | Amazon Bedrock is a fully managed service that offers a choice of high-performing foundation models (FMs). |
|
||||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy) |
|
||||
| Provider Route on LiteLLM | `bedrock/`, [`bedrock/converse/`](#set-converse--invoke-route), [`bedrock/invoke/`](#set-invoke-route), [`bedrock/converse_like/`](#calling-via-internal-proxy), [`bedrock/llama/`](#bedrock-imported-models-deepseek) |
|
||||
| Provider Doc | [Amazon Bedrock ↗](https://docs.aws.amazon.com/bedrock/latest/userguide/what-is-bedrock.html) |
|
||||
| Supported OpenAI Endpoints | `/chat/completions`, `/completions`, `/embeddings`, `/images/generations` |
|
||||
| Pass-through Endpoint | [Supported](../pass_through/bedrock.md) |
|
||||
|
@ -1277,6 +1277,74 @@ curl -X POST 'http://0.0.0.0:4000/chat/completions' \
|
|||
https://some-api-url/models
|
||||
```
|
||||
|
||||
## Bedrock Imported Models (Deepseek)
|
||||
|
||||
| Property | Details |
|
||||
|----------|---------|
|
||||
| Provider Route | `bedrock/llama/{model_arn}` |
|
||||
| Provider Documentation | [Bedrock Imported Models](https://docs.aws.amazon.com/bedrock/latest/userguide/model-customization-import-model.html), [Deepseek Bedrock Imported Model](https://aws.amazon.com/blogs/machine-learning/deploy-deepseek-r1-distilled-llama-models-with-amazon-bedrock-custom-model-import/) |
|
||||
|
||||
Use this route to call Bedrock Imported Models that follow the `llama` Invoke Request / Response spec
|
||||
|
||||
|
||||
<Tabs>
|
||||
<TabItem value="sdk" label="SDK">
|
||||
|
||||
```python
|
||||
from litellm import completion
|
||||
import os
|
||||
|
||||
response = completion(
|
||||
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # bedrock/llama/{your-model-arn}
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
)
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
|
||||
<TabItem value="proxy" label="Proxy">
|
||||
|
||||
|
||||
**1. Add to config**
|
||||
|
||||
```yaml
|
||||
model_list:
|
||||
- model_name: DeepSeek-R1-Distill-Llama-70B
|
||||
litellm_params:
|
||||
model: bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n
|
||||
|
||||
```
|
||||
|
||||
**2. Start proxy**
|
||||
|
||||
```bash
|
||||
litellm --config /path/to/config.yaml
|
||||
|
||||
# RUNNING at http://0.0.0.0:4000
|
||||
```
|
||||
|
||||
**3. Test it!**
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:4000/chat/completions' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"model": "DeepSeek-R1-Distill-Llama-70B", # 👈 the 'model_name' in config
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "what llm are you"
|
||||
}
|
||||
],
|
||||
}'
|
||||
```
|
||||
|
||||
</TabItem>
|
||||
</Tabs>
|
||||
|
||||
|
||||
|
||||
## Provisioned throughput models
|
||||
To use provisioned throughput Bedrock models pass
|
||||
- `model=bedrock/<base-model>`, example `model=bedrock/anthropic.claude-v2`. Set `model` to any of the [Supported AWS models](#supported-aws-bedrock-models)
|
||||
|
|
|
@ -359,6 +359,9 @@ BEDROCK_CONVERSE_MODELS = [
|
|||
"meta.llama3-2-11b-instruct-v1:0",
|
||||
"meta.llama3-2-90b-instruct-v1:0",
|
||||
]
|
||||
BEDROCK_INVOKE_PROVIDERS_LITERAL = Literal[
|
||||
"cohere", "anthropic", "mistral", "amazon", "meta", "llama"
|
||||
]
|
||||
####### COMPLETION MODELS ###################
|
||||
open_ai_chat_completion_models: List = []
|
||||
open_ai_text_completion_models: List = []
|
||||
|
|
|
@ -19,12 +19,14 @@ from typing import (
|
|||
Tuple,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
)
|
||||
|
||||
import httpx # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import print_verbose
|
||||
from litellm.caching.caching import InMemoryCache
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
|
@ -206,7 +208,7 @@ async def make_call(
|
|||
api_key="",
|
||||
data=data,
|
||||
messages=messages,
|
||||
print_verbose=litellm.print_verbose,
|
||||
print_verbose=print_verbose,
|
||||
encoding=litellm.encoding,
|
||||
) # type: ignore
|
||||
completion_stream: Any = MockResponseIterator(
|
||||
|
@ -286,7 +288,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
elif provider == "meta":
|
||||
elif provider == "meta" or provider == "llama":
|
||||
prompt = prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="bedrock"
|
||||
)
|
||||
|
@ -318,7 +320,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
print_verbose,
|
||||
encoding,
|
||||
) -> Union[ModelResponse, CustomStreamWrapper]:
|
||||
provider = model.split(".")[0]
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=messages,
|
||||
|
@ -465,7 +467,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
outputText = (
|
||||
completion_response.get("completions")[0].get("data").get("text")
|
||||
)
|
||||
elif provider == "meta":
|
||||
elif provider == "meta" or provider == "llama":
|
||||
outputText = completion_response["generation"]
|
||||
elif provider == "mistral":
|
||||
outputText = completion_response["outputs"][0]["text"]
|
||||
|
@ -597,13 +599,13 @@ class BedrockLLM(BaseAWSLLM):
|
|||
|
||||
## SETUP ##
|
||||
stream = optional_params.pop("stream", None)
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
provider = model.split(".")[0]
|
||||
provider = self.get_bedrock_invoke_provider(model)
|
||||
modelId = self.get_bedrock_model_id(
|
||||
model=model,
|
||||
provider=provider,
|
||||
optional_params=optional_params,
|
||||
)
|
||||
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
@ -785,7 +787,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
"textGenerationConfig": inference_params,
|
||||
}
|
||||
)
|
||||
elif provider == "meta":
|
||||
elif provider == "meta" or provider == "llama":
|
||||
## LOAD CONFIG
|
||||
config = litellm.AmazonLlamaConfig.get_config()
|
||||
for k, v in config.items():
|
||||
|
@ -1044,6 +1046,74 @@ class BedrockLLM(BaseAWSLLM):
|
|||
)
|
||||
return streaming_response
|
||||
|
||||
@staticmethod
|
||||
def get_bedrock_invoke_provider(
|
||||
model: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the bedrock provider from the model
|
||||
|
||||
handles 2 scenarions:
|
||||
1. model=anthropic.claude-3-5-sonnet-20240620-v1:0 -> Returns `anthropic`
|
||||
2. model=llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n -> Returns `llama`
|
||||
"""
|
||||
_split_model = model.split(".")[0]
|
||||
if _split_model in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, _split_model)
|
||||
|
||||
# If not a known provider, check for pattern with two slashes
|
||||
provider = BedrockLLM._get_provider_from_model_path(model)
|
||||
if provider is not None:
|
||||
return provider
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _get_provider_from_model_path(
|
||||
model_path: str,
|
||||
) -> Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL]:
|
||||
"""
|
||||
Helper function to get the provider from a model path with format: provider/model-name
|
||||
|
||||
Args:
|
||||
model_path (str): The model path (e.g., 'llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n' or 'anthropic/model-name')
|
||||
|
||||
Returns:
|
||||
Optional[str]: The provider name, or None if no valid provider found
|
||||
"""
|
||||
parts = model_path.split("/")
|
||||
if len(parts) >= 1:
|
||||
provider = parts[0]
|
||||
if provider in get_args(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL):
|
||||
return cast(litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL, provider)
|
||||
return None
|
||||
|
||||
def get_bedrock_model_id(
|
||||
self,
|
||||
optional_params: dict,
|
||||
provider: Optional[litellm.BEDROCK_INVOKE_PROVIDERS_LITERAL],
|
||||
model: str,
|
||||
) -> str:
|
||||
modelId = optional_params.pop("model_id", None)
|
||||
if modelId is not None:
|
||||
modelId = self.encode_model_id(model_id=modelId)
|
||||
else:
|
||||
modelId = model
|
||||
|
||||
if provider == "llama" and "llama/" in modelId:
|
||||
modelId = self._get_model_id_for_llama_like_model(modelId)
|
||||
|
||||
return modelId
|
||||
|
||||
def _get_model_id_for_llama_like_model(
|
||||
self,
|
||||
model: str,
|
||||
) -> str:
|
||||
"""
|
||||
Remove `llama` from modelID since `llama` is simply a spec to follow for custom bedrock models
|
||||
"""
|
||||
model_id = model.replace("llama/", "")
|
||||
return self.encode_model_id(model_id=model_id)
|
||||
|
||||
|
||||
def get_response_stream_shape():
|
||||
global _response_stream_shape_cache
|
||||
|
|
|
@ -6045,20 +6045,23 @@ class ProviderConfigManager:
|
|||
return litellm.PetalsConfig()
|
||||
elif litellm.LlmProviders.BEDROCK == provider:
|
||||
base_model = litellm.AmazonConverseConfig()._get_base_model(model)
|
||||
bedrock_provider = litellm.BedrockLLM.get_bedrock_invoke_provider(model)
|
||||
if (
|
||||
base_model in litellm.bedrock_converse_models
|
||||
or "converse_like" in model
|
||||
):
|
||||
return litellm.AmazonConverseConfig()
|
||||
elif "amazon" in model: # amazon titan llms
|
||||
elif bedrock_provider == "amazon": # amazon titan llms
|
||||
return litellm.AmazonTitanConfig()
|
||||
elif "meta" in model: # amazon / meta llms
|
||||
elif (
|
||||
bedrock_provider == "meta" or bedrock_provider == "llama"
|
||||
): # amazon / meta llms
|
||||
return litellm.AmazonLlamaConfig()
|
||||
elif "ai21" in model: # ai21 llms
|
||||
elif bedrock_provider == "ai21": # ai21 llms
|
||||
return litellm.AmazonAI21Config()
|
||||
elif "cohere" in model: # cohere models on bedrock
|
||||
elif bedrock_provider == "cohere": # cohere models on bedrock
|
||||
return litellm.AmazonCohereConfig()
|
||||
elif "mistral" in model: # mistral models on bedrock
|
||||
elif bedrock_provider == "mistral": # mistral models on bedrock
|
||||
return litellm.AmazonMistralConfig()
|
||||
return litellm.OpenAIGPTConfig()
|
||||
|
||||
|
|
|
@ -2529,3 +2529,54 @@ def test_bedrock_custom_proxy():
|
|||
assert mock_post.call_args.kwargs["url"] == "https://some-api-url/models"
|
||||
|
||||
assert mock_post.call_args.kwargs["headers"]["Authorization"] == "Bearer Token"
|
||||
|
||||
|
||||
def test_bedrock_custom_deepseek():
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
import json
|
||||
|
||||
litellm._turn_on_debug()
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post") as mock_post:
|
||||
# Mock the response
|
||||
mock_response = Mock()
|
||||
mock_response.text = json.dumps(
|
||||
{"generation": "Here's a joke...", "stop_reason": "stop"}
|
||||
)
|
||||
mock_response.status_code = 200
|
||||
# Add required response attributes
|
||||
mock_response.headers = {"Content-Type": "application/json"}
|
||||
mock_response.json = lambda: json.loads(mock_response.text)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
try:
|
||||
response = completion(
|
||||
model="bedrock/llama/arn:aws:bedrock:us-east-1:086734376398:imported-model/r4c4kewx2s0n", # Updated to specify provider
|
||||
messages=[{"role": "user", "content": "Tell me a joke"}],
|
||||
max_tokens=100,
|
||||
client=client,
|
||||
)
|
||||
|
||||
# Print request details
|
||||
print("\nRequest Details:")
|
||||
print(f"URL: {mock_post.call_args.kwargs['url']}")
|
||||
|
||||
# Verify the URL
|
||||
assert (
|
||||
mock_post.call_args.kwargs["url"]
|
||||
== "https://bedrock-runtime.us-west-2.amazonaws.com/model/arn%3Aaws%3Abedrock%3Aus-east-1%3A086734376398%3Aimported-model%2Fr4c4kewx2s0n/invoke"
|
||||
)
|
||||
|
||||
# Verify the request body format
|
||||
request_body = json.loads(mock_post.call_args.kwargs["data"])
|
||||
print("request_body=", json.dumps(request_body, indent=4, default=str))
|
||||
assert "prompt" in request_body
|
||||
assert request_body["prompt"] == "Tell me a joke"
|
||||
|
||||
# follows the llama spec
|
||||
assert request_body["max_gen_len"] == 100
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error: {str(e)}")
|
||||
raise e
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue