diff --git a/litellm/__pycache__/__init__.cpython-311.pyc b/litellm/__pycache__/__init__.cpython-311.pyc index a441050f7..79e48505e 100644 Binary files a/litellm/__pycache__/__init__.cpython-311.pyc and b/litellm/__pycache__/__init__.cpython-311.pyc differ diff --git a/litellm/__pycache__/main.cpython-311.pyc b/litellm/__pycache__/main.cpython-311.pyc index d5fda4c2a..e14558129 100644 Binary files a/litellm/__pycache__/main.cpython-311.pyc and b/litellm/__pycache__/main.cpython-311.pyc differ diff --git a/litellm/__pycache__/utils.cpython-311.pyc b/litellm/__pycache__/utils.cpython-311.pyc index 5a8696af9..6a0dcbc33 100644 Binary files a/litellm/__pycache__/utils.cpython-311.pyc and b/litellm/__pycache__/utils.cpython-311.pyc differ diff --git a/litellm/main.py b/litellm/main.py index 4085c6387..fa88e3ca8 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -44,10 +44,11 @@ def completion( presence_penalty=0, frequency_penalty=0, logit_bias={}, user="", deployment_id=None, # Optional liteLLM function params *, return_async=False, api_key=None, force_timeout=60, azure=False, logger_fn=None, verbose=False, - hugging_face = False, replicate=False,together_ai = False, llm_provider=None, custom_api_base=None + hugging_face = False, replicate=False,together_ai = False, custom_llm_provider=None, custom_api_base=None ): try: global new_response + args = locals() model_response = deepcopy(new_response) # deep copy the default response format so we can mutate it and it's thread-safe. # check if user passed in any of the OpenAI optional params optional_params = get_optional_params( @@ -85,7 +86,7 @@ def completion( messages = messages, **optional_params ) - elif model in litellm.open_ai_chat_completion_models or llm_provider == "custom": # allow user to make an openai call with a custom base + elif model in litellm.open_ai_chat_completion_models or custom_llm_provider == "custom_openai": # allow user to make an openai call with a custom base openai.api_type = "openai" # note: if a user sets a custom base - we should ensure this works api_base = custom_api_base if custom_api_base is not None else litellm.api_base # allow for the setting of dynamic and stateful api-bases @@ -100,9 +101,7 @@ def completion( else: openai.api_key = get_secret("OPENAI_API_KEY") ## LOGGING - logging(model=model, input=messages, additional_args=optional_params, azure=azure, logger_fn=logger_fn) - if custom_api_base: # reset after call, if a dynamic api base was passsed - openai.api_base = "https://api.openai.com/v1" + logging(model=model, input=messages, additional_args=args, azure=azure, logger_fn=logger_fn) ## COMPLETION CALL if litellm.headers: response = openai.ChatCompletion.create( @@ -117,6 +116,8 @@ def completion( messages = messages, **optional_params ) + if custom_api_base: # reset after call, if a dynamic api base was passsed + openai.api_base = "https://api.openai.com/v1" elif model in litellm.open_ai_text_completion_models: openai.api_type = "openai" openai.api_base = litellm.api_base if litellm.api_base is not None else "https://api.openai.com/v1" diff --git a/litellm/tests/test_custom_api_base.py b/litellm/tests/test_custom_api_base.py new file mode 100644 index 000000000..966fff954 --- /dev/null +++ b/litellm/tests/test_custom_api_base.py @@ -0,0 +1,20 @@ +import sys, os +import traceback +from dotenv import load_dotenv +load_dotenv() +import os +sys.path.insert(0, os.path.abspath('../..')) # Adds the parent directory to the system path +import litellm +from litellm import completion + +def logging_fn(model_call_dict): + print(f"model call details: {model_call_dict}") +models = ["gorilla-7b-hf-v1", "gpt-4"] +custom_llm_provider = None +messages = [{"role": "user", "content": "Hey, how's it going?"}] +for model in models: # iterate through list + custom_api_base = None + if model == "gorilla-7b-hf-v1": + custom_llm_provider = "custom_openai" + custom_api_base = "http://zanino.millennium.berkeley.edu:8000/v1" + completion(model=model, messages=messages, custom_llm_provider=custom_llm_provider, custom_api_base=custom_api_base, logger_fn=logging_fn) diff --git a/pyproject.toml b/pyproject.toml index da235aabb..d5eb3177c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "litellm" -version = "0.1.379" +version = "0.1.380" description = "Library to easily interface with LLM API providers" authors = ["BerriAI"] license = "MIT License"