forked from phoenix/litellm-mirror
Litellm dev 11 11 2024 (#6693)
* fix(__init__.py): add 'watsonx_text' as mapped llm api route Fixes https://github.com/BerriAI/litellm/issues/6663 * fix(opentelemetry.py): fix passing parallel tool calls to otel Fixes https://github.com/BerriAI/litellm/issues/6677 * refactor(test_opentelemetry_unit_tests.py): create a base set of unit tests for all logging integrations - test for parallel tool call handling reduces bugs in repo * fix(__init__.py): update provider-model mapping to include all known provider-model mappings Fixes https://github.com/BerriAI/litellm/issues/6669 * feat(anthropic): support passing document in llm api call * docs(anthropic.md): add pdf anthropic call to docs + expose new 'supports_pdf_input' function * fix(factory.py): fix linting error
This commit is contained in:
parent
b8ae08b8eb
commit
f59cb46e71
21 changed files with 533 additions and 2264 deletions
|
@ -44,3 +44,30 @@ class BaseLLMChatTest(ABC):
|
|||
messages=messages,
|
||||
)
|
||||
assert response is not None
|
||||
|
||||
@pytest.fixture
|
||||
def pdf_messages(self):
|
||||
import base64
|
||||
|
||||
import requests
|
||||
|
||||
# URL of the file
|
||||
url = "https://storage.googleapis.com/cloud-samples-data/generative-ai/pdf/2403.05530.pdf"
|
||||
|
||||
response = requests.get(url)
|
||||
file_data = response.content
|
||||
|
||||
encoded_file = base64.b64encode(file_data).decode("utf-8")
|
||||
url = f"data:application/pdf;base64,{encoded_file}"
|
||||
|
||||
image_content = [
|
||||
{"type": "text", "text": "What's this file about?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": url},
|
||||
},
|
||||
]
|
||||
|
||||
image_messages = [{"role": "user", "content": image_content}]
|
||||
|
||||
return image_messages
|
||||
|
|
|
@ -36,6 +36,7 @@ from litellm.types.llms.anthropic import AnthropicResponse
|
|||
|
||||
from litellm.llms.anthropic.common_utils import process_anthropic_headers
|
||||
from httpx import Headers
|
||||
from base_llm_unit_tests import BaseLLMChatTest
|
||||
|
||||
|
||||
def test_anthropic_completion_messages_translation():
|
||||
|
@ -624,3 +625,40 @@ def test_anthropic_tool_helper(cache_control_location):
|
|||
tool = AnthropicConfig()._map_tool_helper(tool=tool)
|
||||
|
||||
assert tool["cache_control"] == {"type": "ephemeral"}
|
||||
|
||||
|
||||
from litellm import completion
|
||||
|
||||
|
||||
class TestAnthropicCompletion(BaseLLMChatTest):
|
||||
def get_base_completion_call_args(self) -> dict:
|
||||
return {"model": "claude-3-haiku-20240307"}
|
||||
|
||||
def test_pdf_handling(self, pdf_messages):
|
||||
from litellm.llms.custom_httpx.http_handler import HTTPHandler
|
||||
from litellm.types.llms.anthropic import AnthropicMessagesDocumentParam
|
||||
import json
|
||||
|
||||
client = HTTPHandler()
|
||||
|
||||
with patch.object(client, "post", new=MagicMock()) as mock_client:
|
||||
response = completion(
|
||||
model="claude-3-5-sonnet-20241022",
|
||||
messages=pdf_messages,
|
||||
client=client,
|
||||
)
|
||||
|
||||
mock_client.assert_called_once()
|
||||
|
||||
json_data = json.loads(mock_client.call_args.kwargs["data"])
|
||||
headers = mock_client.call_args.kwargs["headers"]
|
||||
|
||||
assert headers["anthropic-beta"] == "pdfs-2024-09-25"
|
||||
|
||||
json_data["messages"][0]["role"] == "user"
|
||||
_document_validation = AnthropicMessagesDocumentParam(
|
||||
**json_data["messages"][0]["content"][1]
|
||||
)
|
||||
assert _document_validation["type"] == "document"
|
||||
assert _document_validation["source"]["media_type"] == "application/pdf"
|
||||
assert _document_validation["source"]["type"] == "base64"
|
||||
|
|
|
@ -169,3 +169,11 @@ def test_get_llm_provider_hosted_vllm():
|
|||
assert custom_llm_provider == "hosted_vllm"
|
||||
assert model == "llama-3.1-70b-instruct"
|
||||
assert dynamic_api_key == ""
|
||||
|
||||
|
||||
def test_get_llm_provider_watson_text():
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||
model="watsonx_text/watson-text-to-speech",
|
||||
)
|
||||
assert custom_llm_provider == "watsonx_text"
|
||||
assert model == "watson-text-to-speech"
|
||||
|
|
|
@ -1,11 +0,0 @@
|
|||
import os, sys, traceback
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm import get_model_list
|
||||
|
||||
print(get_model_list())
|
||||
print(get_model_list())
|
||||
# print(litellm.model_list)
|
|
@ -1,41 +0,0 @@
|
|||
# What is this?
|
||||
## Unit tests for opentelemetry integration
|
||||
|
||||
# What is this?
|
||||
## Unit test for presidio pii masking
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_opentelemetry_integration():
|
||||
"""
|
||||
Unit test to confirm the parent otel span is ended
|
||||
"""
|
||||
|
||||
parent_otel_span = MagicMock()
|
||||
litellm.callbacks = ["otel"]
|
||||
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response="Hey!",
|
||||
metadata={"litellm_parent_otel_span": parent_otel_span},
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
parent_otel_span.end.assert_called_once()
|
|
@ -943,3 +943,24 @@ def test_validate_chat_completion_user_messages(messages, expected_bool):
|
|||
## Invalid message
|
||||
with pytest.raises(Exception):
|
||||
validate_chat_completion_user_messages(messages=messages)
|
||||
|
||||
|
||||
def test_models_by_provider():
|
||||
"""
|
||||
Make sure all providers from model map are in the valid providers list
|
||||
"""
|
||||
from litellm import models_by_provider
|
||||
|
||||
providers = set()
|
||||
for k, v in litellm.model_cost.items():
|
||||
if "_" in v["litellm_provider"] and "-" in v["litellm_provider"]:
|
||||
continue
|
||||
elif k == "sample_spec":
|
||||
continue
|
||||
elif v["litellm_provider"] == "sagemaker":
|
||||
continue
|
||||
else:
|
||||
providers.add(v["litellm_provider"])
|
||||
|
||||
for provider in providers:
|
||||
assert provider in models_by_provider.keys()
|
||||
|
|
100
tests/logging_callback_tests/base_test.py
Normal file
100
tests/logging_callback_tests/base_test.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
import asyncio
|
||||
import httpx
|
||||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import litellm
|
||||
from litellm.exceptions import BadRequestError
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class BaseLoggingCallbackTest(ABC):
|
||||
"""
|
||||
Abstract base test class that enforces a common test across all test classes.
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_response_obj(self):
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
Choices,
|
||||
Message,
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
Usage,
|
||||
CompletionTokensDetailsWrapper,
|
||||
PromptTokensDetailsWrapper,
|
||||
)
|
||||
|
||||
# Create a mock response object with the structure you need
|
||||
return ModelResponse(
|
||||
id="chatcmpl-ASId3YJWagBpBskWfoNEMPFSkmrEw",
|
||||
created=1731308157,
|
||||
model="gpt-4o-mini-2024-07-18",
|
||||
object="chat.completion",
|
||||
system_fingerprint="fp_0ba0d124f1",
|
||||
choices=[
|
||||
Choices(
|
||||
finish_reason="tool_calls",
|
||||
index=0,
|
||||
message=Message(
|
||||
content=None,
|
||||
role="assistant",
|
||||
tool_calls=[
|
||||
ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
arguments='{"city": "New York"}', name="get_weather"
|
||||
),
|
||||
id="call_PngsQS5YGmIZKnswhnUOnOVb",
|
||||
type="function",
|
||||
),
|
||||
ChatCompletionMessageToolCall(
|
||||
function=Function(
|
||||
arguments='{"city": "New York"}', name="get_news"
|
||||
),
|
||||
id="call_1zsDThBu0VSK7KuY7eCcJBnq",
|
||||
type="function",
|
||||
),
|
||||
],
|
||||
function_call=None,
|
||||
),
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
completion_tokens=46,
|
||||
prompt_tokens=86,
|
||||
total_tokens=132,
|
||||
completion_tokens_details=CompletionTokensDetailsWrapper(
|
||||
accepted_prediction_tokens=0,
|
||||
audio_tokens=0,
|
||||
reasoning_tokens=0,
|
||||
rejected_prediction_tokens=0,
|
||||
text_tokens=None,
|
||||
),
|
||||
prompt_tokens_details=PromptTokensDetailsWrapper(
|
||||
audio_tokens=0, cached_tokens=0, text_tokens=None, image_tokens=None
|
||||
),
|
||||
),
|
||||
service_tier=None,
|
||||
)
|
||||
|
||||
@abstractmethod
|
||||
def test_parallel_tool_calls(self, mock_response_obj: ModelResponse):
|
||||
"""
|
||||
Check if parallel tool calls are correctly logged by Logging callback
|
||||
|
||||
Relevant issue - https://github.com/BerriAI/litellm/issues/6677
|
||||
"""
|
||||
pass
|
|
@ -0,0 +1,58 @@
|
|||
# What is this?
|
||||
## Unit tests for opentelemetry integration
|
||||
|
||||
# What is this?
|
||||
## Unit test for presidio pii masking
|
||||
import sys, os, asyncio, time, random
|
||||
from datetime import datetime
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from unittest.mock import patch, MagicMock, AsyncMock
|
||||
from base_test import BaseLoggingCallbackTest
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
class TestOpentelemetryUnitTests(BaseLoggingCallbackTest):
|
||||
def test_parallel_tool_calls(self, mock_response_obj: ModelResponse):
|
||||
tool_calls = mock_response_obj.choices[0].message.tool_calls
|
||||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.proxy._types import SpanAttributes
|
||||
|
||||
kv_pair_dict = OpenTelemetry._tool_calls_kv_pair(tool_calls)
|
||||
|
||||
assert kv_pair_dict == {
|
||||
f"{SpanAttributes.LLM_COMPLETIONS}.0.function_call.arguments": '{"city": "New York"}',
|
||||
f"{SpanAttributes.LLM_COMPLETIONS}.0.function_call.name": "get_weather",
|
||||
f"{SpanAttributes.LLM_COMPLETIONS}.1.function_call.arguments": '{"city": "New York"}',
|
||||
f"{SpanAttributes.LLM_COMPLETIONS}.1.function_call.name": "get_news",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_opentelemetry_integration(self):
|
||||
"""
|
||||
Unit test to confirm the parent otel span is ended
|
||||
"""
|
||||
|
||||
parent_otel_span = MagicMock()
|
||||
litellm.callbacks = ["otel"]
|
||||
|
||||
await litellm.acompletion(
|
||||
model="gpt-3.5-turbo",
|
||||
messages=[{"role": "user", "content": "Hello, world!"}],
|
||||
mock_response="Hey!",
|
||||
metadata={"litellm_parent_otel_span": parent_otel_span},
|
||||
)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
parent_otel_span.end.assert_called_once()
|
Loading…
Add table
Add a link
Reference in a new issue