mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
Merge branch 'main' into litellm_arize_dynamic_logging
This commit is contained in:
commit
57e5c94360
13 changed files with 225 additions and 135 deletions
|
@ -1,4 +1,3 @@
|
||||||
import json
|
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any, Optional
|
||||||
|
|
||||||
from litellm._logging import verbose_logger
|
from litellm._logging import verbose_logger
|
||||||
|
|
|
@ -129,17 +129,15 @@ def get_llm_provider( # noqa: PLR0915
|
||||||
model, custom_llm_provider
|
model, custom_llm_provider
|
||||||
)
|
)
|
||||||
|
|
||||||
if custom_llm_provider:
|
if custom_llm_provider and (
|
||||||
if (
|
model.split("/")[0] != custom_llm_provider
|
||||||
model.split("/")[0] == custom_llm_provider
|
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
||||||
): # handle scenario where model="azure/*" and custom_llm_provider="azure"
|
model = custom_llm_provider + "/" + model
|
||||||
model = model.replace("{}/".format(custom_llm_provider), "")
|
|
||||||
|
|
||||||
return model, custom_llm_provider, dynamic_api_key, api_base
|
|
||||||
|
|
||||||
if api_key and api_key.startswith("os.environ/"):
|
if api_key and api_key.startswith("os.environ/"):
|
||||||
dynamic_api_key = get_secret_str(api_key)
|
dynamic_api_key = get_secret_str(api_key)
|
||||||
# check if llm provider part of model name
|
# check if llm provider part of model name
|
||||||
|
|
||||||
if (
|
if (
|
||||||
model.split("/", 1)[0] in litellm.provider_list
|
model.split("/", 1)[0] in litellm.provider_list
|
||||||
and model.split("/", 1)[0] not in litellm.model_list_set
|
and model.split("/", 1)[0] not in litellm.model_list_set
|
||||||
|
@ -573,11 +571,11 @@ def _get_openai_compatible_provider_info( # noqa: PLR0915
|
||||||
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
dynamic_api_key = api_key or get_secret_str("GALADRIEL_API_KEY")
|
||||||
elif custom_llm_provider == "snowflake":
|
elif custom_llm_provider == "snowflake":
|
||||||
api_base = (
|
api_base = (
|
||||||
api_base
|
api_base
|
||||||
or get_secret("SNOWFLAKE_API_BASE")
|
or get_secret_str("SNOWFLAKE_API_BASE")
|
||||||
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
or f"https://{get_secret('SNOWFLAKE_ACCOUNT_ID')}.snowflakecomputing.com/api/v2/cortex/inference:complete"
|
||||||
) # type: ignore
|
) # type: ignore
|
||||||
dynamic_api_key = api_key or get_secret("SNOWFLAKE_JWT")
|
dynamic_api_key = api_key or get_secret_str("SNOWFLAKE_JWT")
|
||||||
|
|
||||||
if api_base is not None and not isinstance(api_base, str):
|
if api_base is not None and not isinstance(api_base, str):
|
||||||
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
raise Exception("api base needs to be a string. api_base={}".format(api_base))
|
||||||
|
|
|
@ -3257,6 +3257,7 @@ class StandardLoggingPayloadSetup:
|
||||||
additional_headers=None,
|
additional_headers=None,
|
||||||
litellm_overhead_time_ms=None,
|
litellm_overhead_time_ms=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
|
litellm_model_name=None,
|
||||||
)
|
)
|
||||||
if hidden_params is not None:
|
if hidden_params is not None:
|
||||||
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
for key in StandardLoggingHiddenParams.__annotations__.keys():
|
||||||
|
@ -3371,6 +3372,7 @@ def get_standard_logging_object_payload(
|
||||||
response_cost=None,
|
response_cost=None,
|
||||||
litellm_overhead_time_ms=None,
|
litellm_overhead_time_ms=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
|
litellm_model_name=None,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -3656,6 +3658,7 @@ def create_dummy_standard_logging_payload() -> StandardLoggingPayload:
|
||||||
additional_headers=None,
|
additional_headers=None,
|
||||||
litellm_overhead_time_ms=None,
|
litellm_overhead_time_ms=None,
|
||||||
batch_models=None,
|
batch_models=None,
|
||||||
|
litellm_model_name=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert numeric values to appropriate types
|
# Convert numeric values to appropriate types
|
||||||
|
|
|
@ -44,6 +44,7 @@ class ResponseMetadata:
|
||||||
"additional_headers": process_response_headers(
|
"additional_headers": process_response_headers(
|
||||||
self._get_value_from_hidden_params("additional_headers") or {}
|
self._get_value_from_hidden_params("additional_headers") or {}
|
||||||
),
|
),
|
||||||
|
"litellm_model_name": model,
|
||||||
}
|
}
|
||||||
self._update_hidden_params(new_params)
|
self._update_hidden_params(new_params)
|
||||||
|
|
||||||
|
|
|
@ -65,10 +65,13 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
# ------------
|
# ------------
|
||||||
# Setup values
|
# Setup values
|
||||||
# ------------
|
# ------------
|
||||||
|
|
||||||
dt = get_utc_datetime()
|
dt = get_utc_datetime()
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime("%H-%M")
|
||||||
model_id = deployment.get("model_info", {}).get("id")
|
model_id = deployment.get("model_info", {}).get("id")
|
||||||
rpm_key = f"{model_id}:rpm:{current_minute}"
|
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||||
|
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||||
|
|
||||||
local_result = self.router_cache.get_cache(
|
local_result = self.router_cache.get_cache(
|
||||||
key=rpm_key, local_only=True
|
key=rpm_key, local_only=True
|
||||||
) # check local result first
|
) # check local result first
|
||||||
|
@ -151,7 +154,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
dt = get_utc_datetime()
|
dt = get_utc_datetime()
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime("%H-%M")
|
||||||
model_id = deployment.get("model_info", {}).get("id")
|
model_id = deployment.get("model_info", {}).get("id")
|
||||||
rpm_key = f"{model_id}:rpm:{current_minute}"
|
deployment_name = deployment.get("litellm_params", {}).get("model")
|
||||||
|
|
||||||
|
rpm_key = f"{model_id}:{deployment_name}:rpm:{current_minute}"
|
||||||
local_result = await self.router_cache.async_get_cache(
|
local_result = await self.router_cache.async_get_cache(
|
||||||
key=rpm_key, local_only=True
|
key=rpm_key, local_only=True
|
||||||
) # check local result first
|
) # check local result first
|
||||||
|
@ -228,8 +233,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
if standard_logging_object is None:
|
if standard_logging_object is None:
|
||||||
raise ValueError("standard_logging_object not passed in.")
|
raise ValueError("standard_logging_object not passed in.")
|
||||||
model_group = standard_logging_object.get("model_group")
|
model_group = standard_logging_object.get("model_group")
|
||||||
|
model = standard_logging_object["hidden_params"].get("litellm_model_name")
|
||||||
id = standard_logging_object.get("model_id")
|
id = standard_logging_object.get("model_id")
|
||||||
if model_group is None or id is None:
|
if model_group is None or id is None or model is None:
|
||||||
return
|
return
|
||||||
elif isinstance(id, int):
|
elif isinstance(id, int):
|
||||||
id = str(id)
|
id = str(id)
|
||||||
|
@ -244,7 +250,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
"%H-%M"
|
"%H-%M"
|
||||||
) # use the same timezone regardless of system clock
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
tpm_key = f"{id}:tpm:{current_minute}"
|
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
@ -276,6 +282,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
if standard_logging_object is None:
|
if standard_logging_object is None:
|
||||||
raise ValueError("standard_logging_object not passed in.")
|
raise ValueError("standard_logging_object not passed in.")
|
||||||
model_group = standard_logging_object.get("model_group")
|
model_group = standard_logging_object.get("model_group")
|
||||||
|
model = standard_logging_object["hidden_params"]["litellm_model_name"]
|
||||||
id = standard_logging_object.get("model_id")
|
id = standard_logging_object.get("model_id")
|
||||||
if model_group is None or id is None:
|
if model_group is None or id is None:
|
||||||
return
|
return
|
||||||
|
@ -290,7 +297,7 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
"%H-%M"
|
"%H-%M"
|
||||||
) # use the same timezone regardless of system clock
|
) # use the same timezone regardless of system clock
|
||||||
|
|
||||||
tpm_key = f"{id}:tpm:{current_minute}"
|
tpm_key = f"{id}:{model}:tpm:{current_minute}"
|
||||||
# ------------
|
# ------------
|
||||||
# Update usage
|
# Update usage
|
||||||
# ------------
|
# ------------
|
||||||
|
@ -458,8 +465,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
id = m.get("model_info", {}).get(
|
id = m.get("model_info", {}).get(
|
||||||
"id"
|
"id"
|
||||||
) # a deployment should always have an 'id'. this is set in router.py
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
deployment_name = m.get("litellm_params", {}).get("model")
|
||||||
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||||
|
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||||
|
|
||||||
tpm_keys.append(tpm_key)
|
tpm_keys.append(tpm_key)
|
||||||
rpm_keys.append(rpm_key)
|
rpm_keys.append(rpm_key)
|
||||||
|
@ -576,8 +584,9 @@ class LowestTPMLoggingHandler_v2(CustomLogger):
|
||||||
id = m.get("model_info", {}).get(
|
id = m.get("model_info", {}).get(
|
||||||
"id"
|
"id"
|
||||||
) # a deployment should always have an 'id'. this is set in router.py
|
) # a deployment should always have an 'id'. this is set in router.py
|
||||||
tpm_key = "{}:tpm:{}".format(id, current_minute)
|
deployment_name = m.get("litellm_params", {}).get("model")
|
||||||
rpm_key = "{}:rpm:{}".format(id, current_minute)
|
tpm_key = "{}:{}:tpm:{}".format(id, deployment_name, current_minute)
|
||||||
|
rpm_key = "{}:{}:rpm:{}".format(id, deployment_name, current_minute)
|
||||||
|
|
||||||
tpm_keys.append(tpm_key)
|
tpm_keys.append(tpm_key)
|
||||||
rpm_keys.append(rpm_key)
|
rpm_keys.append(rpm_key)
|
||||||
|
|
|
@ -1625,13 +1625,16 @@ class StandardLoggingAdditionalHeaders(TypedDict, total=False):
|
||||||
|
|
||||||
|
|
||||||
class StandardLoggingHiddenParams(TypedDict):
|
class StandardLoggingHiddenParams(TypedDict):
|
||||||
model_id: Optional[str]
|
model_id: Optional[
|
||||||
|
str
|
||||||
|
] # id of the model in the router, separates multiple models with the same name but different credentials
|
||||||
cache_key: Optional[str]
|
cache_key: Optional[str]
|
||||||
api_base: Optional[str]
|
api_base: Optional[str]
|
||||||
response_cost: Optional[str]
|
response_cost: Optional[str]
|
||||||
litellm_overhead_time_ms: Optional[float]
|
litellm_overhead_time_ms: Optional[float]
|
||||||
additional_headers: Optional[StandardLoggingAdditionalHeaders]
|
additional_headers: Optional[StandardLoggingAdditionalHeaders]
|
||||||
batch_models: Optional[List[str]]
|
batch_models: Optional[List[str]]
|
||||||
|
litellm_model_name: Optional[str] # the model name sent to the provider by litellm
|
||||||
|
|
||||||
|
|
||||||
class StandardLoggingModelInformation(TypedDict):
|
class StandardLoggingModelInformation(TypedDict):
|
||||||
|
|
|
@ -20,7 +20,8 @@ from litellm.caching.redis_cache import RedisCache
|
||||||
|
|
||||||
@pytest.mark.parametrize("namespace", [None, "test"])
|
@pytest.mark.parametrize("namespace", [None, "test"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_cache_async_increment(namespace):
|
async def test_redis_cache_async_increment(namespace, monkeypatch):
|
||||||
|
monkeypatch.setenv("REDIS_HOST", "https://my-test-host")
|
||||||
redis_cache = RedisCache(namespace=namespace)
|
redis_cache = RedisCache(namespace=namespace)
|
||||||
# Create an AsyncMock for the Redis client
|
# Create an AsyncMock for the Redis client
|
||||||
mock_redis_instance = AsyncMock()
|
mock_redis_instance = AsyncMock()
|
||||||
|
@ -46,7 +47,8 @@ async def test_redis_cache_async_increment(namespace):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_redis_client_init_with_socket_timeout():
|
async def test_redis_client_init_with_socket_timeout(monkeypatch):
|
||||||
|
monkeypatch.setenv("REDIS_HOST", "my-fake-host")
|
||||||
redis_cache = RedisCache(socket_timeout=1.0)
|
redis_cache = RedisCache(socket_timeout=1.0)
|
||||||
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
assert redis_cache.redis_kwargs["socket_timeout"] == 1.0
|
||||||
client = redis_cache.init_async_client()
|
client = redis_cache.init_async_client()
|
||||||
|
|
|
@ -8,84 +8,24 @@ import pytest
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../../../../..")
|
0, os.path.abspath("../../../../..")
|
||||||
) # Adds the parent directory to the system path
|
) # Adds the parent directory to the system path
|
||||||
import litellm
|
from litellm.llms.azure_ai.chat.transformation import AzureAIStudioConfig
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("is_async", [True, False])
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_azure_ai_request_format(is_async):
|
async def test_get_openai_compatible_provider_info():
|
||||||
"""
|
"""
|
||||||
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||||
for both synchronous and asynchronous calls
|
for both synchronous and asynchronous calls
|
||||||
"""
|
"""
|
||||||
litellm._turn_on_debug()
|
config = AzureAIStudioConfig()
|
||||||
|
|
||||||
# Set up the test parameters
|
api_base, dynamic_api_key, custom_llm_provider = (
|
||||||
api_key = "00xxx"
|
config._get_openai_compatible_provider_info(
|
||||||
api_base = "https://my-endpoint-europe-berri-992.openai.azure.com/openai/deployments/gpt-4o-mini/chat/completions?api-version=2024-08-01-preview"
|
model="azure_ai/gpt-4o-mini",
|
||||||
model = "azure_ai/gpt-4o-mini"
|
api_base="https://my-base",
|
||||||
messages = [
|
api_key="my-key",
|
||||||
{"role": "user", "content": "hi"},
|
custom_llm_provider="azure_ai",
|
||||||
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
)
|
||||||
{"role": "user", "content": "hi"},
|
)
|
||||||
]
|
|
||||||
|
|
||||||
if is_async:
|
assert custom_llm_provider == "azure"
|
||||||
# Mock AsyncHTTPHandler.post method for async test
|
|
||||||
with patch(
|
|
||||||
"litellm.llms.custom_httpx.llm_http_handler.AsyncHTTPHandler.post"
|
|
||||||
) as mock_post:
|
|
||||||
# Set up mock response
|
|
||||||
mock_post.return_value = AsyncMock()
|
|
||||||
|
|
||||||
# Call the acompletion function
|
|
||||||
try:
|
|
||||||
await litellm.acompletion(
|
|
||||||
custom_llm_provider="azure_ai",
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# We expect an exception since we're mocking the response
|
|
||||||
pass
|
|
||||||
|
|
||||||
else:
|
|
||||||
# Mock HTTPHandler.post method for sync test
|
|
||||||
with patch(
|
|
||||||
"litellm.llms.custom_httpx.llm_http_handler.HTTPHandler.post"
|
|
||||||
) as mock_post:
|
|
||||||
# Set up mock response
|
|
||||||
mock_post.return_value = MagicMock()
|
|
||||||
|
|
||||||
# Call the completion function
|
|
||||||
try:
|
|
||||||
litellm.completion(
|
|
||||||
custom_llm_provider="azure_ai",
|
|
||||||
api_key=api_key,
|
|
||||||
api_base=api_base,
|
|
||||||
model=model,
|
|
||||||
messages=messages,
|
|
||||||
)
|
|
||||||
except Exception as e:
|
|
||||||
# We expect an exception since we're mocking the response
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Verify the request was made with the correct parameters
|
|
||||||
mock_post.assert_called_once()
|
|
||||||
call_args = mock_post.call_args
|
|
||||||
print("sync request call=", json.dumps(call_args.kwargs, indent=4, default=str))
|
|
||||||
|
|
||||||
# Check URL
|
|
||||||
assert call_args.kwargs["url"] == api_base
|
|
||||||
|
|
||||||
# Check headers
|
|
||||||
assert call_args.kwargs["headers"]["api-key"] == api_key
|
|
||||||
|
|
||||||
# Check request body
|
|
||||||
request_body = json.loads(call_args.kwargs["data"])
|
|
||||||
assert (
|
|
||||||
request_body["model"] == "gpt-4o-mini"
|
|
||||||
) # Model name should be stripped of provider prefix
|
|
||||||
assert request_body["messages"] == messages
|
|
||||||
|
|
|
@ -265,3 +265,32 @@ class TestAzureAIRerank(BaseLLMRerankTest):
|
||||||
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
"api_base": os.getenv("AZURE_AI_COHERE_API_BASE"),
|
||||||
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
"api_key": os.getenv("AZURE_AI_COHERE_API_KEY"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_azure_ai_request_format():
|
||||||
|
"""
|
||||||
|
Test that Azure AI requests are formatted correctly with the proper endpoint and parameters
|
||||||
|
for both synchronous and asynchronous calls
|
||||||
|
"""
|
||||||
|
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||||
|
|
||||||
|
litellm._turn_on_debug()
|
||||||
|
|
||||||
|
# Set up the test parameters
|
||||||
|
api_key = os.getenv("AZURE_API_KEY")
|
||||||
|
api_base = f"{os.getenv('AZURE_API_BASE')}/openai/deployments/gpt-4o/chat/completions?api-version=2024-08-01-preview"
|
||||||
|
model = "azure_ai/gpt-4o"
|
||||||
|
messages = [
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
{"role": "assistant", "content": "Hello! How can I assist you today?"},
|
||||||
|
{"role": "user", "content": "hi"},
|
||||||
|
]
|
||||||
|
|
||||||
|
await litellm.acompletion(
|
||||||
|
custom_llm_provider="azure_ai",
|
||||||
|
api_key=api_key,
|
||||||
|
api_base=api_base,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
|
@ -2,7 +2,9 @@ import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import time
|
||||||
|
from unittest.mock import patch, Mock
|
||||||
|
import opentelemetry.exporter.otlp.proto.grpc.trace_exporter
|
||||||
from litellm import Choices
|
from litellm import Choices
|
||||||
import pytest
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
@ -81,3 +83,52 @@ def test_get_arize_config_with_endpoints(mock_env_vars, monkeypatch):
|
||||||
config = ArizeLogger.get_arize_config()
|
config = ArizeLogger.get_arize_config()
|
||||||
assert config.endpoint == "grpc://test.endpoint"
|
assert config.endpoint == "grpc://test.endpoint"
|
||||||
assert config.protocol == "otlp_grpc"
|
assert config.protocol == "otlp_grpc"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(
|
||||||
|
reason="Works locally but not in CI/CD. We'll need a better way to test Arize on CI/CD"
|
||||||
|
)
|
||||||
|
def test_arize_callback():
|
||||||
|
litellm.callbacks = ["arize"]
|
||||||
|
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
|
||||||
|
os.environ["ARIZE_API_KEY"] = "test_api_key"
|
||||||
|
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
|
||||||
|
|
||||||
|
# Set the batch span processor to quickly flush after a span has been added
|
||||||
|
# This is to ensure that the span is exported before the test ends
|
||||||
|
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
|
||||||
|
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
|
||||||
|
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
|
||||||
|
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch.object(
|
||||||
|
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
|
||||||
|
"export",
|
||||||
|
new=Mock(),
|
||||||
|
) as patched_export:
|
||||||
|
litellm.completion(
|
||||||
|
model="openai/test-model",
|
||||||
|
messages=[{"role": "user", "content": "arize test content"}],
|
||||||
|
stream=False,
|
||||||
|
mock_response="hello there!",
|
||||||
|
)
|
||||||
|
|
||||||
|
time.sleep(1) # Wait for the batch span processor to flush
|
||||||
|
assert patched_export.called
|
||||||
|
finally:
|
||||||
|
# Clean up environment variables
|
||||||
|
for key in [
|
||||||
|
"ARIZE_SPACE_KEY",
|
||||||
|
"ARIZE_API_KEY",
|
||||||
|
"ARIZE_ENDPOINT",
|
||||||
|
"OTEL_BSP_MAX_QUEUE_SIZE",
|
||||||
|
"OTEL_BSP_MAX_EXPORT_BATCH_SIZE",
|
||||||
|
"OTEL_BSP_SCHEDULE_DELAY_MILLIS",
|
||||||
|
"OTEL_BSP_EXPORT_TIMEOUT_MILLIS",
|
||||||
|
]:
|
||||||
|
if key in os.environ:
|
||||||
|
del os.environ[key]
|
||||||
|
|
||||||
|
# Reset callbacks
|
||||||
|
litellm.callbacks = []
|
||||||
|
|
|
@ -216,3 +216,21 @@ def test_bedrock_invoke_anthropic():
|
||||||
)
|
)
|
||||||
assert custom_llm_provider == "bedrock"
|
assert custom_llm_provider == "bedrock"
|
||||||
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
assert model == "invoke/anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model", ["xai/grok-2-vision-latest", "grok-2-vision-latest"])
|
||||||
|
def test_xai_api_base(model):
|
||||||
|
args = {
|
||||||
|
"model": model,
|
||||||
|
"custom_llm_provider": "xai",
|
||||||
|
"api_base": None,
|
||||||
|
"api_key": "xai-my-specialkey",
|
||||||
|
"litellm_params": None,
|
||||||
|
}
|
||||||
|
model, custom_llm_provider, dynamic_api_key, api_base = litellm.get_llm_provider(
|
||||||
|
**args
|
||||||
|
)
|
||||||
|
assert custom_llm_provider == "xai"
|
||||||
|
assert model == "grok-2-vision-latest"
|
||||||
|
assert api_base == "https://api.x.ai/v1"
|
||||||
|
assert dynamic_api_key == "xai-my-specialkey"
|
||||||
|
|
|
@ -20,7 +20,7 @@ sys.path.insert(
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
from litellm.types.utils import StandardLoggingPayload
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
import pytest
|
import pytest
|
||||||
|
from litellm.types.router import DeploymentTypedDict
|
||||||
import litellm
|
import litellm
|
||||||
from litellm import Router
|
from litellm import Router
|
||||||
from litellm.caching.caching import DualCache
|
from litellm.caching.caching import DualCache
|
||||||
|
@ -47,12 +47,14 @@ def test_tpm_rpm_updated():
|
||||||
deployment_id = "1234"
|
deployment_id = "1234"
|
||||||
deployment = "azure/chatgpt-v-2"
|
deployment = "azure/chatgpt-v-2"
|
||||||
total_tokens = 50
|
total_tokens = 50
|
||||||
standard_logging_payload = create_standard_logging_payload()
|
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||||
standard_logging_payload["model_group"] = model_group
|
standard_logging_payload["model_group"] = model_group
|
||||||
standard_logging_payload["model_id"] = deployment_id
|
standard_logging_payload["model_id"] = deployment_id
|
||||||
standard_logging_payload["total_tokens"] = total_tokens
|
standard_logging_payload["total_tokens"] = total_tokens
|
||||||
|
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
|
"model": deployment,
|
||||||
"metadata": {
|
"metadata": {
|
||||||
"model_group": model_group,
|
"model_group": model_group,
|
||||||
"deployment": deployment,
|
"deployment": deployment,
|
||||||
|
@ -62,10 +64,16 @@ def test_tpm_rpm_updated():
|
||||||
"standard_logging_object": standard_logging_payload,
|
"standard_logging_object": standard_logging_payload,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
litellm_deployment_dict: DeploymentTypedDict = {
|
||||||
|
"model_name": model_group,
|
||||||
|
"litellm_params": {"model": deployment},
|
||||||
|
"model_info": {"id": deployment_id},
|
||||||
|
}
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
response_obj = {"usage": {"total_tokens": total_tokens}}
|
response_obj = {"usage": {"total_tokens": total_tokens}}
|
||||||
end_time = time.time()
|
end_time = time.time()
|
||||||
lowest_tpm_logger.pre_call_check(deployment=kwargs["litellm_params"])
|
lowest_tpm_logger.pre_call_check(deployment=litellm_deployment_dict)
|
||||||
lowest_tpm_logger.log_success_event(
|
lowest_tpm_logger.log_success_event(
|
||||||
response_obj=response_obj,
|
response_obj=response_obj,
|
||||||
kwargs=kwargs,
|
kwargs=kwargs,
|
||||||
|
@ -74,8 +82,8 @@ def test_tpm_rpm_updated():
|
||||||
)
|
)
|
||||||
dt = get_utc_datetime()
|
dt = get_utc_datetime()
|
||||||
current_minute = dt.strftime("%H-%M")
|
current_minute = dt.strftime("%H-%M")
|
||||||
tpm_count_api_key = f"{deployment_id}:tpm:{current_minute}"
|
tpm_count_api_key = f"{deployment_id}:{deployment}:tpm:{current_minute}"
|
||||||
rpm_count_api_key = f"{deployment_id}:rpm:{current_minute}"
|
rpm_count_api_key = f"{deployment_id}:{deployment}:rpm:{current_minute}"
|
||||||
|
|
||||||
print(f"tpm_count_api_key={tpm_count_api_key}")
|
print(f"tpm_count_api_key={tpm_count_api_key}")
|
||||||
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
|
assert response_obj["usage"]["total_tokens"] == test_cache.get_cache(
|
||||||
|
@ -113,6 +121,7 @@ def test_get_available_deployments():
|
||||||
standard_logging_payload["model_group"] = model_group
|
standard_logging_payload["model_group"] = model_group
|
||||||
standard_logging_payload["model_id"] = deployment_id
|
standard_logging_payload["model_id"] = deployment_id
|
||||||
standard_logging_payload["total_tokens"] = total_tokens
|
standard_logging_payload["total_tokens"] = total_tokens
|
||||||
|
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -135,10 +144,11 @@ def test_get_available_deployments():
|
||||||
## DEPLOYMENT 2 ##
|
## DEPLOYMENT 2 ##
|
||||||
total_tokens = 20
|
total_tokens = 20
|
||||||
deployment_id = "5678"
|
deployment_id = "5678"
|
||||||
standard_logging_payload = create_standard_logging_payload()
|
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||||
standard_logging_payload["model_group"] = model_group
|
standard_logging_payload["model_group"] = model_group
|
||||||
standard_logging_payload["model_id"] = deployment_id
|
standard_logging_payload["model_id"] = deployment_id
|
||||||
standard_logging_payload["total_tokens"] = total_tokens
|
standard_logging_payload["total_tokens"] = total_tokens
|
||||||
|
standard_logging_payload["hidden_params"]["litellm_model_name"] = deployment
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -209,11 +219,12 @@ def test_router_get_available_deployments():
|
||||||
print(f"router id's: {router.get_model_ids()}")
|
print(f"router id's: {router.get_model_ids()}")
|
||||||
## DEPLOYMENT 1 ##
|
## DEPLOYMENT 1 ##
|
||||||
deployment_id = 1
|
deployment_id = 1
|
||||||
standard_logging_payload = create_standard_logging_payload()
|
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||||
standard_logging_payload["model_group"] = "azure-model"
|
standard_logging_payload["model_group"] = "azure-model"
|
||||||
standard_logging_payload["model_id"] = str(deployment_id)
|
standard_logging_payload["model_id"] = str(deployment_id)
|
||||||
total_tokens = 50
|
total_tokens = 50
|
||||||
standard_logging_payload["total_tokens"] = total_tokens
|
standard_logging_payload["total_tokens"] = total_tokens
|
||||||
|
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -237,6 +248,9 @@ def test_router_get_available_deployments():
|
||||||
standard_logging_payload = create_standard_logging_payload()
|
standard_logging_payload = create_standard_logging_payload()
|
||||||
standard_logging_payload["model_group"] = "azure-model"
|
standard_logging_payload["model_group"] = "azure-model"
|
||||||
standard_logging_payload["model_id"] = str(deployment_id)
|
standard_logging_payload["model_id"] = str(deployment_id)
|
||||||
|
standard_logging_payload["hidden_params"][
|
||||||
|
"litellm_model_name"
|
||||||
|
] = "azure/gpt-35-turbo"
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -293,10 +307,11 @@ def test_router_skip_rate_limited_deployments():
|
||||||
## DEPLOYMENT 1 ##
|
## DEPLOYMENT 1 ##
|
||||||
deployment_id = 1
|
deployment_id = 1
|
||||||
total_tokens = 1439
|
total_tokens = 1439
|
||||||
standard_logging_payload = create_standard_logging_payload()
|
standard_logging_payload: StandardLoggingPayload = create_standard_logging_payload()
|
||||||
standard_logging_payload["model_group"] = "azure-model"
|
standard_logging_payload["model_group"] = "azure-model"
|
||||||
standard_logging_payload["model_id"] = str(deployment_id)
|
standard_logging_payload["model_id"] = str(deployment_id)
|
||||||
standard_logging_payload["total_tokens"] = total_tokens
|
standard_logging_payload["total_tokens"] = total_tokens
|
||||||
|
standard_logging_payload["hidden_params"]["litellm_model_name"] = "azure/gpt-turbo"
|
||||||
kwargs = {
|
kwargs = {
|
||||||
"litellm_params": {
|
"litellm_params": {
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
@ -699,3 +714,54 @@ def test_return_potential_deployments():
|
||||||
)
|
)
|
||||||
|
|
||||||
assert len(potential_deployments) == 1
|
assert len(potential_deployments) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tpm_rpm_routing_model_name_checks():
|
||||||
|
deployment = {
|
||||||
|
"model_name": "gpt-3.5-turbo",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "azure/chatgpt-v-2",
|
||||||
|
"api_key": os.getenv("AZURE_API_KEY"),
|
||||||
|
"api_base": os.getenv("AZURE_API_BASE"),
|
||||||
|
"mock_response": "Hey, how's it going?",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
router = Router(model_list=[deployment], routing_strategy="usage-based-routing-v2")
|
||||||
|
|
||||||
|
async def side_effect_pre_call_check(*args, **kwargs):
|
||||||
|
return args[0]
|
||||||
|
|
||||||
|
with patch.object(
|
||||||
|
router.lowesttpm_logger_v2,
|
||||||
|
"async_pre_call_check",
|
||||||
|
side_effect=side_effect_pre_call_check,
|
||||||
|
) as mock_object, patch.object(
|
||||||
|
router.lowesttpm_logger_v2, "async_log_success_event"
|
||||||
|
) as mock_logging_event:
|
||||||
|
response = await router.acompletion(
|
||||||
|
model="gpt-3.5-turbo", messages=[{"role": "user", "content": "Hey!"}]
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_object.assert_called()
|
||||||
|
print(f"mock_object.call_args: {mock_object.call_args[0][0]}")
|
||||||
|
assert (
|
||||||
|
mock_object.call_args[0][0]["litellm_params"]["model"]
|
||||||
|
== deployment["litellm_params"]["model"]
|
||||||
|
)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
mock_logging_event.assert_called()
|
||||||
|
|
||||||
|
print(f"mock_logging_event: {mock_logging_event.call_args.kwargs}")
|
||||||
|
standard_logging_payload: StandardLoggingPayload = (
|
||||||
|
mock_logging_event.call_args.kwargs.get("kwargs", {}).get(
|
||||||
|
"standard_logging_object"
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
assert (
|
||||||
|
standard_logging_payload["hidden_params"]["litellm_model_name"]
|
||||||
|
== "azure/chatgpt-v-2"
|
||||||
|
)
|
||||||
|
|
|
@ -15,35 +15,6 @@ import litellm
|
||||||
from litellm.types.utils import Choices
|
from litellm.types.utils import Choices
|
||||||
|
|
||||||
|
|
||||||
def test_arize_callback():
|
|
||||||
litellm.callbacks = ["arize"]
|
|
||||||
os.environ["ARIZE_SPACE_KEY"] = "test_space_key"
|
|
||||||
os.environ["ARIZE_API_KEY"] = "test_api_key"
|
|
||||||
os.environ["ARIZE_ENDPOINT"] = "https://otlp.arize.com/v1"
|
|
||||||
|
|
||||||
# Set the batch span processor to quickly flush after a span has been added
|
|
||||||
# This is to ensure that the span is exported before the test ends
|
|
||||||
os.environ["OTEL_BSP_MAX_QUEUE_SIZE"] = "1"
|
|
||||||
os.environ["OTEL_BSP_MAX_EXPORT_BATCH_SIZE"] = "1"
|
|
||||||
os.environ["OTEL_BSP_SCHEDULE_DELAY_MILLIS"] = "1"
|
|
||||||
os.environ["OTEL_BSP_EXPORT_TIMEOUT_MILLIS"] = "5"
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
opentelemetry.exporter.otlp.proto.grpc.trace_exporter.OTLPSpanExporter,
|
|
||||||
"export",
|
|
||||||
new=Mock(),
|
|
||||||
) as patched_export:
|
|
||||||
completion(
|
|
||||||
model="openai/test-model",
|
|
||||||
messages=[{"role": "user", "content": "arize test content"}],
|
|
||||||
stream=False,
|
|
||||||
mock_response="hello there!",
|
|
||||||
)
|
|
||||||
|
|
||||||
time.sleep(1) # Wait for the batch span processor to flush
|
|
||||||
assert patched_export.called
|
|
||||||
|
|
||||||
|
|
||||||
def test_arize_set_attributes():
|
def test_arize_set_attributes():
|
||||||
"""
|
"""
|
||||||
Test setting attributes for Arize
|
Test setting attributes for Arize
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue