Add Llamafile chat config and tests

This commit is contained in:
Peter Wilson 2025-04-22 16:39:33 +01:00
parent afaa3da3dd
commit 17a40696de
No known key found for this signature in database
GPG key ID: 3CECF55EBF09C069
5 changed files with 208 additions and 1 deletions

View file

@ -0,0 +1,46 @@
from typing import Optional, Tuple
from litellm.secret_managers.main import get_secret_str
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
class LlamafileChatConfig(OpenAIGPTConfig):
"""LlamafileChatConfig is used to provide configuration for the LlamaFile's chat API."""
@staticmethod
def _resolve_api_key(api_key: Optional[str] = None) -> str:
"""Attempt to ensure that the API key is set, preferring the user-provided key
over the secret manager key (``LLAMAFILE_API_KEY``).
If both are None, a fake API key is returned.
"""
return api_key or get_secret_str("LLAMAFILE_API_KEY") or "fake-api-key" # llamafile does not require an API key
@staticmethod
def _resolve_api_base(api_base: Optional[str] = None) -> Optional[str]:
"""Attempt to ensure that the API base is set, preferring the user-provided key
over the secret manager key (``LLAMAFILE_API_BASE``).
If both are None, a default Llamafile server URL is returned.
See: https://github.com/Mozilla-Ocho/llamafile/blob/bd1bbe9aabb1ee12dbdcafa8936db443c571eb9d/README.md#L61
"""
return api_base or get_secret_str("LLAMAFILE_API_BASE") or "http://127.0.0.1:8080/v1" # type: ignore
def _get_openai_compatible_provider_info(
self,
api_base: Optional[str],
api_key: Optional[str]
) -> Tuple[Optional[str], Optional[str]]:
"""Attempts to ensure that the API base and key are set, preferring user-provided values,
before falling back to secret manager values (``LLAMAFILE_API_BASE`` and ``LLAMAFILE_API_KEY``
respectively).
If an API key cannot be resolved via either method, a fake key is returned. Llamafile
does not require an API key, but the underlying OpenAI library may expect one anyway.
"""
api_base = LlamafileChatConfig._resolve_api_base(api_base)
dynamic_api_key = LlamafileChatConfig._resolve_api_key(api_key)
return api_base, dynamic_api_key

View file

