diff --git a/litellm/__init__.py b/litellm/__init__.py index 084931eb0..ec3ea3165 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -238,6 +238,7 @@ from .utils import ( register_prompt_template, validate_environment, check_valid_key, + get_llm_provider ) from .main import * # type: ignore from .integrations import * diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index 9690c652a..b5b19892b 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index f5bd73faa..518b9d26e 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 5849ed2ea..d8e7e219c 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index 5c90f276a..a0b233861 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -17,6 +17,7 @@ from litellm.utils import ( CustomStreamWrapper, read_config_args, completion_with_fallbacks, + get_llm_provider ) from .llms import anthropic from .llms import together_ai @@ -168,6 +169,7 @@ def completion( completion_call_id=id ) logging.update_environment_variables(model=model, user=user, optional_params=optional_params, litellm_params=litellm_params) + get_llm_provider(model=model, custom_llm_provider=custom_llm_provider) if custom_llm_provider == "azure": # azure configs openai.api_type = get_secret("AZURE_API_TYPE") or "azure" diff --git a/litellm/tests/test_bad_params.py b/litellm/tests/test_bad_params.py index 0173099f9..4296d997b 100644 --- a/litellm/tests/test_bad_params.py +++ b/litellm/tests/test_bad_params.py @@ -32,6 +32,16 @@ def test_completion_with_empty_model(): pass +def test_completion_with_no_provider(): + # test on empty + try: + model = "cerebras/btlm-3b-8k-base" + response = completion(model=model, messages=messages) + except Exception as e: + print(f"error occurred: {e}") + pass + +test_completion_with_no_provider() # # bad key # temp_key = os.environ.get("OPENAI_API_KEY") # os.environ["OPENAI_API_KEY"] = "bad-key" diff --git a/litellm/utils.py b/litellm/utils.py index 1a36fd918..5479cca7c 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -931,6 +931,55 @@ def get_optional_params( # use the openai defaults return optional_params return optional_params +def get_llm_provider(model: str, custom_llm_provider: str = None): + try: + # check if llm provider provided + if custom_llm_provider: + return model, custom_llm_provider + + # check if llm provider part of model name + if model.split("/",1)[0] in litellm.provider_list: + custom_llm_provider = model.split("/", 1)[0] + model = model.split("/", 1)[1] + return model, custom_llm_provider + + # check if model in known model provider list + ## openai - chatcompletion + text completion + if model in litellm.open_ai_chat_completion_models or model in litellm.open_ai_text_completion_models: + custom_llm_provider = "openai" + ## cohere + elif model in litellm.cohere_models: + custom_llm_provider = "cohere" + ## replicate + elif model in litellm.replicate_models: + custom_llm_provider = "replicate" + ## openrouter + elif model in litellm.openrouter_models: + custom_llm_provider = "openrouter" + ## vertex - text + chat models + elif model in litellm.vertex_chat_models or model in litellm.vertex_text_models: + custom_llm_provider = "vertex_ai" + ## huggingface + elif model in litellm.huggingface_models: + custom_llm_provider = "huggingface" + ## ai21 + elif model in litellm.ai21_models: + custom_llm_provider = "ai21" + ## together_ai + elif model in litellm.together_ai_models: + custom_llm_provider = "together_ai" + ## aleph_alpha + elif model in litellm.aleph_alpha_models: + custom_llm_provider = "aleph_alpha" + ## baseten + elif model in litellm.baseten_models: + custom_llm_provider = "baseten" + + if custom_llm_provider is None: + raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers") + return model, custom_llm_provider + except Exception as e: + raise e def get_max_tokens(model: str): try: diff --git a/pyproject.toml b/pyproject.toml index 8ac50966c..92abc5df9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.601" +version = "0.1.602" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"