fix(utils.py): fix get_llm_provider to handle the ':' in anthropic/bedrock calls

This commit is contained in:
Krrish Dholakia 2023-12-07 14:19:11 -08:00
parent 2b04dc310a
commit 3846ec6124
2 changed files with 22 additions and 9 deletions

View file

@ -0,0 +1,19 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
def test_get_llm_provider():
_, response, _, _ = litellm.get_llm_provider(model="anthropic.claude-v2:1")
assert response == "bedrock"
test_get_llm_provider()

View file

@ -2580,10 +2580,10 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
elif model in litellm.anthropic_models: elif model in litellm.anthropic_models:
custom_llm_provider = "anthropic" custom_llm_provider = "anthropic"
## cohere ## cohere
elif model in litellm.cohere_models: elif model in litellm.cohere_models or model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere" custom_llm_provider = "cohere"
## replicate ## replicate
elif model in litellm.replicate_models or ":" in model: elif model in litellm.replicate_models or (":" in model and len(model)>64):
model_parts = model.split(":") model_parts = model.split(":")
if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3" if len(model_parts) > 1 and len(model_parts[1])==64: ## checks if model name has a 64 digit code - e.g. "meta/llama-2-70b-chat:02e509c789964a7ea8736978a43525956ef40397be9033abf9fd2badfe68c9e3"
custom_llm_provider = "replicate" custom_llm_provider = "replicate"
@ -2619,17 +2619,11 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_
elif model in litellm.petals_models: elif model in litellm.petals_models:
custom_llm_provider = "petals" custom_llm_provider = "petals"
## bedrock ## bedrock
elif model in litellm.bedrock_models: elif model in litellm.bedrock_models or model in litellm.bedrock_embedding_models:
custom_llm_provider = "bedrock" custom_llm_provider = "bedrock"
# openai embeddings # openai embeddings
elif model in litellm.open_ai_embedding_models: elif model in litellm.open_ai_embedding_models:
custom_llm_provider = "openai" custom_llm_provider = "openai"
# cohere embeddings
elif model in litellm.cohere_embedding_models:
custom_llm_provider = "cohere"
elif model in litellm.bedrock_embedding_models:
custom_llm_provider = "bedrock"
if custom_llm_provider is None or custom_llm_provider=="": if custom_llm_provider is None or custom_llm_provider=="":
print() # noqa print() # noqa
print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m") # noqa print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m") # noqa