@ -0,0 +1,128 @@
from typing import Optional
from unittest.mock import patch
import pytest
import litellm
from litellm.llms.llamafile.chat.transformation import LlamafileChatConfig
@pytest.mark.parametrize(
"input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called",
[
("user-provided-key", "secret-key", "user-provided-key", False),
(None, "secret-key", "secret-key", True),
(None, None, "fake-api-key", True),
("", "secret-key", "secret-key", True), # Empty string should fall back to secret
("", None, "fake-api-key", True), # Empty string with no secret should use the fake key
]
)
def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called):
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
mock_get_secret.return_value = api_key_from_secret_manager
result = LlamafileChatConfig._resolve_api_key(input_api_key)
if secret_manager_called:
mock_get_secret.assert_called_once_with("LLAMAFILE_API_KEY")
else:
mock_get_secret.assert_not_called()
assert result == expected_api_key
@pytest.mark.parametrize(
"input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called",
[
("https://user-api.example.com", "https://secret-api.example.com", "https://user-api.example.com", False),
(None, "https://secret-api.example.com", "https://secret-api.example.com", True),
(None, None, "http://127.0.0.1:8080/v1", True),
("", "https://secret-api.example.com", "https://secret-api.example.com", True), # Empty string should fall back
]
)
def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called):
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
mock_get_secret.return_value = api_base_from_secret_manager
result = LlamafileChatConfig._resolve_api_base(input_api_base)
if secret_manager_called:
mock_get_secret.assert_called_once_with("LLAMAFILE_API_BASE")
else:
mock_get_secret.assert_not_called()
assert result == expected_api_base
@pytest.mark.parametrize(
"api_base, api_key, secret_base, secret_key, expected_base, expected_key",
[
# User-provided values
("https://user-api.example.com", "user-key", "https://secret-api.example.com", "secret-key", "https://user-api.example.com", "user-key"),
# Fallback to secrets
(None, None, "https://secret-api.example.com", "secret-key", "https://secret-api.example.com", "secret-key"),
# Nothing provided, use defaults
(None, None, None, None, "http://127.0.0.1:8080/v1", "fake-api-key"),
# Mixed scenarios
("https://user-api.example.com", None, None, "secret-key", "https://user-api.example.com", "secret-key"),
(None, "user-key", "https://secret-api.example.com", None, "https://secret-api.example.com", "user-key"),
]
)
def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, secret_key, expected_base, expected_key):
config = LlamafileChatConfig()
def fake_get_secret(key: str) -> Optional[str]:
return {
"LLAMAFILE_API_BASE": secret_base,
"LLAMAFILE_API_KEY": secret_key
}.get(key)
patch_secret = patch("litellm.llms.llamafile.chat.transformation.get_secret_str", side_effect=fake_get_secret)
patch_base = patch.object(LlamafileChatConfig, "_resolve_api_base", wraps=LlamafileChatConfig._resolve_api_base)
patch_key = patch.object(LlamafileChatConfig, "_resolve_api_key", wraps=LlamafileChatConfig._resolve_api_key)
with patch_secret as mock_secret, patch_base as mock_base, patch_key as mock_key:
result_base, result_key = config._get_openai_compatible_provider_info(api_base, api_key)
assert result_base == expected_base
assert result_key == expected_key
mock_base.assert_called_once_with(api_base)
mock_key.assert_called_once_with(api_key)
# Ensure get_secret_str was used as expected within the methods
if api_base and api_key:
mock_secret.assert_not_called()
elif api_base or api_key:
mock_secret.assert_called_once()
else:
assert mock_secret.call_count == 2
def test_completion_with_custom_llamafile_model():
with patch("litellm.main.openai_chat_completions.completion") as mock_llamafile_completion_func:
mock_llamafile_completion_func.return_value = {} # Return an empty dictionary for the mocked response
provider = "llamafile"
model_name = "my-custom-test-model"
model = f"{provider}/{model_name}"
messages = [{"role": "user", "content": "Hey, how's it going?"}]
_ = litellm.completion(
model=model,
messages=messages,
max_retries=2,
max_tokens=100,
)
mock_llamafile_completion_func.assert_called_once()
_, call_kwargs = mock_llamafile_completion_func.call_args
assert call_kwargs.get("custom_llm_provider") == provider
assert call_kwargs.get("model") == model_name
assert call_kwargs.get("messages") == messages
assert call_kwargs.get("api_base") == "http://127.0.0.1:8080/v1"
assert call_kwargs.get("api_key") == "fake-api-key"
optional_params = call_kwargs.get("optional_params")
assert optional_params
assert optional_params.get("max_retries") == 2
assert optional_params.get("max_tokens") == 100

View file

@ -1497,7 +1497,7 @@ HF Tests we should pass
@pytest.mark.parametrize(
"provider", ["openai", "hosted_vllm", "lm_studio"]
"provider", ["openai", "hosted_vllm", "lm_studio", "llamafile"]
) # "vertex_ai",
@pytest.mark.asyncio
async def test_openai_compatible_custom_api_base(provider):
@ -1539,6 +1539,7 @@ async def test_openai_compatible_custom_api_base(provider):
[
"openai",
"hosted_vllm",
"llamafile",
],
) # "vertex_ai",
@pytest.mark.asyncio

View file

@ -1029,6 +1029,28 @@ def test_hosted_vllm_embedding(monkeypatch):
assert json_data["model"] == "jina-embeddings-v3"
def test_llamafile_embedding(monkeypatch):
monkeypatch.setenv("LLAMAFILE_API_BASE", "http://localhost:8080/v1")
from litellm.llms.custom_httpx.http_handler import HTTPHandler
client = HTTPHandler()
with patch.object(client, "post") as mock_post:
try:
embedding(
model="llamafile/jina-embeddings-v3",
input=["Hello world"],
client=client,
)
except Exception as e:
print(e)
mock_post.assert_called_once()
json_data = json.loads(mock_post.call_args.kwargs["data"])
assert json_data["input"] == ["Hello world"]
assert json_data["model"] == "jina-embeddings-v3"
@pytest.mark.asyncio
@pytest.mark.parametrize("sync_mode", [True, False])
async def test_lm_studio_embedding(monkeypatch, sync_mode):

View file

@ -185,6 +185,16 @@ def test_get_llm_provider_hosted_vllm():
assert dynamic_api_key == "fake-api-key"
def test_get_llm_provider_llamafile():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="llamafile/mistralai/mistral-7b-instruct-v0.2",
)
assert custom_llm_provider == "llamafile"
assert model == "mistralai/mistral-7b-instruct-v0.2"
assert dynamic_api_key == "fake-api-key"
assert api_base == "http://127.0.0.1:8080/v1"
def test_get_llm_provider_watson_text():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="watsonx_text/watson-text-to-speech",