diff --git a/litellm/utils.py b/litellm/utils.py index e3823e446..b498bea04 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -5311,8 +5311,31 @@ def get_optional_params( def get_api_base(model: str, optional_params: dict) -> Optional[str]: - _optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object + """ + Returns the api base used for calling the model. + Parameters: + - model: str - the model passed to litellm.completion() + - optional_params - the additional params passed to litellm.completion - eg. api_base, api_key, etc. See `LiteLLM_Params` - https://github.com/BerriAI/litellm/blob/f09e6ba98d65e035a79f73bc069145002ceafd36/litellm/router.py#L67 + + Returns: + - string (api_base) or None + + Example: + ``` + from litellm import get_api_base + + get_api_base(model="gemini/gemini-pro") + ``` + """ + _optional_params = LiteLLM_Params(**optional_params) # convert to pydantic object + # get llm provider + try: + model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( + model=model + ) + except: + custom_llm_provider = None if _optional_params.api_base is not None: return _optional_params.api_base @@ -5328,6 +5351,11 @@ def get_api_base(model: str, optional_params: dict) -> Optional[str]: ) return _api_base + if custom_llm_provider is not None and custom_llm_provider == "gemini": + _api_base = "https://generativelanguage.googleapis.com/v1beta/models/{}:generateContent".format( + model + ) + return _api_base return None