refactor(bedrock.py): take model names from model cost dict

This commit is contained in:
Krrish Dholakia 2023-10-10 07:34:55 -07:00
parent 36bfc0809f
commit 22ee0c6931
6 changed files with 21 additions and 15 deletions

View file

@ -29,7 +29,7 @@ os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = "" os.environ["AWS_REGION_NAME"] = ""
response = completion( response = completion(
model="bedrock/anthropic.claude-instant-v1", model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}] messages=[{ "content": "Hello, how are you?","role": "user"}]
) )
``` ```
@ -41,7 +41,7 @@ import os
from litellm import completion from litellm import completion
response = completion( response = completion(
model="bedrock/anthropic.claude-instant-v1", model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_access_key_id="", aws_access_key_id="",
aws_secret_access_key="", aws_secret_access_key="",
@ -66,7 +66,7 @@ bedrock = boto3.client(
) )
response = completion( response = completion(
model="bedrock/anthropic.claude-instant-v1", model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock, aws_bedrock_client=bedrock,
) )
@ -84,7 +84,7 @@ bedrock = dev_session.client(
) )
response = completion( response = completion(
model="bedrock/anthropic.claude-instant-v1", model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}], messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock, aws_bedrock_client=bedrock,
) )
@ -95,11 +95,14 @@ Here's an example of using a bedrock model with LiteLLM
| Model Name | Command | Environment Variables | | Model Name | Command | Environment Variables |
|--------------------------|------------------------------------------------------------------|---------------------------------------------------------------------| |--------------------------|------------------------------------------------------------------|---------------------------------------------------------------------|
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` | | Anthropic Claude-V2 | `completion(model='anthropic.claude-v2', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` | | Anthropic Claude-Instant V1 | `completion(model='anthropic.claude-instant-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V1 | `completion(model='bedrock/anthropic.claude-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` | | Anthropic Claude-V1 | `completion(model='anthropic.claude-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` | | Amazon Titan Lite | `completion(model='amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` | | Amazon Titan Express | `completion(model='amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Cohere Command | `completion(model='cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Ultra | `completion(model='ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
## Streaming ## Streaming

View file

@ -94,6 +94,7 @@ vertex_code_text_models: List = []
ai21_models: List = [] ai21_models: List = []
nlp_cloud_models: List = [] nlp_cloud_models: List = []
aleph_alpha_models: List = [] aleph_alpha_models: List = []
bedrock_models: List = []
for key, value in model_cost.items(): for key, value in model_cost.items():
if value.get('litellm_provider') == 'openai': if value.get('litellm_provider') == 'openai':
open_ai_chat_completion_models.append(key) open_ai_chat_completion_models.append(key)
@ -120,6 +121,8 @@ for key, value in model_cost.items():
nlp_cloud_models.append(key) nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha': elif value.get('litellm_provider') == 'aleph_alpha':
aleph_alpha_models.append(key) aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock':
bedrock_models.append(key)
# well supported replicate llms # well supported replicate llms
replicate_models: List = [ replicate_models: List = [
@ -196,11 +199,6 @@ petals_models = [
"petals-team/StableBeluga2", "petals-team/StableBeluga2",
] ]
bedrock_models: List = [
"amazon.titan-tg1-large",
"ai21.j2-grande-instruct"
]
ollama_models = [ ollama_models = [
"llama2" "llama2"
] ]

View file

@ -806,7 +806,7 @@ def test_completion_bedrock_claude():
print("calling claude") print("calling claude")
try: try:
response = completion( response = completion(
model="bedrock/anthropic.claude-instant-v1", model="anthropic.claude-instant-v1",
messages=messages, messages=messages,
max_tokens=10, max_tokens=10,
temperature=0.1, temperature=0.1,

View file

@ -1409,8 +1409,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
## petals ## petals
elif model in litellm.petals_models: elif model in litellm.petals_models:
custom_llm_provider = "petals" custom_llm_provider = "petals"
## bedrock
elif model in litellm.bedrock_models:
custom_llm_provider = "bedrock"
# openai embeddings
elif model in litellm.open_ai_embedding_models: elif model in litellm.open_ai_embedding_models:
custom_llm_provider = "openai" custom_llm_provider = "openai"
# cohere embeddings
elif model in litellm.cohere_embedding_models: elif model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere" custom_llm_provider = "cohere"