diff --git a/litellm/utils.py b/litellm/utils.py index d033f6545..b900a1424 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -1253,6 +1253,8 @@ def get_optional_params( # use the openai defaults optional_params["presence_penalty"] = presence_penalty if stop: optional_params["stop_sequences"] = stop + elif custom_llm_provider == "perplexity": + optional_params[""] elif custom_llm_provider == "replicate": ## check if unsupported param passed in supported_params = ["stream", "temperature", "max_tokens", "top_p", "stop", "seed"] @@ -1554,7 +1556,13 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ if model.split("/",1)[0] in litellm.provider_list: custom_llm_provider = model.split("/", 1)[0] model = model.split("/", 1)[1] - return model, custom_llm_provider, dynamic_api_key + if custom_llm_provider == "perplexity": + # perplexity is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.perplexity.ai + api_base = "https://api.perplexity.ai" + dynamic_api_key = os.getenv("PERPLEXITYAI_API_KEY") + custom_llm_provider = "custom_openai" + + return model, custom_llm_provider, dynamic_api_key, api_base # check if api base is a known openai compatible endpoint if api_base: @@ -1563,7 +1571,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ custom_llm_provider = "custom_openai" if endpoint == "api.perplexity.ai": dynamic_api_key = os.getenv("PERPLEXITYAI_API_KEY") - return model, custom_llm_provider, dynamic_api_key + return model, custom_llm_provider, dynamic_api_key, api_base # check if model in known model provider list -> for huggingface models, raise exception as they don't have a fixed provider (can be togetherai, anyscale, baseten, runpod, et.) ## openai - chatcompletion + text completion @@ -1620,7 +1628,7 @@ def get_llm_provider(model: str, custom_llm_provider: Optional[str] = None, api_ print("\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m") print() raise ValueError(f"LLM Provider NOT provided. Pass in the LLM provider you are trying to call. E.g. For 'Huggingface' inference endpoints pass in `completion(model='huggingface/{model}',..)` Learn more: https://docs.litellm.ai/docs/providers") - return model, custom_llm_provider, dynamic_api_key + return model, custom_llm_provider, dynamic_api_key, api_base except Exception as e: raise e