diff --git a/litellm/llms/openai.py b/litellm/llms/openai.py index e541b1089..4fdecce10 100644 --- a/litellm/llms/openai.py +++ b/litellm/llms/openai.py @@ -272,8 +272,8 @@ class OpenAIChatCompletion(BaseLLM): ## LOGGING logging_obj.pre_call( input=data['messages'], - api_key=api_key, - additional_args={"headers": headers, "api_base": api_base, "acompletion": True, "complete_input_dict": data}, + api_key=openai_aclient.api_key, + additional_args={"headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, "api_base": openai_aclient._base_url._uri_reference, "acompletion": True, "complete_input_dict": data}, ) response = await openai_aclient.chat.completions.create(**data) stringified_response = response.model_dump_json() diff --git a/litellm/router.py b/litellm/router.py index cc2b832b7..54c3f0ebe 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -889,25 +889,23 @@ class Router: model["model_info"] = model_info #### for OpenAI / Azure we need to initalize the Client for High Traffic ######## custom_llm_provider = litellm_params.get("custom_llm_provider") - if custom_llm_provider is None: - custom_llm_provider = model_name.split("/",1)[0] + custom_llm_provider = custom_llm_provider or model_name.split("/",1)[0] or "" + default_api_base = None + default_api_key = None + if custom_llm_provider in litellm.openai_compatible_providers: + _, custom_llm_provider, api_key, api_base = litellm.get_llm_provider(model=model_name) + default_api_base = api_base + default_api_key = api_key if ( model_name in litellm.open_ai_chat_completion_models - or custom_llm_provider == "custom_openai" - or custom_llm_provider == "deepinfra" - or custom_llm_provider == "perplexity" - or custom_llm_provider == "anyscale" - or custom_llm_provider == "mistral" - or custom_llm_provider == "openai" - or custom_llm_provider == "azure" - or custom_llm_provider == "mistral" + or custom_llm_provider in litellm.openai_compatible_providers or "ft:gpt-3.5-turbo" in model_name or model_name in litellm.open_ai_embedding_models ): # glorified / complicated reading of configs # user can pass vars directly or they can pas os.environ/AZURE_API_KEY, in which case we will read the env # we do this here because we init clients for Azure, OpenAI and we need to set the right key - api_key = litellm_params.get("api_key") + api_key = litellm_params.get("api_key") or default_api_key if api_key and api_key.startswith("os.environ/"): api_key_env_name = api_key.replace("os.environ/", "") api_key = litellm.get_secret(api_key_env_name) @@ -915,7 +913,7 @@ class Router: api_base = litellm_params.get("api_base") base_url = litellm_params.get("base_url") - api_base = api_base or base_url # allow users to pass in `api_base` or `base_url` for azure + api_base = api_base or base_url or default_api_base # allow users to pass in `api_base` or `base_url` for azure if api_base and api_base.startswith("os.environ/"): api_base_env_name = api_base.replace("os.environ/", "") api_base = litellm.get_secret(api_base_env_name) diff --git a/litellm/tests/test_router.py b/litellm/tests/test_router.py index ef7474b92..b89748208 100644 --- a/litellm/tests/test_router.py +++ b/litellm/tests/test_router.py @@ -431,7 +431,7 @@ def test_acompletion_on_router(): traceback.print_exc() pytest.fail(f"Error occurred: {e}") -test_acompletion_on_router() +# test_acompletion_on_router() def test_function_calling_on_router(): try: @@ -593,6 +593,30 @@ def test_bedrock_on_router(): pytest.fail(f"Error occurred: {e}") # test_bedrock_on_router() +# test openai-compatible endpoint +@pytest.mark.asyncio +async def test_mistral_on_router(): + litellm.set_verbose = True + model_list = [ + { + "model_name": "gpt-3.5-turbo", + "litellm_params": { + "model": "mistral/mistral-medium", + }, + }, + ] + router = Router(model_list=model_list) + response = await router.acompletion( + model="gpt-3.5-turbo", + messages=[ + { + "role": "user", + "content": "hello from litellm test", + } + ] + ) + print(response) +asyncio.run(test_mistral_on_router()) def test_openai_completion_on_router(): # [PROD Use Case] - Makes an acompletion call + async acompletion call, and sync acompletion call, sync completion + stream