diff --git a/docs/my-website/docs/providers/watsonx.md b/docs/my-website/docs/providers/watsonx.md index 8665611fa7..23d8d259ac 100644 --- a/docs/my-website/docs/providers/watsonx.md +++ b/docs/my-website/docs/providers/watsonx.md @@ -14,6 +14,7 @@ os.environ["WATSONX_TOKEN"] = "" # IAM auth token # optional - can also be passed as params to completion() or embedding() os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models +os.environ["WATSONX_ZENAPIKEY"] = "" # Zen API key (use for long-term api token) ``` See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai. diff --git a/docs/my-website/docs/set_keys.md b/docs/my-website/docs/set_keys.md index 26784ce1be..7e63b5a888 100644 --- a/docs/my-website/docs/set_keys.md +++ b/docs/my-website/docs/set_keys.md @@ -179,6 +179,22 @@ assert(valid_models == expected_models) os.environ = old_environ ``` +### `get_valid_models(check_provider_endpoint: True)` + +This helper will check the provider's endpoint for valid models. + +Currently implemented for: +- OpenAI (if OPENAI_API_KEY is set) +- Fireworks AI (if FIREWORKS_AI_API_KEY is set) +- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set) + +```python +from litellm import get_valid_models + +valid_models = get_valid_models(check_provider_endpoint=True) +print(valid_models) +``` + ### `validate_environment(model: str)` This helper tells you if you have all the required environment variables for a model, and if not - what's missing. diff --git a/litellm/__init__.py b/litellm/__init__.py index e0f9b59c9e..83e6dc8c53 100644 --- a/litellm/__init__.py +++ b/litellm/__init__.py @@ -1168,6 +1168,7 @@ from .llms.azure.azure import ( from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig from .llms.azure.completion.transformation import AzureOpenAITextConfig from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig +from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig from .llms.vllm.completion.transformation import VLLMConfig from .llms.deepseek.chat.transformation import DeepSeekChatConfig from .llms.lm_studio.chat.transformation import LMStudioChatConfig diff --git a/litellm/litellm_core_utils/get_llm_provider_logic.py b/litellm/litellm_core_utils/get_llm_provider_logic.py index 4583dc2107..cb26955fb1 100644 --- a/litellm/litellm_core_utils/get_llm_provider_logic.py +++ b/litellm/litellm_core_utils/get_llm_provider_logic.py @@ -488,11 +488,10 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915 elif custom_llm_provider == "fireworks_ai": # fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1 ( - model, api_base, dynamic_api_key, ) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info( - model=model, api_base=api_base, api_key=api_key + api_base=api_base, api_key=api_key ) elif custom_llm_provider == "azure_ai": ( diff --git a/litellm/llms/base_llm/base_utils.py b/litellm/llms/base_llm/base_utils.py index dca8c2504c..b7587727fe 100644 --- a/litellm/llms/base_llm/base_utils.py +++ b/litellm/llms/base_llm/base_utils.py @@ -1,9 +1,18 @@ from abc import ABC, abstractmethod +from typing import List, Optional from litellm.types.utils import ModelInfoBase class BaseLLMModelInfo(ABC): @abstractmethod - def get_model_info(self, model: str) -> ModelInfoBase: + def get_model_info( + self, + model: str, + existing_model_info: Optional[ModelInfoBase] = None, + ) -> Optional[ModelInfoBase]: + pass + + @abstractmethod + def get_models(self) -> List[str]: pass diff --git a/litellm/llms/fireworks_ai/chat/transformation.py b/litellm/llms/fireworks_ai/chat/transformation.py index 0879d2579f..28b5d8045e 100644 --- a/litellm/llms/fireworks_ai/chat/transformation.py +++ b/litellm/llms/fireworks_ai/chat/transformation.py @@ -1,7 +1,6 @@ from typing import List, Literal, Optional, Tuple, Union, cast import litellm -from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo @@ -9,7 +8,7 @@ from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo from ...openai.chat.gpt_transformation import OpenAIGPTConfig -class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig): +class FireworksAIConfig(OpenAIGPTConfig): """ Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions @@ -209,8 +208,8 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig): ) def _get_openai_compatible_provider_info( - self, model: str, api_base: Optional[str], api_key: Optional[str] - ) -> Tuple[str, Optional[str], Optional[str]]: + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: api_base = ( api_base or get_secret_str("FIREWORKS_API_BASE") @@ -222,4 +221,32 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig): or get_secret_str("FIREWORKSAI_API_KEY") or get_secret_str("FIREWORKS_AI_TOKEN") ) - return model, api_base, dynamic_api_key + return api_base, dynamic_api_key + + def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None): + api_base, api_key = self._get_openai_compatible_provider_info( + api_base=api_base, api_key=api_key + ) + if api_base is None or api_key is None: + raise ValueError( + "FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint." + ) + + account_id = get_secret_str("FIREWORKS_ACCOUNT_ID") + if account_id is None: + raise ValueError( + "FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint." + ) + + response = litellm.module_level_client.get( + url=f"{api_base}/v1/accounts/{account_id}/models", + headers={"Authorization": f"Bearer {api_key}"}, + ) + + if response.status_code != 200: + raise ValueError( + f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}" + ) + + models = response.json()["models"] + return ["fireworks_ai/" + model["name"] for model in models] diff --git a/litellm/llms/litellm_proxy/chat/transformation.py b/litellm/llms/litellm_proxy/chat/transformation.py new file mode 100644 index 0000000000..cb3a9a5eeb --- /dev/null +++ b/litellm/llms/litellm_proxy/chat/transformation.py @@ -0,0 +1,29 @@ +""" +Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions` +""" + +from typing import List, Optional, Tuple + +from litellm.secret_managers.main import get_secret_str + +from ...openai.chat.gpt_transformation import OpenAIGPTConfig + + +class LiteLLMProxyChatConfig(OpenAIGPTConfig): + def _get_openai_compatible_provider_info( + self, api_base: Optional[str], api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") # type: ignore + dynamic_api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY") + return api_base, dynamic_api_key + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: + api_base, api_key = self._get_openai_compatible_provider_info(api_base, api_key) + if api_base is None: + raise ValueError( + "api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`" + ) + models = super().get_models(api_key=api_key, api_base=api_base) + return [f"litellm_proxy/{model}" for model in models] diff --git a/litellm/llms/openai/chat/gpt_transformation.py b/litellm/llms/openai/chat/gpt_transformation.py index 599250ab6b..b8ef78ae98 100644 --- a/litellm/llms/openai/chat/gpt_transformation.py +++ b/litellm/llms/openai/chat/gpt_transformation.py @@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast import httpx import litellm +from litellm.llms.base_llm.base_utils import BaseLLMModelInfo from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException +from litellm.secret_managers.main import get_secret_str from litellm.types.llms.openai import AllMessageValues -from litellm.types.utils import ModelResponse +from litellm.types.utils import ModelInfoBase, ModelResponse from ..common_utils import OpenAIError @@ -21,7 +23,7 @@ else: LiteLLMLoggingObj = Any -class OpenAIGPTConfig(BaseConfig): +class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig): """ Reference: https://platform.openai.com/docs/api-reference/chat/create @@ -229,3 +231,43 @@ class OpenAIGPTConfig(BaseConfig): api_base: Optional[str] = None, ) -> dict: raise NotImplementedError + + def get_models( + self, api_key: Optional[str] = None, api_base: Optional[str] = None + ) -> List[str]: + """ + Calls OpenAI's `/v1/models` endpoint and returns the list of models. + """ + + if api_base is None: + api_base = "https://api.openai.com" + if api_key is None: + api_key = get_secret_str("OPENAI_API_KEY") + + response = litellm.module_level_client.get( + url=f"{api_base}/v1/models", + headers={"Authorization": f"Bearer {api_key}"}, + ) + + if response.status_code != 200: + raise Exception(f"Failed to get models: {response.text}") + + models = response.json()["data"] + return [model["id"] for model in models] + + def get_model_info( + self, model: str, existing_model_info: Optional[ModelInfoBase] = None + ) -> ModelInfoBase: + + if existing_model_info is not None: + return existing_model_info + return ModelInfoBase( + key=model, + litellm_provider="openai", + mode="chat", + input_cost_per_token=0.0, + output_cost_per_token=0.0, + max_tokens=None, + max_input_tokens=None, + max_output_tokens=None, + ) diff --git a/litellm/llms/watsonx/common_utils.py b/litellm/llms/watsonx/common_utils.py index b8340503d3..d5bbe3fc14 100644 --- a/litellm/llms/watsonx/common_utils.py +++ b/litellm/llms/watsonx/common_utils.py @@ -175,7 +175,12 @@ class IBMWatsonXMixin: if "Authorization" in headers: return {**default_headers, **headers} - token = cast(Optional[str], optional_params.get("token")) + token = cast( + Optional[str], + optional_params.get("token") + or get_secret_str("WATSONX_ZENAPIKEY") + or get_secret_str("WATSONX_TOKEN"), + ) if token: headers["Authorization"] = f"Bearer {token}" else: @@ -245,6 +250,7 @@ class IBMWatsonXMixin: ) token: Optional[str] = None + if wx_credentials is not None: api_base = wx_credentials.get("url", api_base) api_key = wx_credentials.get( diff --git a/litellm/utils.py b/litellm/utils.py index 902e20fdf6..f3789fe129 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915 _model_info: Optional[Dict[str, Any]] = None key: Optional[str] = None + provider_config: Optional[BaseLLMModelInfo] = None if combined_model_name in litellm.model_cost: key = combined_model_name _model_info = _get_model_info_from_model_cost(key=key) @@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915 model_info=_model_info, custom_llm_provider=custom_llm_provider ): _model_info = None - if _model_info is None and ProviderConfigManager.get_provider_model_info( - model=model, provider=LlmProviders(custom_llm_provider) - ): + + if custom_llm_provider: provider_config = ProviderConfigManager.get_provider_model_info( model=model, provider=LlmProviders(custom_llm_provider) ) - if provider_config is not None: - _model_info = cast( - dict, provider_config.get_model_info(model=model) - ) + + if _model_info is None and provider_config is not None: + _model_info = cast( + Optional[Dict], + provider_config.get_model_info( + model=model, existing_model_info=_model_info + ), + ) + if key is None: key = "provider_specific_model_info" if _model_info is None or key is None: raise ValueError( @@ -5706,12 +5711,12 @@ def trim_messages( return messages -def get_valid_models() -> List[str]: +def get_valid_models(check_provider_endpoint: bool = False) -> List[str]: """ Returns a list of valid LLMs based on the set environment variables Args: - None + check_provider_endpoint: If True, will check the provider's endpoint for valid models. Returns: A list of valid LLMs @@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]: for provider in litellm.provider_list: # edge case litellm has together_ai as a provider, it should be togetherai - provider = provider.replace("_", "") + env_provider_1 = provider.replace("_", "") + env_provider_2 = provider # litellm standardizes expected provider keys to # PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY - expected_provider_key = f"{provider.upper()}_API_KEY" - if expected_provider_key in environ_keys: + expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY" + expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY" + if ( + expected_provider_key_1 in environ_keys + or expected_provider_key_2 in environ_keys + ): # key is set valid_providers.append(provider) + for provider in valid_providers: + provider_config = ProviderConfigManager.get_provider_model_info( + model=None, + provider=LlmProviders(provider), + ) + if provider == "azure": valid_models.append("Azure-LLM") + elif provider_config is not None and check_provider_endpoint: + valid_models.extend(provider_config.get_models()) else: models_for_provider = litellm.models_by_provider.get(provider, []) valid_models.extend(models_for_provider) return valid_models - except Exception: + except Exception as e: + verbose_logger.debug(f"Error getting valid models: {e}") return [] # NON-Blocking @@ -6291,11 +6310,14 @@ class ProviderConfigManager: @staticmethod def get_provider_model_info( - model: str, + model: Optional[str], provider: LlmProviders, ) -> Optional[BaseLLMModelInfo]: if LlmProviders.FIREWORKS_AI == provider: return litellm.FireworksAIConfig() + elif LlmProviders.LITELLM_PROXY == provider: + return litellm.LiteLLMProxyChatConfig() + return None diff --git a/tests/llm_translation/test_watsonx.py b/tests/llm_translation/test_watsonx.py index 9efe28ceec..41b3f099e5 100644 --- a/tests/llm_translation/test_watsonx.py +++ b/tests/llm_translation/test_watsonx.py @@ -23,31 +23,43 @@ def watsonx_chat_completion_call(): api_key="test_api_key", headers=None, client=None, + patch_token_call=True, ): if messages is None: messages = [{"role": "user", "content": "Hello, how are you?"}] if client is None: client = HTTPHandler() - mock_response = Mock() - mock_response.json.return_value = { - "access_token": "mock_access_token", - "expires_in": 3600, - } - mock_response.raise_for_status = Mock() # No-op to simulate no exception + if patch_token_call: + mock_response = Mock() + mock_response.json.return_value = { + "access_token": "mock_access_token", + "expires_in": 3600, + } + mock_response.raise_for_status = Mock() # No-op to simulate no exception - with patch.object(client, "post") as mock_post, patch.object( - litellm.module_level_client, "post", return_value=mock_response - ) as mock_get: - completion( - model=model, - messages=messages, - api_key=api_key, - headers=headers or {}, - client=client, - ) + with patch.object(client, "post") as mock_post, patch.object( + litellm.module_level_client, "post", return_value=mock_response + ) as mock_get: + completion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + ) - return mock_post, mock_get + return mock_post, mock_get + else: + with patch.object(client, "post") as mock_post: + completion( + model=model, + messages=messages, + api_key=api_key, + headers=headers or {}, + client=client, + ) + return mock_post, None return _call @@ -77,6 +89,20 @@ def test_watsonx_custom_auth_header( ) +@pytest.mark.parametrize("env_var_key", ["WATSONX_ZENAPIKEY", "WATSONX_TOKEN"]) +def test_watsonx_token_in_env_var( + monkeypatch, watsonx_chat_completion_call, env_var_key +): + monkeypatch.setenv(env_var_key, "my-custom-token") + + mock_post, _ = watsonx_chat_completion_call(patch_token_call=False) + + assert mock_post.call_count == 1 + assert ( + mock_post.call_args[1]["headers"]["Authorization"] == "Bearer my-custom-token" + ) + + def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call): model = "watsonx/another-model" messages = [{"role": "user", "content": "Test message"}] diff --git a/tests/local_testing/test_utils.py b/tests/local_testing/test_utils.py index d07ca29b1c..6b654ca76e 100644 --- a/tests/local_testing/test_utils.py +++ b/tests/local_testing/test_utils.py @@ -1270,6 +1270,8 @@ def test_fireworks_ai_document_inlining(): """ from litellm.utils import supports_pdf_input, supports_vision + litellm._turn_on_debug() + assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True @@ -1288,3 +1290,131 @@ def test_logprobs_type(): assert logprobs.token_logprobs is None assert logprobs.tokens is None assert logprobs.top_logprobs is None + + +def test_get_valid_models_openai_proxy(monkeypatch): + from litellm.utils import get_valid_models + import litellm + + litellm._turn_on_debug() + + monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234") + monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/") + monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None) + monkeypatch.delenv("FIREWORKS_AI_API_KEY", None) + + mock_response_data = { + "object": "list", + "data": [ + { + "id": "gpt-4o", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner", + }, + ], + } + + # Create a mock response object + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + + with patch.object( + litellm.module_level_client, "get", return_value=mock_response + ) as mock_post: + valid_models = get_valid_models(check_provider_endpoint=True) + assert "litellm_proxy/gpt-4o" in valid_models + + +def test_get_valid_models_fireworks_ai(monkeypatch): + from litellm.utils import get_valid_models + import litellm + + monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234") + monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234") + + mock_response_data = { + "models": [ + { + "name": "accounts/fireworks/models/llama-3.1-8b-instruct", + "displayName": "", + "description": "", + "createTime": "2023-11-07T05:31:56Z", + "createdBy": "", + "state": "STATE_UNSPECIFIED", + "status": {"code": "OK", "message": ""}, + "kind": "KIND_UNSPECIFIED", + "githubUrl": "", + "huggingFaceUrl": "", + "baseModelDetails": { + "worldSize": 123, + "checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED", + "parameterCount": "", + "moe": True, + "tunable": True, + }, + "peftDetails": { + "baseModel": "", + "r": 123, + "targetModules": [""], + }, + "teftDetails": {}, + "public": True, + "conversationConfig": { + "style": "", + "system": "", + "template": "", + }, + "contextLength": 123, + "supportsImageInput": True, + "supportsTools": True, + "importedFrom": "", + "fineTuningJob": "", + "defaultDraftModel": "", + "defaultDraftTokenCount": 123, + "precisions": ["PRECISION_UNSPECIFIED"], + "deployedModelRefs": [ + { + "name": "", + "deployment": "", + "state": "STATE_UNSPECIFIED", + "default": True, + "public": True, + } + ], + "cluster": "", + "deprecationDate": {"year": 123, "month": 123, "day": 123}, + } + ], + "nextPageToken": "", + "totalSize": 123, + } + + # Create a mock response object + mock_response = MagicMock() + mock_response.status_code = 200 + mock_response.json.return_value = mock_response_data + + with patch.object( + litellm.module_level_client, "get", return_value=mock_response + ) as mock_post: + valid_models = get_valid_models(check_provider_endpoint=True) + assert ( + "fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct" + in valid_models + ) + + +def test_get_valid_models_default(monkeypatch): + """ + Ensure that the default models is used when error retrieving from model api. + + Prevent regression for existing usage. + """ + from litellm.utils import get_valid_models + import litellm + + monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234") + valid_models = get_valid_models() + assert len(valid_models) > 0