mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Support checking provider-specific /models
endpoints for available models based on key (#7538)
* test(test_utils.py): initial test for valid models Addresses https://github.com/BerriAI/litellm/issues/7525 * fix: test * feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint * refactor(fireworks_ai/): support checking model info on `/v1/models` route * docs(set_keys.md): update docs to clarify check llm provider api usage * fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth * fix(watsonx): read in watsonx token from env var * fix: fix linting errors * fix(utils.py): fix provider config check * style: cleanup unused imports
This commit is contained in:
parent
cac06a32b8
commit
f770dd0c95
12 changed files with 350 additions and 42 deletions
|
@ -14,6 +14,7 @@ os.environ["WATSONX_TOKEN"] = "" # IAM auth token
|
|||
# optional - can also be passed as params to completion() or embedding()
|
||||
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
|
||||
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
|
||||
os.environ["WATSONX_ZENAPIKEY"] = "" # Zen API key (use for long-term api token)
|
||||
```
|
||||
|
||||
See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.
|
||||
|
|
|
@ -179,6 +179,22 @@ assert(valid_models == expected_models)
|
|||
os.environ = old_environ
|
||||
```
|
||||
|
||||
### `get_valid_models(check_provider_endpoint: True)`
|
||||
|
||||
This helper will check the provider's endpoint for valid models.
|
||||
|
||||
Currently implemented for:
|
||||
- OpenAI (if OPENAI_API_KEY is set)
|
||||
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
|
||||
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)
|
||||
|
||||
```python
|
||||
from litellm import get_valid_models
|
||||
|
||||
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||
print(valid_models)
|
||||
```
|
||||
|
||||
### `validate_environment(model: str)`
|
||||
|
||||
This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
|
||||
|
|
|
@ -1168,6 +1168,7 @@ from .llms.azure.azure import (
|
|||
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
||||
from .llms.azure.completion.transformation import AzureOpenAITextConfig
|
||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||
from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig
|
||||
from .llms.vllm.completion.transformation import VLLMConfig
|
||||
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||
|
|
|
@ -488,11 +488,10 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
|||
elif custom_llm_provider == "fireworks_ai":
|
||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||
(
|
||||
model,
|
||||
api_base,
|
||||
dynamic_api_key,
|
||||
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
||||
model=model, api_base=api_base, api_key=api_key
|
||||
api_base=api_base, api_key=api_key
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
(
|
||||
|
|
|
@ -1,9 +1,18 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional
|
||||
|
||||
from litellm.types.utils import ModelInfoBase
|
||||
|
||||
|
||||
class BaseLLMModelInfo(ABC):
|
||||
@abstractmethod
|
||||
def get_model_info(self, model: str) -> ModelInfoBase:
|
||||
def get_model_info(
|
||||
self,
|
||||
model: str,
|
||||
existing_model_info: Optional[ModelInfoBase] = None,
|
||||
) -> Optional[ModelInfoBase]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_models(self) -> List[str]:
|
||||
pass
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
|
||||
from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
|
||||
|
@ -9,7 +8,7 @@ from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
|
|||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
|
||||
class FireworksAIConfig(OpenAIGPTConfig):
|
||||
"""
|
||||
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
||||
|
||||
|
@ -209,8 +208,8 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
|
|||
)
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, model: str, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = (
|
||||
api_base
|
||||
or get_secret_str("FIREWORKS_API_BASE")
|
||||
|
@ -222,4 +221,32 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
|
|||
or get_secret_str("FIREWORKSAI_API_KEY")
|
||||
or get_secret_str("FIREWORKS_AI_TOKEN")
|
||||
)
|
||||
return model, api_base, dynamic_api_key
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||
api_base, api_key = self._get_openai_compatible_provider_info(
|
||||
api_base=api_base, api_key=api_key
|
||||
)
|
||||
if api_base is None or api_key is None:
|
||||
raise ValueError(
|
||||
"FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||
)
|
||||
|
||||
account_id = get_secret_str("FIREWORKS_ACCOUNT_ID")
|
||||
if account_id is None:
|
||||
raise ValueError(
|
||||
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||
)
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/accounts/{account_id}/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(
|
||||
f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}"
|
||||
)
|
||||
|
||||
models = response.json()["models"]
|
||||
return ["fireworks_ai/" + model["name"] for model in models]
|
||||
|
|
29
litellm/llms/litellm_proxy/chat/transformation.py
Normal file
29
litellm/llms/litellm_proxy/chat/transformation.py
Normal file
|
@ -0,0 +1,29 @@
|
|||
"""
|
||||
Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
|
||||
"""
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
|
||||
def _get_openai_compatible_provider_info(
|
||||
self, api_base: Optional[str], api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") # type: ignore
|
||||
dynamic_api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
|
||||
return api_base, dynamic_api_key
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
api_base, api_key = self._get_openai_compatible_provider_info(api_base, api_key)
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`"
|
||||
)
|
||||
models = super().get_models(api_key=api_key, api_base=api_base)
|
||||
return [f"litellm_proxy/{model}" for model in models]
|
|
@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import AllMessageValues
|
||||
from litellm.types.utils import ModelResponse
|
||||
from litellm.types.utils import ModelInfoBase, ModelResponse
|
||||
|
||||
from ..common_utils import OpenAIError
|
||||
|
||||
|
@ -21,7 +23,7 @@ else:
|
|||
LiteLLMLoggingObj = Any
|
||||
|
||||
|
||||
class OpenAIGPTConfig(BaseConfig):
|
||||
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||
"""
|
||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||
|
||||
|
@ -229,3 +231,43 @@ class OpenAIGPTConfig(BaseConfig):
|
|||
api_base: Optional[str] = None,
|
||||
) -> dict:
|
||||
raise NotImplementedError
|
||||
|
||||
def get_models(
|
||||
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""
|
||||
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
|
||||
"""
|
||||
|
||||
if api_base is None:
|
||||
api_base = "https://api.openai.com"
|
||||
if api_key is None:
|
||||
api_key = get_secret_str("OPENAI_API_KEY")
|
||||
|
||||
response = litellm.module_level_client.get(
|
||||
url=f"{api_base}/v1/models",
|
||||
headers={"Authorization": f"Bearer {api_key}"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise Exception(f"Failed to get models: {response.text}")
|
||||
|
||||
models = response.json()["data"]
|
||||
return [model["id"] for model in models]
|
||||
|
||||
def get_model_info(
|
||||
self, model: str, existing_model_info: Optional[ModelInfoBase] = None
|
||||
) -> ModelInfoBase:
|
||||
|
||||
if existing_model_info is not None:
|
||||
return existing_model_info
|
||||
return ModelInfoBase(
|
||||
key=model,
|
||||
litellm_provider="openai",
|
||||
mode="chat",
|
||||
input_cost_per_token=0.0,
|
||||
output_cost_per_token=0.0,
|
||||
max_tokens=None,
|
||||
max_input_tokens=None,
|
||||
max_output_tokens=None,
|
||||
)
|
||||
|
|
|
@ -175,7 +175,12 @@ class IBMWatsonXMixin:
|
|||
|
||||
if "Authorization" in headers:
|
||||
return {**default_headers, **headers}
|
||||
token = cast(Optional[str], optional_params.get("token"))
|
||||
token = cast(
|
||||
Optional[str],
|
||||
optional_params.get("token")
|
||||
or get_secret_str("WATSONX_ZENAPIKEY")
|
||||
or get_secret_str("WATSONX_TOKEN"),
|
||||
)
|
||||
if token:
|
||||
headers["Authorization"] = f"Bearer {token}"
|
||||
else:
|
||||
|
@ -245,6 +250,7 @@ class IBMWatsonXMixin:
|
|||
)
|
||||
|
||||
token: Optional[str] = None
|
||||
|
||||
if wx_credentials is not None:
|
||||
api_base = wx_credentials.get("url", api_base)
|
||||
api_key = wx_credentials.get(
|
||||
|
|
|
@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
|
||||
_model_info: Optional[Dict[str, Any]] = None
|
||||
key: Optional[str] = None
|
||||
provider_config: Optional[BaseLLMModelInfo] = None
|
||||
if combined_model_name in litellm.model_cost:
|
||||
key = combined_model_name
|
||||
_model_info = _get_model_info_from_model_cost(key=key)
|
||||
|
@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915
|
|||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||
):
|
||||
_model_info = None
|
||||
if _model_info is None and ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
):
|
||||
|
||||
if custom_llm_provider:
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model=model, provider=LlmProviders(custom_llm_provider)
|
||||
)
|
||||
if provider_config is not None:
|
||||
_model_info = cast(
|
||||
dict, provider_config.get_model_info(model=model)
|
||||
)
|
||||
|
||||
if _model_info is None and provider_config is not None:
|
||||
_model_info = cast(
|
||||
Optional[Dict],
|
||||
provider_config.get_model_info(
|
||||
model=model, existing_model_info=_model_info
|
||||
),
|
||||
)
|
||||
if key is None:
|
||||
key = "provider_specific_model_info"
|
||||
if _model_info is None or key is None:
|
||||
raise ValueError(
|
||||
|
@ -5706,12 +5711,12 @@ def trim_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_valid_models() -> List[str]:
|
||||
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
||||
"""
|
||||
Returns a list of valid LLMs based on the set environment variables
|
||||
|
||||
Args:
|
||||
None
|
||||
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
|
||||
|
||||
Returns:
|
||||
A list of valid LLMs
|
||||
|
@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]:
|
|||
|
||||
for provider in litellm.provider_list:
|
||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||
provider = provider.replace("_", "")
|
||||
env_provider_1 = provider.replace("_", "")
|
||||
env_provider_2 = provider
|
||||
|
||||
# litellm standardizes expected provider keys to
|
||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||
expected_provider_key = f"{provider.upper()}_API_KEY"
|
||||
if expected_provider_key in environ_keys:
|
||||
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||
if (
|
||||
expected_provider_key_1 in environ_keys
|
||||
or expected_provider_key_2 in environ_keys
|
||||
):
|
||||
# key is set
|
||||
valid_providers.append(provider)
|
||||
|
||||
for provider in valid_providers:
|
||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||
model=None,
|
||||
provider=LlmProviders(provider),
|
||||
)
|
||||
|
||||
if provider == "azure":
|
||||
valid_models.append("Azure-LLM")
|
||||
elif provider_config is not None and check_provider_endpoint:
|
||||
valid_models.extend(provider_config.get_models())
|
||||
else:
|
||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||
valid_models.extend(models_for_provider)
|
||||
return valid_models
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||
return [] # NON-Blocking
|
||||
|
||||
|
||||
|
@ -6291,11 +6310,14 @@ class ProviderConfigManager:
|
|||
|
||||
@staticmethod
|
||||
def get_provider_model_info(
|
||||
model: str,
|
||||
model: Optional[str],
|
||||
provider: LlmProviders,
|
||||
) -> Optional[BaseLLMModelInfo]:
|
||||
if LlmProviders.FIREWORKS_AI == provider:
|
||||
return litellm.FireworksAIConfig()
|
||||
elif LlmProviders.LITELLM_PROXY == provider:
|
||||
return litellm.LiteLLMProxyChatConfig()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -23,31 +23,43 @@ def watsonx_chat_completion_call():
|
|||
api_key="test_api_key",
|
||||
headers=None,
|
||||
client=None,
|
||||
patch_token_call=True,
|
||||
):
|
||||
if messages is None:
|
||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||
if client is None:
|
||||
client = HTTPHandler()
|
||||
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "mock_access_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
||||
if patch_token_call:
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
"access_token": "mock_access_token",
|
||||
"expires_in": 3600,
|
||||
}
|
||||
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
||||
|
||||
with patch.object(client, "post") as mock_post, patch.object(
|
||||
litellm.module_level_client, "post", return_value=mock_response
|
||||
) as mock_get:
|
||||
completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
)
|
||||
with patch.object(client, "post") as mock_post, patch.object(
|
||||
litellm.module_level_client, "post", return_value=mock_response
|
||||
) as mock_get:
|
||||
completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
)
|
||||
|
||||
return mock_post, mock_get
|
||||
return mock_post, mock_get
|
||||
else:
|
||||
with patch.object(client, "post") as mock_post:
|
||||
completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
headers=headers or {},
|
||||
client=client,
|
||||
)
|
||||
return mock_post, None
|
||||
|
||||
return _call
|
||||
|
||||
|
@ -77,6 +89,20 @@ def test_watsonx_custom_auth_header(
|
|||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("env_var_key", ["WATSONX_ZENAPIKEY", "WATSONX_TOKEN"])
|
||||
def test_watsonx_token_in_env_var(
|
||||
monkeypatch, watsonx_chat_completion_call, env_var_key
|
||||
):
|
||||
monkeypatch.setenv(env_var_key, "my-custom-token")
|
||||
|
||||
mock_post, _ = watsonx_chat_completion_call(patch_token_call=False)
|
||||
|
||||
assert mock_post.call_count == 1
|
||||
assert (
|
||||
mock_post.call_args[1]["headers"]["Authorization"] == "Bearer my-custom-token"
|
||||
)
|
||||
|
||||
|
||||
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
|
||||
model = "watsonx/another-model"
|
||||
messages = [{"role": "user", "content": "Test message"}]
|
||||
|
|
|
@ -1270,6 +1270,8 @@ def test_fireworks_ai_document_inlining():
|
|||
"""
|
||||
from litellm.utils import supports_pdf_input, supports_vision
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||
|
||||
|
@ -1288,3 +1290,131 @@ def test_logprobs_type():
|
|||
assert logprobs.token_logprobs is None
|
||||
assert logprobs.tokens is None
|
||||
assert logprobs.top_logprobs is None
|
||||
|
||||
|
||||
def test_get_valid_models_openai_proxy(monkeypatch):
|
||||
from litellm.utils import get_valid_models
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234")
|
||||
monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/")
|
||||
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None)
|
||||
monkeypatch.delenv("FIREWORKS_AI_API_KEY", None)
|
||||
|
||||
mock_response_data = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "organization-owner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a mock response object
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
with patch.object(
|
||||
litellm.module_level_client, "get", return_value=mock_response
|
||||
) as mock_post:
|
||||
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||
assert "litellm_proxy/gpt-4o" in valid_models
|
||||
|
||||
|
||||
def test_get_valid_models_fireworks_ai(monkeypatch):
|
||||
from litellm.utils import get_valid_models
|
||||
import litellm
|
||||
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
|
||||
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234")
|
||||
|
||||
mock_response_data = {
|
||||
"models": [
|
||||
{
|
||||
"name": "accounts/fireworks/models/llama-3.1-8b-instruct",
|
||||
"displayName": "<string>",
|
||||
"description": "<string>",
|
||||
"createTime": "2023-11-07T05:31:56Z",
|
||||
"createdBy": "<string>",
|
||||
"state": "STATE_UNSPECIFIED",
|
||||
"status": {"code": "OK", "message": "<string>"},
|
||||
"kind": "KIND_UNSPECIFIED",
|
||||
"githubUrl": "<string>",
|
||||
"huggingFaceUrl": "<string>",
|
||||
"baseModelDetails": {
|
||||
"worldSize": 123,
|
||||
"checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED",
|
||||
"parameterCount": "<string>",
|
||||
"moe": True,
|
||||
"tunable": True,
|
||||
},
|
||||
"peftDetails": {
|
||||
"baseModel": "<string>",
|
||||
"r": 123,
|
||||
"targetModules": ["<string>"],
|
||||
},
|
||||
"teftDetails": {},
|
||||
"public": True,
|
||||
"conversationConfig": {
|
||||
"style": "<string>",
|
||||
"system": "<string>",
|
||||
"template": "<string>",
|
||||
},
|
||||
"contextLength": 123,
|
||||
"supportsImageInput": True,
|
||||
"supportsTools": True,
|
||||
"importedFrom": "<string>",
|
||||
"fineTuningJob": "<string>",
|
||||
"defaultDraftModel": "<string>",
|
||||
"defaultDraftTokenCount": 123,
|
||||
"precisions": ["PRECISION_UNSPECIFIED"],
|
||||
"deployedModelRefs": [
|
||||
{
|
||||
"name": "<string>",
|
||||
"deployment": "<string>",
|
||||
"state": "STATE_UNSPECIFIED",
|
||||
"default": True,
|
||||
"public": True,
|
||||
}
|
||||
],
|
||||
"cluster": "<string>",
|
||||
"deprecationDate": {"year": 123, "month": 123, "day": 123},
|
||||
}
|
||||
],
|
||||
"nextPageToken": "<string>",
|
||||
"totalSize": 123,
|
||||
}
|
||||
|
||||
# Create a mock response object
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
with patch.object(
|
||||
litellm.module_level_client, "get", return_value=mock_response
|
||||
) as mock_post:
|
||||
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||
assert (
|
||||
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
|
||||
in valid_models
|
||||
)
|
||||
|
||||
|
||||
def test_get_valid_models_default(monkeypatch):
|
||||
"""
|
||||
Ensure that the default models is used when error retrieving from model api.
|
||||
|
||||
Prevent regression for existing usage.
|
||||
"""
|
||||
from litellm.utils import get_valid_models
|
||||
import litellm
|
||||
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
|
||||
valid_models = get_valid_models()
|
||||
assert len(valid_models) > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue