forked from phoenix/litellm-mirror
raise better exception if llm provider isn't passed in or inferred
This commit is contained in:
parent
4acca3d4d9
commit
baa69734b0
8 changed files with 63 additions and 1 deletions
|
@ -238,6 +238,7 @@ from .utils import (
|
||||||
register_prompt_template,
|
register_prompt_template,
|
||||||
validate_environment,
|
validate_environment,
|
||||||
check_valid_key,
|
check_valid_key,
|
||||||
|
get_llm_provider
|
||||||
)
|
)
|
||||||
from .main import * # type: ignore
|
from .main import * # type: ignore
|
||||||
from .integrations import *
|
from .integrations import *
|
||||||
|
|
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -17,6 +17,7 @@ from litellm.utils import (
|
||||||
CustomStreamWrapper,
|
CustomStreamWrapper,
|
||||||
read_config_args,
|
read_config_args,
|
||||||
completion_with_fallbacks,
|
completion_with_fallbacks,
|
||||||
|
get_llm_provider
|
||||||
)
|
)
|
||||||
from .llms import anthropic
|
from .llms import anthropic
|
||||||
from .llms import together_ai
|
from .llms import together_ai
|
||||||
|
@ -168,6 +169,7 @@ def completion(
|
||||||
completion_call_id=id
|
completion_call_id=id
|
||||||
)
|
)
|
||||||
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
|
logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params)
|
||||||
|
get_llm_provider(model=model, custom_llm_provider=custom_llm_provider)
|
||||||
if custom_llm_provider == "azure":
|
if custom_llm_provider == "azure":
|
||||||
# azure configs
|
# azure configs
|
||||||
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
|
openai.api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||||
|
|
|
@ -32,6 +32,16 @@ def test_completion_with_empty_model():
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_with_no_provider():
|
||||||
|
# test on empty
|
||||||
|
try:
|
||||||
|
model = "cerebras/btlm-3b-8k-base"
|
||||||
|
response = completion(model=model, messages=messages)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"error occurred: {e}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
test_completion_with_no_provider()
|
||||||
# # bad key
|
# # bad key
|
||||||
# temp_key = os.environ.get("OPENAI_API_KEY")
|
# temp_key = os.environ.get("OPENAI_API_KEY")
|
||||||
# os.environ["OPENAI_API_KEY"] = "bad-key"
|
# os.environ["OPENAI_API_KEY"] = "bad-key"
|
||||||
|
|
|
@ -931,6 +931,55 @@ def get_optional_params( # use the openai defaults
|
||||||
return optional_params
|
return optional_params
|
||||||
return optional_params
|
return optional_params
|
||||||
|
|
||||||
|
def get_llm_provider(model: str, custom_llm_provider: str = None):
|
||||||
|
try:
|
||||||
|
# check if llm provider provided
|
||||||
|
if custom_llm_provider:
|
||||||
|
return model, custom_llm_provider
|
||||||
|
|
||||||
|
# check if llm provider part of model name
|
||||||
|
if model.split("/",1)[0] in litellm.provider_list:
|
||||||
|
custom_llm_provider = model.split("/", 1)[0]
|
||||||
|
model = model.split("/", 1)[1]
|
||||||
|
return model, custom_llm_provider
|
||||||
|
|
||||||
|
# check if model in known model provider list
|
||||||
|
## openai - chatcompletion + text completion
|
||||||
|
if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models:
|
||||||
|
custom_llm_provider = "openai"
|
||||||
|
## cohere
|
||||||
|
elif model in litellm.cohere_models:
|
||||||
|
custom_llm_provider = "cohere"
|
||||||
|
## replicate
|
||||||
|
elif model in litellm.replicate_models:
|
||||||
|
custom_llm_provider = "replicate"
|
||||||
|
## openrouter
|
||||||
|
elif model in litellm.openrouter_models:
|
||||||
|
custom_llm_provider = "openrouter"
|
||||||
|
## vertex - text + chat models
|
||||||
|
elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models:
|
||||||
|
custom_llm_provider = "vertex_ai"
|
||||||
|
## huggingface
|
||||||
|
elif model in litellm.huggingface_models:
|
||||||
|
custom_llm_provider = "huggingface"
|
||||||
|
## ai21
|
||||||
|
elif model in litellm.ai21_models:
|
||||||
|
custom_llm_provider = "ai21"
|
||||||
|
## together_ai
|
||||||
|
elif model in litellm.together_ai_models:
|
||||||
|
custom_llm_provider = "together_ai"
|
||||||
|
## aleph_alpha
|
||||||
|
elif model in litellm.aleph_alpha_models:
|
||||||
|
custom_llm_provider = "aleph_alpha"
|
||||||
|
## baseten
|
||||||
|
elif model in litellm.baseten_models:
|
||||||
|
custom_llm_provider = "baseten"
|
||||||
|
|
||||||
|
if custom_llm_provider is None:
|
||||||
|
raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers")
|
||||||
|
return model, custom_llm_provider
|
||||||
|
except Exception as e:
|
||||||
|
raise e
|
||||||
|
|
||||||
def get_max_tokens(model: str):
|
def get_max_tokens(model: str):
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
[tool.poetry]
|
[tool.poetry]
|
||||||
name = "litellm"
|
name = "litellm"
|
||||||
version = "0.1.601"
|
version = "0.1.602"
|
||||||
description = "Library to easily interface with LLM API providers"
|
description = "Library to easily interface with LLM API providers"
|
||||||
authors = ["BerriAI"]
|
authors = ["BerriAI"]
|
||||||
license = "MIT License"
|
license = "MIT License"
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue