LiteLLM Minor Fixes & Improvements (10/24/2024) (#6421)

* fix(utils.py): support passing dynamic api base to validate_environment

Returns True if just api base is required and api base is passed

* fix(litellm_pre_call_utils.py): feature flag sending client headers to llm api

Fixes https://github.com/BerriAI/litellm/issues/6410

* fix(anthropic/chat/transformation.py): return correct error message

* fix(http_handler.py): add error response text in places where we expect it

* fix(factory.py): handle base case of no non-system messages to bedrock

Fixes https://github.com/BerriAI/litellm/issues/6411

* feat(cohere/embed): Support cohere image embeddings

Closes https://github.com/BerriAI/litellm/issues/6413

* fix(__init__.py): fix linting error

* docs(supported_embedding.md): add image embedding example to docs

* feat(cohere/embed): use cohere embedding returned usage for cost calc

* build(model_prices_and_context_window.json): add embed-english-v3.0 details (image cost + 'supports_image_input' flag)

* fix(cohere_transformation.py): fix linting error

* test(test_proxy_server.py): cleanup test

* test: cleanup test

* fix: fix linting errors
This commit is contained in:
Krish Dholakia 2024-10-25 15:55:56 -07:00 committed by GitHub
parent 38708a355a
commit c03e5da41f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
23 changed files with 417 additions and 150 deletions

File diff suppressed because one or more lines are too long

View file

@ -160,3 +160,12 @@ def test_get_llm_provider_jina_ai():
assert custom_llm_provider == "openai_like"
assert api_base == "https://api.jina.ai/v1"
assert model == "jina-embeddings-v3"
def test_get_llm_provider_hosted_vllm():
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
model="hosted_vllm/llama-3.1-70b-instruct",
)
assert custom_llm_provider == "hosted_vllm"
assert model == "llama-3.1-70b-instruct"
assert dynamic_api_key == ""

View file

@ -675,3 +675,15 @@ def test_alternating_roles_e2e():
"stream": False,
}
)
def test_just_system_message():
from litellm.llms.prompt_templates.factory import _bedrock_converse_messages_pt
with pytest.raises(litellm.BadRequestError) as e:
_bedrock_converse_messages_pt(
messages=[],
model="anthropic.claude-3-sonnet-20240229-v1:0",
llm_provider="bedrock",
)
assert "bedrock requires at least one non-system message" in str(e.value)

View file

@ -225,12 +225,20 @@ def test_add_headers_to_request(litellm_key_header_name):
"litellm_key_header_name",
["x-litellm-key", None],
)
@pytest.mark.parametrize(
"forward_headers",
[True, False],
)
@mock_patch_acompletion()
def test_chat_completion_forward_headers(
mock_acompletion, client_no_auth, litellm_key_header_name
mock_acompletion, client_no_auth, litellm_key_header_name, forward_headers
):
global headers
try:
if forward_headers:
gs = getattr(litellm.proxy.proxy_server, "general_settings")
gs["forward_client_headers_to_llm_api"] = True
setattr(litellm.proxy.proxy_server, "general_settings", gs)
if litellm_key_header_name is not None:
gs = getattr(litellm.proxy.proxy_server, "general_settings")
gs["litellm_key_header_name"] = litellm_key_header_name
@ -260,23 +268,14 @@ def test_chat_completion_forward_headers(
response = client_no_auth.post(
"/v1/chat/completions", json=test_data, headers=received_headers
)
mock_acompletion.assert_called_once_with(
model="gpt-3.5-turbo",
messages=[
{"role": "user", "content": "hi"},
],
max_tokens=10,
litellm_call_id=mock.ANY,
litellm_logging_obj=mock.ANY,
request_timeout=mock.ANY,
specific_deployment=True,
metadata=mock.ANY,
proxy_server_request=mock.ANY,
headers={
if not forward_headers:
assert "headers" not in mock_acompletion.call_args.kwargs
else:
assert mock_acompletion.call_args.kwargs["headers"] == {
"x-custom-header": "Custom-Value",
"x-another-header": "Another-Value",
},
)
}
print(f"response - {response.text}")
assert response.status_code == 200
result = response.json()

View file

@ -331,6 +331,13 @@ def test_validate_environment_api_key():
), f"Missing keys={response_obj['missing_keys']}"
def test_validate_environment_api_base_dynamic():
for provider in ["ollama", "ollama_chat"]:
kv = validate_environment(provider + "/mistral", api_base="https://example.com")
assert kv["keys_in_environment"]
assert kv["missing_keys"] == []
@mock.patch.dict(os.environ, {"OLLAMA_API_BASE": "foo"}, clear=True)
def test_validate_environment_ollama():
for provider in ["ollama", "ollama_chat"]: