mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
Add Llamafile chat config and tests
This commit is contained in:
parent
afaa3da3dd
commit
17a40696de
5 changed files with 208 additions and 1 deletions
46
litellm/llms/llamafile/chat/transformation.py
Normal file
46
litellm/llms/llamafile/chat/transformation.py
Normal file
|
@ -0,0 +1,46 @@
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
from ...openai.chat.gpt_transformation import OpenAIGPTConfig
|
||||
|
||||
|
||||
class LlamafileChatConfig(OpenAIGPTConfig):
|
||||
"""LlamafileChatConfig is used to provide configuration for the LlamaFile's chat API."""
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_key(api_key: Optional[str] = None) -> str:
|
||||
"""Attempt to ensure that the API key is set, preferring the user-provided key
|
||||
over the secret manager key (``LLAMAFILE_API_KEY``).
|
||||
|
||||
If both are None, a fake API key is returned.
|
||||
"""
|
||||
return api_key or get_secret_str("LLAMAFILE_API_KEY") or "fake-api-key" # llamafile does not require an API key
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_base(api_base: Optional[str] = None) -> Optional[str]:
|
||||
"""Attempt to ensure that the API base is set, preferring the user-provided key
|
||||
over the secret manager key (``LLAMAFILE_API_BASE``).
|
||||
|
||||
If both are None, a default Llamafile server URL is returned.
|
||||
See: https://github.com/Mozilla-Ocho/llamafile/blob/bd1bbe9aabb1ee12dbdcafa8936db443c571eb9d/README.md#L61
|
||||
"""
|
||||
return api_base or get_secret_str("LLAMAFILE_API_BASE") or "http://127.0.0.1:8080/v1" # type: ignore
|
||||
|
||||
|
||||
def _get_openai_compatible_provider_info(
|
||||
self,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str]
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Attempts to ensure that the API base and key are set, preferring user-provided values,
|
||||
before falling back to secret manager values (``LLAMAFILE_API_BASE`` and ``LLAMAFILE_API_KEY``
|
||||
respectively).
|
||||
|
||||
If an API key cannot be resolved via either method, a fake key is returned. Llamafile
|
||||
does not require an API key, but the underlying OpenAI library may expect one anyway.
|
||||
"""
|
||||
api_base = LlamafileChatConfig._resolve_api_base(api_base)
|
||||
dynamic_api_key = LlamafileChatConfig._resolve_api_key(api_key)
|
||||
|
||||
return api_base, dynamic_api_key
|
|
@ -0,0 +1,128 @@
|
|||
from typing import Optional
|
||||
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
import litellm
|
||||
from litellm.llms.llamafile.chat.transformation import LlamafileChatConfig
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called",
|
||||
[
|
||||
("user-provided-key", "secret-key", "user-provided-key", False),
|
||||
(None, "secret-key", "secret-key", True),
|
||||
(None, None, "fake-api-key", True),
|
||||
("", "secret-key", "secret-key", True), # Empty string should fall back to secret
|
||||
("", None, "fake-api-key", True), # Empty string with no secret should use the fake key
|
||||
]
|
||||
)
|
||||
def test_resolve_api_key(input_api_key, api_key_from_secret_manager, expected_api_key, secret_manager_called):
|
||||
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
|
||||
mock_get_secret.return_value = api_key_from_secret_manager
|
||||
|
||||
result = LlamafileChatConfig._resolve_api_key(input_api_key)
|
||||
|
||||
if secret_manager_called:
|
||||
mock_get_secret.assert_called_once_with("LLAMAFILE_API_KEY")
|
||||
else:
|
||||
mock_get_secret.assert_not_called()
|
||||
|
||||
assert result == expected_api_key
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called",
|
||||
[
|
||||
("https://user-api.example.com", "https://secret-api.example.com", "https://user-api.example.com", False),
|
||||
(None, "https://secret-api.example.com", "https://secret-api.example.com", True),
|
||||
(None, None, "http://127.0.0.1:8080/v1", True),
|
||||
("", "https://secret-api.example.com", "https://secret-api.example.com", True), # Empty string should fall back
|
||||
]
|
||||
)
|
||||
def test_resolve_api_base(input_api_base, api_base_from_secret_manager, expected_api_base, secret_manager_called):
|
||||
with patch("litellm.llms.llamafile.chat.transformation.get_secret_str") as mock_get_secret:
|
||||
mock_get_secret.return_value = api_base_from_secret_manager
|
||||
|
||||
result = LlamafileChatConfig._resolve_api_base(input_api_base)
|
||||
|
||||
if secret_manager_called:
|
||||
mock_get_secret.assert_called_once_with("LLAMAFILE_API_BASE")
|
||||
else:
|
||||
mock_get_secret.assert_not_called()
|
||||
|
||||
assert result == expected_api_base
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"api_base, api_key, secret_base, secret_key, expected_base, expected_key",
|
||||
[
|
||||
# User-provided values
|
||||
("https://user-api.example.com", "user-key", "https://secret-api.example.com", "secret-key", "https://user-api.example.com", "user-key"),
|
||||
# Fallback to secrets
|
||||
(None, None, "https://secret-api.example.com", "secret-key", "https://secret-api.example.com", "secret-key"),
|
||||
# Nothing provided, use defaults
|
||||
(None, None, None, None, "http://127.0.0.1:8080/v1", "fake-api-key"),
|
||||
# Mixed scenarios
|
||||
("https://user-api.example.com", None, None, "secret-key", "https://user-api.example.com", "secret-key"),
|
||||
(None, "user-key", "https://secret-api.example.com", None, "https://secret-api.example.com", "user-key"),
|
||||
]
|
||||
)
|
||||
def test_get_openai_compatible_provider_info(api_base, api_key, secret_base, secret_key, expected_base, expected_key):
|
||||
config = LlamafileChatConfig()
|
||||
|
||||
def fake_get_secret(key: str) -> Optional[str]:
|
||||
return {
|
||||
"LLAMAFILE_API_BASE": secret_base,
|
||||
"LLAMAFILE_API_KEY": secret_key
|
||||
}.get(key)
|
||||
|
||||
patch_secret = patch("litellm.llms.llamafile.chat.transformation.get_secret_str", side_effect=fake_get_secret)
|
||||
patch_base = patch.object(LlamafileChatConfig, "_resolve_api_base", wraps=LlamafileChatConfig._resolve_api_base)
|
||||
patch_key = patch.object(LlamafileChatConfig, "_resolve_api_key", wraps=LlamafileChatConfig._resolve_api_key)
|
||||
|
||||
with patch_secret as mock_secret, patch_base as mock_base, patch_key as mock_key:
|
||||
result_base, result_key = config._get_openai_compatible_provider_info(api_base, api_key)
|
||||
|
||||
assert result_base == expected_base
|
||||
assert result_key == expected_key
|
||||
|
||||
mock_base.assert_called_once_with(api_base)
|
||||
mock_key.assert_called_once_with(api_key)
|
||||
|
||||
# Ensure get_secret_str was used as expected within the methods
|
||||
if api_base and api_key:
|
||||
mock_secret.assert_not_called()
|
||||
elif api_base or api_key:
|
||||
mock_secret.assert_called_once()
|
||||
else:
|
||||
assert mock_secret.call_count == 2
|
||||
|
||||
|
||||
def test_completion_with_custom_llamafile_model():
|
||||
with patch("litellm.main.openai_chat_completions.completion") as mock_llamafile_completion_func:
|
||||
mock_llamafile_completion_func.return_value = {} # Return an empty dictionary for the mocked response
|
||||
|
||||
provider = "llamafile"
|
||||
model_name = "my-custom-test-model"
|
||||
model = f"{provider}/{model_name}"
|
||||
messages = [{"role": "user", "content": "Hey, how's it going?"}]
|
||||
|
||||
_ = litellm.completion(
|
||||
model=model,
|
||||
messages=messages,
|
||||
max_retries=2,
|
||||
max_tokens=100,
|
||||
)
|
||||
|
||||
mock_llamafile_completion_func.assert_called_once()
|
||||
_, call_kwargs = mock_llamafile_completion_func.call_args
|
||||
assert call_kwargs.get("custom_llm_provider") == provider
|
||||
assert call_kwargs.get("model") == model_name
|
||||
assert call_kwargs.get("messages") == messages
|
||||
assert call_kwargs.get("api_base") == "http://127.0.0.1:8080/v1"
|
||||
assert call_kwargs.get("api_key") == "fake-api-key"
|
||||
optional_params = call_kwargs.get("optional_params")
|
||||
assert optional_params
|
||||
assert optional_params.get("max_retries") == 2
|
||||
assert optional_params.get("max_tokens") == 100
|
|
@ -1497,7 +1497,7 @@ HF Tests we should pass
|
|||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"provider", ["openai", "hosted_vllm", "lm_studio"]
|
||||
"provider", ["openai", "hosted_vllm", "lm_studio", "llamafile"]
|
||||
) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
async def test_openai_compatible_custom_api_base(provider):
|
||||
|
@ -1539,6 +1539,7 @@ async def test_openai_compatible_custom_api_base(provider):
|
|||
[
|
||||
"openai",
|
||||
"hosted_vllm",
|
||||
"llamafile",
|
||||
],
|
||||
) # "vertex_ai",
|
||||
@pytest.mark.asyncio
|
||||
|
|
|
@ -1029,6 +1029,28 @@ def test_hosted_vllm_embedding(monkeypatch):
|
|||
assert json_data["model"] == "jina-embeddings-v3"
|
||||
|
||||
|
||||
def test_llamafile_embedding(monkeypatch):
|
||||
monkeypatch.setenv("LLAMAFILE_API_BASE", "http://localhost:8080/v1")
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
|
||||
client = HTTPHandler()
|
||||
with patch.object(client, "post") as mock_post:
|
||||
try:
|
||||
embedding(
|
||||
model="llamafile/jina-embeddings-v3",
|
||||
input=["Hello world"],
|
||||
client=client,
|
||||
)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
|
||||
mock_post.assert_called_once()
|
||||
|
||||
json_data = json.loads(mock_post.call_args.kwargs["data"])
|
||||
assert json_data["input"] == ["Hello world"]
|
||||
assert json_data["model"] == "jina-embeddings-v3"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
async def test_lm_studio_embedding(monkeypatch, sync_mode):
|
||||
|
|
|
@ -185,6 +185,16 @@ def test_get_llm_provider_hosted_vllm():
|
|||
assert dynamic_api_key == "fake-api-key"
|
||||
|
||||
|
||||
def test_get_llm_provider_llamafile():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="llamafile/mistralai/mistral-7b-instruct-v0.2",
|
||||
)
|
||||
assert custom_llm_provider == "llamafile"
|
||||
assert model == "mistralai/mistral-7b-instruct-v0.2"
|
||||
assert dynamic_api_key == "fake-api-key"
|
||||
assert api_base == "http://127.0.0.1:8080/v1"
|
||||
|
||||
|
||||
def test_get_llm_provider_watson_text():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="watsonx_text/watson-text-to-speech",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue