diff --git a/litellm/llms/llamafile/chat/transformation.py b/litellm/llms/llamafile/chat/transformation.py new file mode 100644 index 0000000000..b0f8cd3fc3 --- /dev/null +++ b/litellm/llms/llamafile/chat/transformation.py @@ -0,0 +1,46 @@ +from typing import Optional, Tuple + +from litellm.secret_managers.main import get_secret_str + +from ...openai.chat.gpt_transformation import OpenAIGPTConfig + + +class LlamafileChatConfig(OpenAIGPTConfig): + """LlamafileChatConfig is used to provide configuration for the LlamaFile's chat API.""" + + @staticmethod + def _resolve_api_key(api_key: Optional[str] = None) -> str: + """Attempt to ensure that the API key is set, preferring the user-provided key + over the secret manager key (``LLAMAFILE_API_KEY``). + + If both are None, a fake API key is returned. + """ + return api_key or get_secret_str("LLAMAFILE_API_KEY") or "fake-api-key" # llamafile does not require an API key + + @staticmethod + def _resolve_api_base(api_base: Optional[str] = None) -> Optional[str]: + """Attempt to ensure that the API base is set, preferring the user-provided key + over the secret manager key (``LLAMAFILE_API_BASE``). + + If both are None, a default Llamafile server URL is returned. + See: https://github.com/Mozilla-Ocho/llamafile/blob/bd1bbe9aabb1ee12dbdcafa8936db443c571eb9d/README.md#L61 + """ + return api_base or get_secret_str("LLAMAFILE_API_BASE") or "http://127.0.0.1:8080/v1" # type: ignore + + + def _get_openai_compatible_provider_info( + self, + api_base: Optional[str], + api_key: Optional[str] + ) -> Tuple[Optional[str], Optional[str]]: + """Attempts to ensure that the API base and key are set, preferring user-provided values, + before falling back to secret manager values (``LLAMAFILE_API_BASE`` and ``LLAMAFILE_API_KEY`` + respectively). + + If an API key cannot be resolved via either method, a fake key is returned. Llamafile + does not require an API key, but the underlying OpenAI library may expect one anyway. + """ + api_base = LlamafileChatConfig._resolve_api_base(api_base) + dynamic_api_key = LlamafileChatConfig._resolve_api_key(api_key) + + return api_base, dynamic_api_key diff --git a/tests/litellm/llms/llamafile/chat/test_llamafile_chat_transformation.py b/tests/litellm/llms/llamafile/chat/test_llamafile_chat_transformation.py new file mode 100644 index 0000000000..3dfd0d6628 --- /dev/null +++ b/tests/litellm/llms/llamafile/chat/test_llamafile_chat_transformation.py @@ -0,0 +1,128 @@ +from typing import Optional + +from unittest.mock import patch +import pytest + +import litellm +from litellm.llms.llamafile.chat.transformation import LlamafileChatConfig + + +@pytest.mark.parametrize( + "input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called", + [ + ("user-provided-key", "secret-key", "user-provided-key", False), + (None, "secret-key", "secret-key", True), + (None, None, "fake-api-key", True), + ("", "secret-key", "secret-key", True), # Empty string should fall back to secret + ("", None, "fake-api-key", True), # Empty string with no secret should use the fake key + ] +) +def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called): + with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret: + mock_get_secret.return_value = api_key_from_secret_manager + + result = LlamafileChatConfig._resolve_api_key(input_api_key) + + if secret_manager_called: + mock_get_secret.assert_called_once_with("LLAMAFILE_API_KEY") + else: + mock_get_secret.assert_not_called() + + assert result == expected_api_key + + +@pytest.mark.parametrize( + "input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called", + [ + ("https://user-api.example.com", "https://secret-api.example.com", "https://user-api.example.com", False), + (None, "https://secret-api.example.com", "https://secret-api.example.com", True), + (None, None, "http://127.0.0.1:8080/v1", True), + ("", "https://secret-api.example.com", "https://secret-api.example.com", True), # Empty string should fall back + ] +) +def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called): + with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret: + mock_get_secret.return_value = api_base_from_secret_manager + + result = LlamafileChatConfig._resolve_api_base(input_api_base) + + if secret_manager_called: + mock_get_secret.assert_called_once_with("LLAMAFILE_API_BASE") + else: + mock_get_secret.assert_not_called() + + assert result == expected_api_base + + +@pytest.mark.parametrize( + "api_base, api_key, secret_base, secret_key, expected_base, expected_key", + [ + # User-provided values + ("https://user-api.example.com", "user-key", "https://secret-api.example.com", "secret-key", "https://user-api.example.com", "user-key"), + # Fallback to secrets + (None, None, "https://secret-api.example.com", "secret-key", "https://secret-api.example.com", "secret-key"), + # Nothing provided, use defaults + (None, None, None, None, "http://127.0.0.1:8080/v1", "fake-api-key"), + # Mixed scenarios + ("https://user-api.example.com", None, None, "secret-key", "https://user-api.example.com", "secret-key"), + (None, "user-key", "https://secret-api.example.com", None, "https://secret-api.example.com", "user-key"), + ] +) +def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, secret_key, expected_base, expected_key): + config = LlamafileChatConfig() + + def fake_get_secret(key: str) -> Optional[str]: + return { + "LLAMAFILE_API_BASE": secret_base, + "LLAMAFILE_API_KEY": secret_key + }.get(key) + + patch_secret = patch("litellm.llms.llamafile.chat.transformation.get_secret_str", side_effect=fake_get_secret) + patch_base = patch.object(LlamafileChatConfig, "_resolve_api_base", wraps=LlamafileChatConfig._resolve_api_base) + patch_key = patch.object(LlamafileChatConfig, "_resolve_api_key", wraps=LlamafileChatConfig._resolve_api_key) + + with patch_secret as mock_secret, patch_base as mock_base, patch_key as mock_key: + result_base, result_key = config._get_openai_compatible_provider_info(api_base, api_key) + + assert result_base == expected_base + assert result_key == expected_key + + mock_base.assert_called_once_with(api_base) + mock_key.assert_called_once_with(api_key) + + # Ensure get_secret_str was used as expected within the methods + if api_base and api_key: + mock_secret.assert_not_called() + elif api_base or api_key: + mock_secret.assert_called_once() + else: + assert mock_secret.call_count == 2 + + +def test_completion_with_custom_llamafile_model(): + with patch("litellm.main.openai_chat_completions.completion") as mock_llamafile_completion_func: + mock_llamafile_completion_func.return_value = {} # Return an empty dictionary for the mocked response + + provider = "llamafile" + model_name = "my-custom-test-model" + model = f"{provider}/{model_name}" + messages = [{"role": "user", "content": "Hey, how's it going?"}] + + _ = litellm.completion( + model=model, + messages=messages, + max_retries=2, + max_tokens=100, + ) + + mock_llamafile_completion_func.assert_called_once() + _, call_kwargs = mock_llamafile_completion_func.call_args + assert call_kwargs.get("custom_llm_provider") == provider + assert call_kwargs.get("model") == model_name + assert call_kwargs.get("messages") == messages + assert call_kwargs.get("api_base") == "http://127.0.0.1:8080/v1" + assert call_kwargs.get("api_key") == "fake-api-key" + optional_params = call_kwargs.get("optional_params") + assert optional_params + assert optional_params.get("max_retries") == 2 + assert optional_params.get("max_tokens") == 100 diff --git a/tests/local_testing/test_completion.py b/tests/local_testing/test_completion.py index 064fe7f736..0362e8dbcd 100644 --- a/tests/local_testing/test_completion.py +++ b/tests/local_testing/test_completion.py @@ -1497,7 +1497,7 @@ HF Tests we should pass @pytest.mark.parametrize( - "provider", ["openai", "hosted_vllm", "lm_studio"] + "provider", ["openai", "hosted_vllm", "lm_studio", "llamafile"] ) # "vertex_ai", @pytest.mark.asyncio async def test_openai_compatible_custom_api_base(provider): @@ -1539,6 +1539,7 @@ async def test_openai_compatible_custom_api_base(provider): [ "openai", "hosted_vllm", + "llamafile", ], ) # "vertex_ai", @pytest.mark.asyncio diff --git a/tests/local_testing/test_embedding.py b/tests/local_testing/test_embedding.py index 9146944aa2..75e0c35283 100644 --- a/tests/local_testing/test_embedding.py +++ b/tests/local_testing/test_embedding.py @@ -1029,6 +1029,28 @@ def test_hosted_vllm_embedding(monkeypatch): assert json_data["model"] == "jina-embeddings-v3" +def test_llamafile_embedding(monkeypatch): + monkeypatch.setenv("LLAMAFILE_API_BASE", "http://localhost:8080/v1") + from litellm.llms.custom_httpx.http_handler import HTTPHandler + + client = HTTPHandler() + with patch.object(client, "post") as mock_post: + try: + embedding( + model="llamafile/jina-embeddings-v3", + input=["Hello world"], + client=client, + ) + except Exception as e: + print(e) + + mock_post.assert_called_once() + + json_data = json.loads(mock_post.call_args.kwargs["data"]) + assert json_data["input"] == ["Hello world"] + assert json_data["model"] == "jina-embeddings-v3" + + @pytest.mark.asyncio @pytest.mark.parametrize("sync_mode", [True, False]) async def test_lm_studio_embedding(monkeypatch, sync_mode): diff --git a/tests/local_testing/test_get_llm_provider.py b/tests/local_testing/test_get_llm_provider.py index fa27a8378c..b5ea128971 100644 --- a/tests/local_testing/test_get_llm_provider.py +++ b/tests/local_testing/test_get_llm_provider.py @@ -185,6 +185,16 @@ def test_get_llm_provider_hosted_vllm(): assert dynamic_api_key == "fake-api-key" +def test_get_llm_provider_llamafile(): + model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( + model="llamafile/mistralai/mistral-7b-instruct-v0.2", + ) + assert custom_llm_provider == "llamafile" + assert model == "mistralai/mistral-7b-instruct-v0.2" + assert dynamic_api_key == "fake-api-key" + assert api_base == "http://127.0.0.1:8080/v1" + + def test_get_llm_provider_watson_text(): model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider( model="watsonx_text/watson-text-to-speech",