mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Add Llamafile chat config and tests
This commit is contained in:
parent
afaa3da3dd
commit
17a40696de
5 changed files with 208 additions and 1 deletions
46
litellm/llms/llamafile/chat/transformation.py
Normal file
46
litellm/llms/llamafile/chat/transformation.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
from litellm.secret_managers.main import get_secret_str
|
||||||
|
|
||||||
|
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||||
|
|
||||||
|
|
||||||
|
class LlamafileChatConfig(OpenAIGPTConfig):
|
||||||
|
"""LlamafileChatConfig is used to provide configuration for the LlamaFile's chat API."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_api_key(api_key: Optional[str] = None) -> str:
|
||||||
|
"""Attempt to ensure that the API key is set, preferring the user-provided key
|
||||||
|
over the secret manager key (``LLAMAFILE_API_KEY``).
|
||||||
|
|
||||||
|
If both are None, a fake API key is returned.
|
||||||
|
"""
|
||||||
|
return api_key or get_secret_str("LLAMAFILE_API_KEY") or "fake-api-key" # llamafile does not require an API key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _resolve_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||||
|
"""Attempt to ensure that the API base is set, preferring the user-provided key
|
||||||
|
over the secret manager key (``LLAMAFILE_API_BASE``).
|
||||||
|
|
||||||
|
If both are None, a default Llamafile server URL is returned.
|
||||||
|
See: https://github.com/Mozilla-Ocho/llamafile/blob/bd1bbe9aabb1ee12dbdcafa8936db443c571eb9d/README.md#L61
|
||||||
|
"""
|
||||||
|
return api_base or get_secret_str("LLAMAFILE_API_BASE") or "http://127.0.0.1:8080/v1" # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
def _get_openai_compatible_provider_info(
|
||||||
|
self,
|
||||||
|
api_base: Optional[str],
|
||||||
|
api_key: Optional[str]
|
||||||
|
) -> Tuple[Optional[str], Optional[str]]:
|
||||||
|
"""Attempts to ensure that the API base and key are set, preferring user-provided values,
|
||||||
|
before falling back to secret manager values (``LLAMAFILE_API_BASE`` and ``LLAMAFILE_API_KEY``
|
||||||
|
respectively).
|
||||||
|
|
||||||
|
If an API key cannot be resolved via either method, a fake key is returned. Llamafile
|
||||||
|
does not require an API key, but the underlying OpenAI library may expect one anyway.
|
||||||
|
"""
|
||||||
|
api_base = LlamafileChatConfig._resolve_api_base(api_base)
|
||||||
|
dynamic_api_key = LlamafileChatConfig._resolve_api_key(api_key)
|
||||||
|
|
||||||
|
return api_base, dynamic_api_key
|
|
@ -0,0 +1,128 @@
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from unittest.mock import patch
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import litellm
|
||||||
|
from litellm.llms.llamafile.chat.transformation import LlamafileChatConfig
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called",
|
||||||
|
[
|
||||||
|
("user-provided-key", "secret-key", "user-provided-key", False),
|
||||||
|
(None, "secret-key", "secret-key", True),
|
||||||
|
(None, None, "fake-api-key", True),
|
||||||
|
("", "secret-key", "secret-key", True), # Empty string should fall back to secret
|
||||||
|
("", None, "fake-api-key", True), # Empty string with no secret should use the fake key
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called):
|
||||||
|
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
|
||||||
|
mock_get_secret.return_value = api_key_from_secret_manager
|
||||||
|
|
||||||
|
result = LlamafileChatConfig._resolve_api_key(input_api_key)
|
||||||
|
|
||||||
|
if secret_manager_called:
|
||||||
|
mock_get_secret.assert_called_once_with("LLAMAFILE_API_KEY")
|
||||||
|
else:
|
||||||
|
mock_get_secret.assert_not_called()
|
||||||
|
|
||||||
|
assert result == expected_api_key
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called",
|
||||||
|
[
|
||||||
|
("https://user-api.example.com", "https://secret-api.example.com", "https://user-api.example.com", False),
|
||||||
|
(None, "https://secret-api.example.com", "https://secret-api.example.com", True),
|
||||||
|
(None, None, "http://127.0.0.1:8080/v1", True),
|
||||||
|
("", "https://secret-api.example.com", "https://secret-api.example.com", True), # Empty string should fall back
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called):
|
||||||
|
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
|
||||||
|
mock_get_secret.return_value = api_base_from_secret_manager
|
||||||
|
|
||||||
|
result = LlamafileChatConfig._resolve_api_base(input_api_base)
|
||||||
|
|
||||||
|
if secret_manager_called:
|
||||||
|
mock_get_secret.assert_called_once_with("LLAMAFILE_API_BASE")
|
||||||
|
else:
|
||||||
|
mock_get_secret.assert_not_called()
|
||||||
|
|
||||||
|
assert result == expected_api_base
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"api_base, api_key, secret_base, secret_key, expected_base, expected_key",
|
||||||
|
[
|
||||||
|
# User-provided values
|
||||||
|
("https://user-api.example.com", "user-key", "https://secret-api.example.com", "secret-key", "https://user-api.example.com", "user-key"),
|
||||||
|
# Fallback to secrets
|
||||||
|
(None, None, "https://secret-api.example.com", "secret-key", "https://secret-api.example.com", "secret-key"),
|
||||||
|
# Nothing provided, use defaults
|
||||||
|
(None, None, None, None, "http://127.0.0.1:8080/v1", "fake-api-key"),
|
||||||
|
# Mixed scenarios
|
||||||
|
("https://user-api.example.com", None, None, "secret-key", "https://user-api.example.com", "secret-key"),
|
||||||
|
(None, "user-key", "https://secret-api.example.com", None, "https://secret-api.example.com", "user-key"),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, secret_key, expected_base, expected_key):
|
||||||
|
config = LlamafileChatConfig()
|
||||||
|
|
||||||
|
def fake_get_secret(key: str) -> Optional[str]:
|
||||||
|
return {
|
||||||
|
"LLAMAFILE_API_BASE": secret_base,
|
||||||
|
"LLAMAFILE_API_KEY": secret_key
|
||||||
|
}.get(key)
|
||||||
|
|
||||||
|
patch_secret = patch("litellm.llms.llamafile.chat.transformation.get_secret_str", side_effect=fake_get_secret)
|
||||||
|
patch_base = patch.object(LlamafileChatConfig, "_resolve_api_base", wraps=LlamafileChatConfig._resolve_api_base)
|
||||||
|
patch_key = patch.object(LlamafileChatConfig, "_resolve_api_key", wraps=LlamafileChatConfig._resolve_api_key)
|
||||||
|
|
||||||
|
with patch_secret as mock_secret, patch_base as mock_base, patch_key as mock_key:
|
||||||
|
result_base, result_key = config._get_openai_compatible_provider_info(api_base, api_key)
|
||||||
|
|
||||||
|
assert result_base == expected_base
|
||||||
|
assert result_key == expected_key
|
||||||
|
|
||||||
|
mock_base.assert_called_once_with(api_base)
|
||||||
|
mock_key.assert_called_once_with(api_key)
|
||||||
|
|
||||||
|
# Ensure get_secret_str was used as expected within the methods
|
||||||
|
if api_base and api_key:
|
||||||
|
mock_secret.assert_not_called()
|
||||||
|
elif api_base or api_key:
|
||||||
|
mock_secret.assert_called_once()
|
||||||
|
else:
|
||||||
|
assert mock_secret.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
def test_completion_with_custom_llamafile_model():
|
||||||
|
with patch("litellm.main.openai_chat_completions.completion") as mock_llamafile_completion_func:
|
||||||
|
mock_llamafile_completion_func.return_value = {} # Return an empty dictionary for the mocked response
|
||||||
|
|
||||||
|
provider = "llamafile"
|
||||||
|
model_name = "my-custom-test-model"
|
||||||
|
model = f"{provider}/{model_name}"
|
||||||
|
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||||
|
|
||||||
|
_ = litellm.completion(
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
max_retries=2,
|
||||||
|
max_tokens=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_llamafile_completion_func.assert_called_once()
|
||||||
|
_, call_kwargs = mock_llamafile_completion_func.call_args
|
||||||
|
assert call_kwargs.get("custom_llm_provider") == provider
|
||||||
|
assert call_kwargs.get("model") == model_name
|
||||||
|
assert call_kwargs.get("messages") == messages
|
||||||
|
assert call_kwargs.get("api_base") == "http://127.0.0.1:8080/v1"
|
||||||
|
assert call_kwargs.get("api_key") == "fake-api-key"
|
||||||
|
optional_params = call_kwargs.get("optional_params")
|
||||||
|
assert optional_params
|
||||||
|
assert optional_params.get("max_retries") == 2
|
||||||
|
assert optional_params.get("max_tokens") == 100
|
|
@ -1497,7 +1497,7 @@ HF Tests we should pass
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"provider", ["openai", "hosted_vllm", "lm_studio"]
|
"provider", ["openai", "hosted_vllm", "lm_studio", "llamafile"]
|
||||||
) # "vertex_ai",
|
) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_openai_compatible_custom_api_base(provider):
|
async def test_openai_compatible_custom_api_base(provider):
|
||||||
|
@ -1539,6 +1539,7 @@ async def test_openai_compatible_custom_api_base(provider):
|
||||||
[
|
[
|
||||||
"openai",
|
"openai",
|
||||||
"hosted_vllm",
|
"hosted_vllm",
|
||||||
|
"llamafile",
|
||||||
],
|
],
|
||||||
) # "vertex_ai",
|
) # "vertex_ai",
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
|
@ -1029,6 +1029,28 @@ def test_hosted_vllm_embedding(monkeypatch):
|
||||||
assert json_data["model"] == "jina-embeddings-v3"
|
assert json_data["model"] == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_llamafile_embedding(monkeypatch):
|
||||||
|
monkeypatch.setenv("LLAMAFILE_API_BASE", "http://localhost:8080/v1")
|
||||||
|
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||||
|
|
||||||
|
client = HTTPHandler()
|
||||||
|
with patch.object(client, "post") as mock_post:
|
||||||
|
try:
|
||||||
|
embedding(
|
||||||
|
model="llamafile/jina-embeddings-v3",
|
||||||
|
input=["Hello world"],
|
||||||
|
client=client,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
|
||||||
|
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||||
|
assert json_data["input"] == ["Hello world"]
|
||||||
|
assert json_data["model"] == "jina-embeddings-v3"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||||
async def test_lm_studio_embedding(monkeypatch, sync_mode):
|
async def test_lm_studio_embedding(monkeypatch, sync_mode):
|
||||||
|
|
|
@ -185,6 +185,16 @@ def test_get_llm_provider_hosted_vllm():
|
||||||
assert dynamic_api_key == "fake-api-key"
|
assert dynamic_api_key == "fake-api-key"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_llm_provider_llamafile():
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
model="llamafile/mistralai/mistral-7b-instruct-v0.2",
|
||||||
|
)
|
||||||
|
assert custom_llm_provider == "llamafile"
|
||||||
|
assert model == "mistralai/mistral-7b-instruct-v0.2"
|
||||||
|
assert dynamic_api_key == "fake-api-key"
|
||||||
|
assert api_base == "http://127.0.0.1:8080/v1"
|
||||||
|
|
||||||
|
|
||||||
def test_get_llm_provider_watson_text():
|
def test_get_llm_provider_watson_text():
|
||||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
model="watsonx_text/watson-text-to-speech",
|
model="watsonx_text/watson-text-to-speech",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue