refactor(bedrock.py): take model names from model cost dict

This commit is contained in:
Krrish Dholakia 2023-10-10 07:34:55 -07:00
parent 152ffca815
commit 0d863f00ad
6 changed files with 21 additions and 15 deletions

View file

@ -29,7 +29,7 @@ os.environ["AWS_SECRET_ACCESS_KEY"] = ""
os.environ["AWS_REGION_NAME"] = ""
response = completion(
model="bedrock/anthropic.claude-instant-v1",
model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}]
)
```
@ -41,7 +41,7 @@ import os
from litellm import completion
response = completion(
model="bedrock/anthropic.claude-instant-v1",
model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_access_key_id="",
aws_secret_access_key="",
@ -66,7 +66,7 @@ bedrock = boto3.client(
)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock,
)
@ -84,7 +84,7 @@ bedrock = dev_session.client(
)
response = completion(
model="bedrock/anthropic.claude-instant-v1",
model="anthropic.claude-instant-v1",
messages=[{ "content": "Hello, how are you?","role": "user"}],
aws_bedrock_client=bedrock,
)
@ -95,11 +95,14 @@ Here's an example of using a bedrock model with LiteLLM
| Model Name | Command | Environment Variables |
|--------------------------|------------------------------------------------------------------|---------------------------------------------------------------------|
| Anthropic Claude-V2 | `completion(model='bedrock/anthropic.claude-v2', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='bedrock/anthropic.claude-instant-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V1 | `completion(model='bedrock/anthropic.claude-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| AI21 J2-Ultra | `completion(model='bedrock/ai21.j2-ultra', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='bedrock/ai21.j2-mid', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Anthropic Claude-V2 | `completion(model='anthropic.claude-v2', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-Instant V1 | `completion(model='anthropic.claude-instant-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Anthropic Claude-V1 | `completion(model='anthropic.claude-v1', messages=messages)` | `os.environ['ANTHROPIC_ACCESS_KEY_ID']`, `os.environ['ANTHROPIC_SECRET_ACCESS_KEY']` |
| Amazon Titan Lite | `completion(model='amazon.titan-text-lite-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Amazon Titan Express | `completion(model='amazon.titan-text-express-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| Cohere Command | `completion(model='cohere.command-text-v14', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Mid | `completion(model='ai21.j2-mid-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
| AI21 J2-Ultra | `completion(model='ai21.j2-ultra-v1', messages=messages)` | `os.environ['AWS_ACCESS_KEY_ID']`, `os.environ['AWS_SECRET_ACCESS_KEY']`, `os.environ['AWS_REGION_NAME']` |
## Streaming

View file

@ -94,6 +94,7 @@ vertex_code_text_models: List = []
ai21_models: List = []
nlp_cloud_models: List = []
aleph_alpha_models: List = []
bedrock_models: List = []
for key, value in model_cost.items():
if value.get('litellm_provider') == 'openai':
open_ai_chat_completion_models.append(key)
@ -120,6 +121,8 @@ for key, value in model_cost.items():
nlp_cloud_models.append(key)
elif value.get('litellm_provider') == 'aleph_alpha':
aleph_alpha_models.append(key)
elif value.get('litellm_provider') == 'bedrock':
bedrock_models.append(key)
# well supported replicate llms
replicate_models: List = [
@ -196,11 +199,6 @@ petals_models = [
"petals-team/StableBeluga2",
]
bedrock_models: List = [
"amazon.titan-tg1-large",
"ai21.j2-grande-instruct"
]
ollama_models = [
"llama2"
]

View file

@ -806,7 +806,7 @@ def test_completion_bedrock_claude():
print("calling claude")
try:
response = completion(
model="bedrock/anthropic.claude-instant-v1",
model="anthropic.claude-instant-v1",
messages=messages,
max_tokens=10,
temperature=0.1,

View file

@ -1409,8 +1409,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None):
## petals
elif model in litellm.petals_models:
custom_llm_provider = "petals"
## bedrock
elif model in litellm.bedrock_models:
custom_llm_provider = "bedrock"
# openai embeddings
elif model in litellm.open_ai_embedding_models:
custom_llm_provider = "openai"
# cohere embeddings
elif model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere"