mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Support checking provider-specific /models
endpoints for available models based on key (#7538)
* test(test_utils.py): initial test for valid models Addresses https://github.com/BerriAI/litellm/issues/7525 * fix: test * feat(fireworks_ai/transformation.py): support retrieving valid models from fireworks ai endpoint * refactor(fireworks_ai/): support checking model info on `/v1/models` route * docs(set_keys.md): update docs to clarify check llm provider api usage * fix(watsonx/common_utils.py): support 'WATSONX_ZENAPIKEY' for iam auth * fix(watsonx): read in watsonx token from env var * fix: fix linting errors * fix(utils.py): fix provider config check * style: cleanup unused imports
This commit is contained in:
parent
cac06a32b8
commit
f770dd0c95
12 changed files with 350 additions and 42 deletions
|
@ -14,6 +14,7 @@ os.environ["WATSONX_TOKEN"] = "" # IAM auth token
|
||||||
# optional - can also be passed as params to completion() or embedding()
|
# optional - can also be passed as params to completion() or embedding()
|
||||||
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
|
os.environ["WATSONX_PROJECT_ID"] = "" # Project ID of your WatsonX instance
|
||||||
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
|
os.environ["WATSONX_DEPLOYMENT_SPACE_ID"] = "" # ID of your deployment space to use deployed models
|
||||||
|
os.environ["WATSONX_ZENAPIKEY"] = "" # Zen API key (use for long-term api token)
|
||||||
```
|
```
|
||||||
|
|
||||||
See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.
|
See [here](https://cloud.ibm.com/apidocs/watsonx-ai#api-authentication) for more information on how to get an access token to authenticate to watsonx.ai.
|
||||||
|
|
|
@ -179,6 +179,22 @@ assert(valid_models == expected_models)
|
||||||
os.environ = old_environ
|
os.environ = old_environ
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### `get_valid_models(check_provider_endpoint: True)`
|
||||||
|
|
||||||
|
This helper will check the provider's endpoint for valid models.
|
||||||
|
|
||||||
|
Currently implemented for:
|
||||||
|
- OpenAI (if OPENAI_API_KEY is set)
|
||||||
|
- Fireworks AI (if FIREWORKS_AI_API_KEY is set)
|
||||||
|
- LiteLLM Proxy (if LITELLM_PROXY_API_KEY is set)
|
||||||
|
|
||||||
|
```python
|
||||||
|
from litellm import get_valid_models
|
||||||
|
|
||||||
|
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||||
|
print(valid_models)
|
||||||
|
```
|
||||||
|
|
||||||
### `validate_environment(model: str)`
|
### `validate_environment(model: str)`
|
||||||
|
|
||||||
This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
|
This helper tells you if you have all the required environment variables for a model, and if not - what's missing.
|
||||||
|
|
|
@ -1168,6 +1168,7 @@ from .llms.azure.azure import (
|
||||||
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
from .llms.azure.chat.gpt_transformation import AzureOpenAIConfig
|
||||||
from .llms.azure.completion.transformation import AzureOpenAITextConfig
|
from .llms.azure.completion.transformation import AzureOpenAITextConfig
|
||||||
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
from .llms.hosted_vllm.chat.transformation import HostedVLLMChatConfig
|
||||||
|
from .llms.litellm_proxy.chat.transformation import LiteLLMProxyChatConfig
|
||||||
from .llms.vllm.completion.transformation import VLLMConfig
|
from .llms.vllm.completion.transformation import VLLMConfig
|
||||||
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
from .llms.deepseek.chat.transformation import DeepSeekChatConfig
|
||||||
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
from .llms.lm_studio.chat.transformation import LMStudioChatConfig
|
||||||
|
|
|
@ -488,11 +488,10 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
elif custom_llm_provider == "fireworks_ai":
|
elif custom_llm_provider == "fireworks_ai":
|
||||||
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
# fireworks is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.fireworks.ai/inference/v1
|
||||||
(
|
(
|
||||||
model,
|
|
||||||
api_base,
|
api_base,
|
||||||
dynamic_api_key,
|
dynamic_api_key,
|
||||||
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
) = litellm.FireworksAIConfig()._get_openai_compatible_provider_info(
|
||||||
model=model, api_base=api_base, api_key=api_key
|
api_base=api_base, api_key=api_key
|
||||||
)
|
)
|
||||||
elif custom_llm_provider == "azure_ai":
|
elif custom_llm_provider == "azure_ai":
|
||||||
(
|
(
|
||||||
|
|
|
@ -1,9 +1,18 @@
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
from litellm.types.utils import ModelInfoBase
|
from litellm.types.utils import ModelInfoBase
|
||||||
|
|
||||||
|
|
||||||
class BaseLLMModelInfo(ABC):
|
class BaseLLMModelInfo(ABC):
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def get_model_info(self, model: str) -> ModelInfoBase:
|
def get_model_info(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
existing_model_info: Optional[ModelInfoBase] = None,
|
||||||
|
) -> Optional[ModelInfoBase]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_models(self) -> List[str]:
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
from typing import List, Literal, Optional, Tuple, Union, cast
|
from typing import List, Literal, Optional, Tuple, Union, cast
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
|
||||||
from litellm.secret_managers.main import get_secret_str
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
|
from litellm.types.llms.openai import AllMessageValues, ChatCompletionImageObject
|
||||||
from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
|
from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
|
||||||
|
@ -9,7 +8,7 @@ from litellm.types.utils import ModelInfoBase, ProviderSpecificModelInfo
|
||||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
|
class FireworksAIConfig(OpenAIGPTConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
Reference: https://docs.fireworks.ai/api-reference/post-chatcompletions
|
||||||
|
|
||||||
|
@ -209,8 +208,8 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
|
||||||
)
|
)
|
||||||
|
|
||||||
def _get_openai_compatible_provider_info(
|
def _get_openai_compatible_provider_info(
|
||||||
self, model: str, api_base: Optional[str], api_key: Optional[str]
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
) -> Tuple[str, Optional[str], Optional[str]]:
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
or get_secret_str("FIREWORKS_API_BASE")
|
or get_secret_str("FIREWORKS_API_BASE")
|
||||||
|
@ -222,4 +221,32 @@ class FireworksAIConfig(BaseLLMModelInfo, OpenAIGPTConfig):
|
||||||
or get_secret_str("FIREWORKSAI_API_KEY")
|
or get_secret_str("FIREWORKSAI_API_KEY")
|
||||||
or get_secret_str("FIREWORKS_AI_TOKEN")
|
or get_secret_str("FIREWORKS_AI_TOKEN")
|
||||||
)
|
)
|
||||||
return model, api_base, dynamic_api_key
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
|
def get_models(self, api_key: Optional[str] = None, api_base: Optional[str] = None):
|
||||||
|
api_base, api_key = self._get_openai_compatible_provider_info(
|
||||||
|
api_base=api_base, api_key=api_key
|
||||||
|
)
|
||||||
|
if api_base is None or api_key is None:
|
||||||
|
raise ValueError(
|
||||||
|
"FIREWORKS_API_BASE or FIREWORKS_API_KEY is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
account_id = get_secret_str("FIREWORKS_ACCOUNT_ID")
|
||||||
|
if account_id is None:
|
||||||
|
raise ValueError(
|
||||||
|
"FIREWORKS_ACCOUNT_ID is not set. Please set the environment variable, to query Fireworks AI's `/models` endpoint."
|
||||||
|
)
|
||||||
|
|
||||||
|
response = litellm.module_level_client.get(
|
||||||
|
url=f"{api_base}/v1/accounts/{account_id}/models",
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise ValueError(
|
||||||
|
f"Failed to fetch models from Fireworks AI. Status code: {response.status_code}, Response: {response.json()}"
|
||||||
|
)
|
||||||
|
|
||||||
|
models = response.json()["models"]
|
||||||
|
return ["fireworks_ai/" + model["name"] for model in models]
|
||||||
|
|
29
litellm/llms/litellm_proxy/chat/transformation.py
Normal file
29
litellm/llms/litellm_proxy/chat/transformation.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
"""
|
||||||
|
Translate from OpenAI's `/v1/chat/completions` to VLLM's `/v1/chat/completions`
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import List, Optional, Tuple
|
||||||
|
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLMProxyChatConfig(OpenAIGPTConfig):
|
||||||
|
def _get_openai_compatible_provider_info(
|
||||||
|
self, api_base: Optional[str], api_key: Optional[str]
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
api_base = api_base or get_secret_str("LITELLM_PROXY_API_BASE") # type: ignore
|
||||||
|
dynamic_api_key = api_key or get_secret_str("LITELLM_PROXY_API_KEY")
|
||||||
|
return api_base, dynamic_api_key
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
api_base, api_key = self._get_openai_compatible_provider_info(api_base, api_key)
|
||||||
|
if api_base is None:
|
||||||
|
raise ValueError(
|
||||||
|
"api_base not set for LiteLLM Proxy route. Set in env via `LITELLM_PROXY_API_BASE`"
|
||||||
|
)
|
||||||
|
models = super().get_models(api_key=api_key, api_base=api_base)
|
||||||
|
return [f"litellm_proxy/{model}" for model in models]
|
|
@ -7,9 +7,11 @@ from typing import TYPE_CHECKING, Any, List, Optional, Union, cast
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
|
from litellm.llms.base_llm.base_utils import BaseLLMModelInfo
|
||||||
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
from litellm.llms.base_llm.chat.transformation import BaseConfig, BaseLLMException
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
from litellm.types.llms.openai import AllMessageValues
|
from litellm.types.llms.openai import AllMessageValues
|
||||||
from litellm.types.utils import ModelResponse
|
from litellm.types.utils import ModelInfoBase, ModelResponse
|
||||||
|
|
||||||
from ..common_utils import OpenAIError
|
from ..common_utils import OpenAIError
|
||||||
|
|
||||||
|
@ -21,7 +23,7 @@ else:
|
||||||
LiteLLMLoggingObj = Any
|
LiteLLMLoggingObj = Any
|
||||||
|
|
||||||
|
|
||||||
class OpenAIGPTConfig(BaseConfig):
|
class OpenAIGPTConfig(BaseLLMModelInfo, BaseConfig):
|
||||||
"""
|
"""
|
||||||
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
Reference: https://platform.openai.com/docs/api-reference/chat/create
|
||||||
|
|
||||||
|
@ -229,3 +231,43 @@ class OpenAIGPTConfig(BaseConfig):
|
||||||
api_base: Optional[str] = None,
|
api_base: Optional[str] = None,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def get_models(
|
||||||
|
self, api_key: Optional[str] = None, api_base: Optional[str] = None
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Calls OpenAI's `/v1/models` endpoint and returns the list of models.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if api_base is None:
|
||||||
|
api_base = "https://api.openai.com"
|
||||||
|
if api_key is None:
|
||||||
|
api_key = get_secret_str("OPENAI_API_KEY")
|
||||||
|
|
||||||
|
response = litellm.module_level_client.get(
|
||||||
|
url=f"{api_base}/v1/models",
|
||||||
|
headers={"Authorization": f"Bearer {api_key}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise Exception(f"Failed to get models: {response.text}")
|
||||||
|
|
||||||
|
models = response.json()["data"]
|
||||||
|
return [model["id"] for model in models]
|
||||||
|
|
||||||
|
def get_model_info(
|
||||||
|
self, model: str, existing_model_info: Optional[ModelInfoBase] = None
|
||||||
|
) -> ModelInfoBase:
|
||||||
|
|
||||||
|
if existing_model_info is not None:
|
||||||
|
return existing_model_info
|
||||||
|
return ModelInfoBase(
|
||||||
|
key=model,
|
||||||
|
litellm_provider="openai",
|
||||||
|
mode="chat",
|
||||||
|
input_cost_per_token=0.0,
|
||||||
|
output_cost_per_token=0.0,
|
||||||
|
max_tokens=None,
|
||||||
|
max_input_tokens=None,
|
||||||
|
max_output_tokens=None,
|
||||||
|
)
|
||||||
|
|
|
@ -175,7 +175,12 @@ class IBMWatsonXMixin:
|
||||||
|
|
||||||
if "Authorization" in headers:
|
if "Authorization" in headers:
|
||||||
return {**default_headers, **headers}
|
return {**default_headers, **headers}
|
||||||
token = cast(Optional[str], optional_params.get("token"))
|
token = cast(
|
||||||
|
Optional[str],
|
||||||
|
optional_params.get("token")
|
||||||
|
or get_secret_str("WATSONX_ZENAPIKEY")
|
||||||
|
or get_secret_str("WATSONX_TOKEN"),
|
||||||
|
)
|
||||||
if token:
|
if token:
|
||||||
headers["Authorization"] = f"Bearer {token}"
|
headers["Authorization"] = f"Bearer {token}"
|
||||||
else:
|
else:
|
||||||
|
@ -245,6 +250,7 @@ class IBMWatsonXMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
token: Optional[str] = None
|
token: Optional[str] = None
|
||||||
|
|
||||||
if wx_credentials is not None:
|
if wx_credentials is not None:
|
||||||
api_base = wx_credentials.get("url", api_base)
|
api_base = wx_credentials.get("url", api_base)
|
||||||
api_key = wx_credentials.get(
|
api_key = wx_credentials.get(
|
||||||
|
|
|
@ -4223,6 +4223,7 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||||
|
|
||||||
_model_info: Optional[Dict[str, Any]] = None
|
_model_info: Optional[Dict[str, Any]] = None
|
||||||
key: Optional[str] = None
|
key: Optional[str] = None
|
||||||
|
provider_config: Optional[BaseLLMModelInfo] = None
|
||||||
if combined_model_name in litellm.model_cost:
|
if combined_model_name in litellm.model_cost:
|
||||||
key = combined_model_name
|
key = combined_model_name
|
||||||
_model_info = _get_model_info_from_model_cost(key=key)
|
_model_info = _get_model_info_from_model_cost(key=key)
|
||||||
|
@ -4261,16 +4262,20 @@ def _get_model_info_helper( # noqa: PLR0915
|
||||||
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
model_info=_model_info, custom_llm_provider=custom_llm_provider
|
||||||
):
|
):
|
||||||
_model_info = None
|
_model_info = None
|
||||||
if _model_info is None and ProviderConfigManager.get_provider_model_info(
|
|
||||||
model=model, provider=LlmProviders(custom_llm_provider)
|
if custom_llm_provider:
|
||||||
):
|
|
||||||
provider_config = ProviderConfigManager.get_provider_model_info(
|
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||||
model=model, provider=LlmProviders(custom_llm_provider)
|
model=model, provider=LlmProviders(custom_llm_provider)
|
||||||
)
|
)
|
||||||
if provider_config is not None:
|
|
||||||
_model_info = cast(
|
if _model_info is None and provider_config is not None:
|
||||||
dict, provider_config.get_model_info(model=model)
|
_model_info = cast(
|
||||||
)
|
Optional[Dict],
|
||||||
|
provider_config.get_model_info(
|
||||||
|
model=model, existing_model_info=_model_info
|
||||||
|
),
|
||||||
|
)
|
||||||
|
if key is None:
|
||||||
key = "provider_specific_model_info"
|
key = "provider_specific_model_info"
|
||||||
if _model_info is None or key is None:
|
if _model_info is None or key is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -5706,12 +5711,12 @@ def trim_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_valid_models() -> List[str]:
|
def get_valid_models(check_provider_endpoint: bool = False) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns a list of valid LLMs based on the set environment variables
|
Returns a list of valid LLMs based on the set environment variables
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
None
|
check_provider_endpoint: If True, will check the provider's endpoint for valid models.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A list of valid LLMs
|
A list of valid LLMs
|
||||||
|
@ -5725,22 +5730,36 @@ def get_valid_models() -> List[str]:
|
||||||
|
|
||||||
for provider in litellm.provider_list:
|
for provider in litellm.provider_list:
|
||||||
# edge case litellm has together_ai as a provider, it should be togetherai
|
# edge case litellm has together_ai as a provider, it should be togetherai
|
||||||
provider = provider.replace("_", "")
|
env_provider_1 = provider.replace("_", "")
|
||||||
|
env_provider_2 = provider
|
||||||
|
|
||||||
# litellm standardizes expected provider keys to
|
# litellm standardizes expected provider keys to
|
||||||
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
# PROVIDER_API_KEY. Example: OPENAI_API_KEY, COHERE_API_KEY
|
||||||
expected_provider_key = f"{provider.upper()}_API_KEY"
|
expected_provider_key_1 = f"{env_provider_1.upper()}_API_KEY"
|
||||||
if expected_provider_key in environ_keys:
|
expected_provider_key_2 = f"{env_provider_2.upper()}_API_KEY"
|
||||||
|
if (
|
||||||
|
expected_provider_key_1 in environ_keys
|
||||||
|
or expected_provider_key_2 in environ_keys
|
||||||
|
):
|
||||||
# key is set
|
# key is set
|
||||||
valid_providers.append(provider)
|
valid_providers.append(provider)
|
||||||
|
|
||||||
for provider in valid_providers:
|
for provider in valid_providers:
|
||||||
|
provider_config = ProviderConfigManager.get_provider_model_info(
|
||||||
|
model=None,
|
||||||
|
provider=LlmProviders(provider),
|
||||||
|
)
|
||||||
|
|
||||||
if provider == "azure":
|
if provider == "azure":
|
||||||
valid_models.append("Azure-LLM")
|
valid_models.append("Azure-LLM")
|
||||||
|
elif provider_config is not None and check_provider_endpoint:
|
||||||
|
valid_models.extend(provider_config.get_models())
|
||||||
else:
|
else:
|
||||||
models_for_provider = litellm.models_by_provider.get(provider, [])
|
models_for_provider = litellm.models_by_provider.get(provider, [])
|
||||||
valid_models.extend(models_for_provider)
|
valid_models.extend(models_for_provider)
|
||||||
return valid_models
|
return valid_models
|
||||||
except Exception:
|
except Exception as e:
|
||||||
|
verbose_logger.debug(f"Error getting valid models: {e}")
|
||||||
return [] # NON-Blocking
|
return [] # NON-Blocking
|
||||||
|
|
||||||
|
|
||||||
|
@ -6291,11 +6310,14 @@ class ProviderConfigManager:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_provider_model_info(
|
def get_provider_model_info(
|
||||||
model: str,
|
model: Optional[str],
|
||||||
provider: LlmProviders,
|
provider: LlmProviders,
|
||||||
) -> Optional[BaseLLMModelInfo]:
|
) -> Optional[BaseLLMModelInfo]:
|
||||||
if LlmProviders.FIREWORKS_AI == provider:
|
if LlmProviders.FIREWORKS_AI == provider:
|
||||||
return litellm.FireworksAIConfig()
|
return litellm.FireworksAIConfig()
|
||||||
|
elif LlmProviders.LITELLM_PROXY == provider:
|
||||||
|
return litellm.LiteLLMProxyChatConfig()
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,31 +23,43 @@ def watsonx_chat_completion_call():
|
||||||
api_key="test_api_key",
|
api_key="test_api_key",
|
||||||
headers=None,
|
headers=None,
|
||||||
client=None,
|
client=None,
|
||||||
|
patch_token_call=True,
|
||||||
):
|
):
|
||||||
if messages is None:
|
if messages is None:
|
||||||
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
messages = [{"role": "user", "content": "Hello, how are you?"}]
|
||||||
if client is None:
|
if client is None:
|
||||||
client = HTTPHandler()
|
client = HTTPHandler()
|
||||||
|
|
||||||
mock_response = Mock()
|
if patch_token_call:
|
||||||
mock_response.json.return_value = {
|
mock_response = Mock()
|
||||||
"access_token": "mock_access_token",
|
mock_response.json.return_value = {
|
||||||
"expires_in": 3600,
|
"access_token": "mock_access_token",
|
||||||
}
|
"expires_in": 3600,
|
||||||
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
}
|
||||||
|
mock_response.raise_for_status = Mock() # No-op to simulate no exception
|
||||||
|
|
||||||
with patch.object(client, "post") as mock_post, patch.object(
|
with patch.object(client, "post") as mock_post, patch.object(
|
||||||
litellm.module_level_client, "post", return_value=mock_response
|
litellm.module_level_client, "post", return_value=mock_response
|
||||||
) as mock_get:
|
) as mock_get:
|
||||||
completion(
|
completion(
|
||||||
model=model,
|
model=model,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
api_key=api_key,
|
api_key=api_key,
|
||||||
headers=headers or {},
|
headers=headers or {},
|
||||||
client=client,
|
client=client,
|
||||||
)
|
)
|
||||||
|
|
||||||
return mock_post, mock_get
|
return mock_post, mock_get
|
||||||
|
else:
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
api_key=api_key,
|
||||||
|
headers=headers or {},
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
return mock_post, None
|
||||||
|
|
||||||
return _call
|
return _call
|
||||||
|
|
||||||
|
@ -77,6 +89,20 @@ def test_watsonx_custom_auth_header(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("env_var_key", ["WATSONX_ZENAPIKEY", "WATSONX_TOKEN"])
|
||||||
|
def test_watsonx_token_in_env_var(
|
||||||
|
monkeypatch, watsonx_chat_completion_call, env_var_key
|
||||||
|
):
|
||||||
|
monkeypatch.setenv(env_var_key, "my-custom-token")
|
||||||
|
|
||||||
|
mock_post, _ = watsonx_chat_completion_call(patch_token_call=False)
|
||||||
|
|
||||||
|
assert mock_post.call_count == 1
|
||||||
|
assert (
|
||||||
|
mock_post.call_args[1]["headers"]["Authorization"] == "Bearer my-custom-token"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
|
def test_watsonx_chat_completions_endpoint(watsonx_chat_completion_call):
|
||||||
model = "watsonx/another-model"
|
model = "watsonx/another-model"
|
||||||
messages = [{"role": "user", "content": "Test message"}]
|
messages = [{"role": "user", "content": "Test message"}]
|
||||||
|
|
|
@ -1270,6 +1270,8 @@ def test_fireworks_ai_document_inlining():
|
||||||
"""
|
"""
|
||||||
from litellm.utils import supports_pdf_input, supports_vision
|
from litellm.utils import supports_pdf_input, supports_vision
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
assert supports_pdf_input("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||||
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
assert supports_vision("fireworks_ai/llama-3.1-8b-instruct") is True
|
||||||
|
|
||||||
|
@ -1288,3 +1290,131 @@ def test_logprobs_type():
|
||||||
assert logprobs.token_logprobs is None
|
assert logprobs.token_logprobs is None
|
||||||
assert logprobs.tokens is None
|
assert logprobs.tokens is None
|
||||||
assert logprobs.top_logprobs is None
|
assert logprobs.top_logprobs is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_valid_models_openai_proxy(monkeypatch):
|
||||||
|
from litellm.utils import get_valid_models
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
monkeypatch.setenv("LITELLM_PROXY_API_KEY", "sk-1234")
|
||||||
|
monkeypatch.setenv("LITELLM_PROXY_API_BASE", "https://litellm-api.up.railway.app/")
|
||||||
|
monkeypatch.delenv("FIREWORKS_AI_ACCOUNT_ID", None)
|
||||||
|
monkeypatch.delenv("FIREWORKS_AI_API_KEY", None)
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"object": "list",
|
||||||
|
"data": [
|
||||||
|
{
|
||||||
|
"id": "gpt-4o",
|
||||||
|
"object": "model",
|
||||||
|
"created": 1686935002,
|
||||||
|
"owned_by": "organization-owner",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a mock response object
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_response_data
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
litellm.module_level_client, "get", return_value=mock_response
|
||||||
|
) as mock_post:
|
||||||
|
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||||
|
assert "litellm_proxy/gpt-4o" in valid_models
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_valid_models_fireworks_ai(monkeypatch):
|
||||||
|
from litellm.utils import get_valid_models
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
|
||||||
|
monkeypatch.setenv("FIREWORKS_ACCOUNT_ID", "1234")
|
||||||
|
|
||||||
|
mock_response_data = {
|
||||||
|
"models": [
|
||||||
|
{
|
||||||
|
"name": "accounts/fireworks/models/llama-3.1-8b-instruct",
|
||||||
|
"displayName": "<string>",
|
||||||
|
"description": "<string>",
|
||||||
|
"createTime": "2023-11-07T05:31:56Z",
|
||||||
|
"createdBy": "<string>",
|
||||||
|
"state": "STATE_UNSPECIFIED",
|
||||||
|
"status": {"code": "OK", "message": "<string>"},
|
||||||
|
"kind": "KIND_UNSPECIFIED",
|
||||||
|
"githubUrl": "<string>",
|
||||||
|
"huggingFaceUrl": "<string>",
|
||||||
|
"baseModelDetails": {
|
||||||
|
"worldSize": 123,
|
||||||
|
"checkpointFormat": "CHECKPOINT_FORMAT_UNSPECIFIED",
|
||||||
|
"parameterCount": "<string>",
|
||||||
|
"moe": True,
|
||||||
|
"tunable": True,
|
||||||
|
},
|
||||||
|
"peftDetails": {
|
||||||
|
"baseModel": "<string>",
|
||||||
|
"r": 123,
|
||||||
|
"targetModules": ["<string>"],
|
||||||
|
},
|
||||||
|
"teftDetails": {},
|
||||||
|
"public": True,
|
||||||
|
"conversationConfig": {
|
||||||
|
"style": "<string>",
|
||||||
|
"system": "<string>",
|
||||||
|
"template": "<string>",
|
||||||
|
},
|
||||||
|
"contextLength": 123,
|
||||||
|
"supportsImageInput": True,
|
||||||
|
"supportsTools": True,
|
||||||
|
"importedFrom": "<string>",
|
||||||
|
"fineTuningJob": "<string>",
|
||||||
|
"defaultDraftModel": "<string>",
|
||||||
|
"defaultDraftTokenCount": 123,
|
||||||
|
"precisions": ["PRECISION_UNSPECIFIED"],
|
||||||
|
"deployedModelRefs": [
|
||||||
|
{
|
||||||
|
"name": "<string>",
|
||||||
|
"deployment": "<string>",
|
||||||
|
"state": "STATE_UNSPECIFIED",
|
||||||
|
"default": True,
|
||||||
|
"public": True,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"cluster": "<string>",
|
||||||
|
"deprecationDate": {"year": 123, "month": 123, "day": 123},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"nextPageToken": "<string>",
|
||||||
|
"totalSize": 123,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Create a mock response object
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = mock_response_data
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
litellm.module_level_client, "get", return_value=mock_response
|
||||||
|
) as mock_post:
|
||||||
|
valid_models = get_valid_models(check_provider_endpoint=True)
|
||||||
|
assert (
|
||||||
|
"fireworks_ai/accounts/fireworks/models/llama-3.1-8b-instruct"
|
||||||
|
in valid_models
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_valid_models_default(monkeypatch):
|
||||||
|
"""
|
||||||
|
Ensure that the default models is used when error retrieving from model api.
|
||||||
|
|
||||||
|
Prevent regression for existing usage.
|
||||||
|
"""
|
||||||
|
from litellm.utils import get_valid_models
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
monkeypatch.setenv("FIREWORKS_API_KEY", "sk-1234")
|
||||||
|
valid_models = get_valid_models()
|
||||||
|
assert len(valid_models) > 0
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue