mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
Support checking provider-specific /models
endpoints for available models based on key (#7538)
* test(test_utils.py): initial test for valid models Addresses https://github.com/BerriAI/litellm/issues/7525 * fix: test * feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint * refactor(fireworks_ai/): support checking model info on `/v1/models` route * docs(set_keys.md): update docs to clarify check llm provider api usage * fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth * fix(watsonx): read in watsonx token from env var * fix: fix linting errors * fix(utils.py): fix provider config check * style: cleanup unused imports
This commit is contained in:
parent
cac06a32b8
commit
f770dd0c95
12 changed files with 350 additions and 42 deletions
|
@ -1270,6 +1270,8 @@ def test_fireworks_ai_document_inlining():
|
|||
"""
|
||||
from litellm.utils import supports_pdf_input, supports_vision
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||
|
||||
|
@ -1288,3 +1290,131 @@ def test_logprobs_type():
|
|||
assert logprobs.token_logprobs is None
|
||||
assert logprobs.tokens is None
|
||||
assert logprobs.top_logprobs is None
|
||||
|
||||
|
||||
def test_get_valid_models_openai_proxy(monkeypatch):
|
||||
from litellm.utils import get_valid_models
|
||||
import litellm
|
||||
|
||||
litellm._turn_on_debug()
|
||||
|
||||
monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234")
|
||||
monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/")
|
||||
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None)
|
||||
monkeypatch.delenv("FIREWORKS_AI_API_KEY", None)
|
||||
|
||||
mock_response_data = {
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "gpt-4o",
|
||||
"object": "model",
|
||||
"created": 1686935002,
|
||||
"owned_by": "organization-owner",
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
# Create a mock response object
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
with patch.object(
|
||||
litellm.module_level_client, "get", return_value=mock_response
|
||||
) as mock_post:
|
||||
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||
assert "litellm_proxy/gpt-4o" in valid_models
|
||||
|
||||
|
||||
def test_get_valid_models_fireworks_ai(monkeypatch):
|
||||
from litellm.utils import get_valid_models
|
||||
import litellm
|
||||
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
|
||||
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234")
|
||||
|
||||
mock_response_data = {
|
||||
"models": [
|
||||
{
|
||||
"name": "accounts/fireworks/models/llama-3.1-8b-instruct",
|
||||
"displayName": "<string>",
|
||||
"description": "<string>",
|
||||
"createTime": "2023-11-07T05:31:56Z",
|
||||
"createdBy": "<string>",
|
||||
"state": "STATE_UNSPECIFIED",
|
||||
"status": {"code": "OK", "message": "<string>"},
|
||||
"kind": "KIND_UNSPECIFIED",
|
||||
"githubUrl": "<string>",
|
||||
"huggingFaceUrl": "<string>",
|
||||
"baseModelDetails": {
|
||||
"worldSize": 123,
|
||||
"checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED",
|
||||
"parameterCount": "<string>",
|
||||
"moe": True,
|
||||
"tunable": True,
|
||||
},
|
||||
"peftDetails": {
|
||||
"baseModel": "<string>",
|
||||
"r": 123,
|
||||
"targetModules": ["<string>"],
|
||||
},
|
||||
"teftDetails": {},
|
||||
"public": True,
|
||||
"conversationConfig": {
|
||||
"style": "<string>",
|
||||
"system": "<string>",
|
||||
"template": "<string>",
|
||||
},
|
||||
"contextLength": 123,
|
||||
"supportsImageInput": True,
|
||||
"supportsTools": True,
|
||||
"importedFrom": "<string>",
|
||||
"fineTuningJob": "<string>",
|
||||
"defaultDraftModel": "<string>",
|
||||
"defaultDraftTokenCount": 123,
|
||||
"precisions": ["PRECISION_UNSPECIFIED"],
|
||||
"deployedModelRefs": [
|
||||
{
|
||||
"name": "<string>",
|
||||
"deployment": "<string>",
|
||||
"state": "STATE_UNSPECIFIED",
|
||||
"default": True,
|
||||
"public": True,
|
||||
}
|
||||
],
|
||||
"cluster": "<string>",
|
||||
"deprecationDate": {"year": 123, "month": 123, "day": 123},
|
||||
}
|
||||
],
|
||||
"nextPageToken": "<string>",
|
||||
"totalSize": 123,
|
||||
}
|
||||
|
||||
# Create a mock response object
|
||||
mock_response = MagicMock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.return_value = mock_response_data
|
||||
|
||||
with patch.object(
|
||||
litellm.module_level_client, "get", return_value=mock_response
|
||||
) as mock_post:
|
||||
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||
assert (
|
||||
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
|
||||
in valid_models
|
||||
)
|
||||
|
||||
|
||||
def test_get_valid_models_default(monkeypatch):
|
||||
"""
|
||||
Ensure that the default models is used when error retrieving from model api.
|
||||
|
||||
Prevent regression for existing usage.
|
||||
"""
|
||||
from litellm.utils import get_valid_models
|
||||
import litellm
|
||||
|
||||
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
|
||||
valid_models = get_valid_models()
|
||||
assert len(valid_models) > 0
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue