Merge branch 'main' into litellm_gemini_refactoring

This commit is contained in:
Krish Dholakia 2024-06-17 19:50:56 -07:00 committed by GitHub
commit 63216f42b8
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
27 changed files with 335 additions and 182 deletions

View file

@ -2,6 +2,15 @@ import Image from '@theme/IdealImage';
# Athina
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Athina](https://athina.ai/) is an evaluation framework and production monitoring platform for your LLM-powered app. Athina is designed to enhance the performance and reliability of AI applications through real-time monitoring, granular analytics, and plug-and-play evaluations.
<Image img={require('../../img/athina_dashboard.png')} />

View file

@ -1,5 +1,14 @@
# Greenscale - Track LLM Spend and Responsible Usage
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Greenscale](https://greenscale.ai/) is a production monitoring platform for your LLM-powered app that provides you granular key insights into your GenAI spending and responsible usage. Greenscale only captures metadata to minimize the exposure risk of personally identifiable information (PII).
## Getting Started

View file

@ -1,4 +1,13 @@
# Helicone Tutorial
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Helicone](https://helicone.ai/) is an open source observability platform that proxies your OpenAI traffic and provides you key insights into your spend, latency and usage.
## Use Helicone to log requests across all LLM Providers (OpenAI, Azure, Anthropic, Cohere, Replicate, PaLM)

View file

@ -1,6 +1,6 @@
import Image from '@theme/IdealImage';
# Langfuse - Logging LLM Input/Output
# 🔥 Langfuse - Logging LLM Input/Output
LangFuse is open Source Observability & Analytics for LLM Apps
Detailed production traces and a granular view on quality, cost and latency

View file

@ -1,6 +1,16 @@
import Image from '@theme/IdealImage';
# Langsmith - Logging LLM Input/Output
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
An all-in-one developer platform for every step of the application lifecycle
https://smith.langchain.com/

View file

@ -1,6 +1,6 @@
import Image from '@theme/IdealImage';
# Logfire - Logging LLM Input/Output
# 🔥 Logfire - Logging LLM Input/Output
Logfire is open Source Observability & Analytics for LLM Apps
Detailed production traces and a granular view on quality, cost and latency
@ -14,10 +14,14 @@ join our [discord](https://discord.gg/wuPM9dRgDw)
## Pre-Requisites
Ensure you have run `pip install logfire` for this integration
Ensure you have installed the following packages to use this integration
```shell
pip install logfire litellm
pip install litellm
pip install opentelemetry-api==1.25.0
pip install opentelemetry-sdk==1.25.0
pip install opentelemetry-exporter-otlp==1.25.0
```
## Quick Start
@ -25,8 +29,7 @@ pip install logfire litellm
Get your Logfire token from [Logfire](https://logfire.pydantic.dev/)
```python
litellm.success_callback = ["logfire"]
litellm.failure_callback = ["logfire"] # logs errors to logfire
litellm.callbacks = ["logfire"]
```
```python

View file

@ -1,5 +1,13 @@
# Lunary - Logging and tracing LLM input/output
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Lunary](https://lunary.ai/) is an open-source AI developer platform providing observability, prompt management, and evaluation tools for AI developers.
<video controls width='900' >

View file

@ -2,6 +2,15 @@ import Image from '@theme/IdealImage';
# Promptlayer Tutorial
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
Promptlayer is a platform for prompt engineers. Log OpenAI requests. Search usage history. Track performance. Visually manage prompt templates.
<Image img={require('../../img/promptlayer.png')} />

View file

@ -1,5 +1,14 @@
import Image from '@theme/IdealImage';
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
# Sentry - Log LLM Exceptions
[Sentry](https://sentry.io/) provides error monitoring for production. LiteLLM can add breadcrumbs and send exceptions to Sentry with this integration

View file

@ -1,4 +1,12 @@
# Supabase Tutorial
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
[Supabase](https://supabase.com/) is an open source Firebase alternative.
Start your project with a Postgres database, Authentication, instant APIs, Edge Functions, Realtime subscriptions, Storage, and Vector embeddings.

View file

@ -1,6 +1,16 @@
import Image from '@theme/IdealImage';
# Weights & Biases - Logging LLM Input/Output
:::tip
This is community maintained, Please make an issue if you run into a bug
https://github.com/BerriAI/litellm
:::
Weights & Biases helps AI developers build better models faster https://wandb.ai
<Image img={require('../../img/wandb.png')} />

View file

@ -172,10 +172,8 @@ const sidebars = {
"proxy/custom_pricing",
"routing",
"scheduler",
"rules",
"set_keys",
"budget_manager",
"contributing",
"secret",
"completion/token_usage",
"load_test",
@ -183,11 +181,11 @@ const sidebars = {
type: "category",
label: "Logging & Observability",
items: [
"observability/langfuse_integration",
"observability/logfire_integration",
"debugging/local_debugging",
"observability/raw_request_response",
"observability/callbacks",
"observability/custom_callback",
"observability/langfuse_integration",
"observability/sentry",
"observability/lago",
"observability/openmeter",
@ -233,6 +231,8 @@ const sidebars = {
label: "Extras",
items: [
"extras/contributing",
"contributing",
"rules",
"proxy_server",
{
type: "category",

View file

@ -37,7 +37,7 @@ input_callback: List[Union[str, Callable]] = []
success_callback: List[Union[str, Callable]] = []
failure_callback: List[Union[str, Callable]] = []
service_callback: List[Union[str, Callable]] = []
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter"]
_custom_logger_compatible_callbacks_literal = Literal["lago", "openmeter", "logfire"]
callbacks: List[Union[Callable, _custom_logger_compatible_callbacks_literal]] = []
_langfuse_default_tags: Optional[
List[

View file

@ -23,8 +23,12 @@ class JsonFormatter(Formatter):
super(JsonFormatter, self).__init__()
def format(self, record):
json_record = {}
json_record["message"] = record.getMessage()
json_record = {
"message": record.getMessage(),
"level": record.levelname,
"timestamp": self.formatTime(record, self.datefmt),
}
return json.dumps(json_record)

View file

@ -1,20 +1,21 @@
import os
from dataclasses import dataclass
from datetime import datetime
import litellm
from litellm.integrations.custom_logger import CustomLogger
from litellm._logging import verbose_logger
from litellm.types.services import ServiceLoggerPayload
from functools import wraps
from typing import Union, Optional, TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Optional, Union
import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.services import ServiceLoggerPayload
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
from litellm.proxy._types import (
ManagementEndpointLoggingPayload as _ManagementEndpointLoggingPayload,
)
from litellm.proxy.proxy_server import UserAPIKeyAuth as _UserAPIKeyAuth
Span = _Span
UserAPIKeyAuth = _UserAPIKeyAuth
@ -107,8 +108,9 @@ class OpenTelemetry(CustomLogger):
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
):
from opentelemetry import trace
from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
_start_time_ns = start_time
@ -145,8 +147,9 @@ class OpenTelemetry(CustomLogger):
start_time: Optional[datetime] = None,
end_time: Optional[datetime] = None,
):
from opentelemetry import trace
from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
_start_time_ns = start_time
@ -179,8 +182,8 @@ class OpenTelemetry(CustomLogger):
async def async_post_call_failure_hook(
self, original_exception: Exception, user_api_key_dict: UserAPIKeyAuth
):
from opentelemetry.trace import Status, StatusCode
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
parent_otel_span = user_api_key_dict.parent_otel_span
if parent_otel_span is not None:
@ -202,8 +205,8 @@ class OpenTelemetry(CustomLogger):
parent_otel_span.end(end_time=self._to_ns(datetime.now()))
def _handle_sucess(self, kwargs, response_obj, start_time, end_time):
from opentelemetry.trace import Status, StatusCode
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
verbose_logger.debug(
"OpenTelemetry Logger: Logging kwargs: %s, OTEL config settings=%s",
@ -253,9 +256,10 @@ class OpenTelemetry(CustomLogger):
span.end(end_time=self._to_ns(end_time))
def set_tools_attributes(self, span: Span, tools):
from litellm.proxy._types import SpanAttributes
import json
from litellm.proxy._types import SpanAttributes
if not tools:
return
@ -320,7 +324,7 @@ class OpenTelemetry(CustomLogger):
)
span.set_attribute(
SpanAttributes.LLM_IS_STREAMING, optional_params.get("stream", False)
SpanAttributes.LLM_IS_STREAMING, str(optional_params.get("stream", False))
)
if optional_params.get("tools"):
@ -439,7 +443,7 @@ class OpenTelemetry(CustomLogger):
#############################################
########## LLM Response Attributes ##########
#############################################
if _raw_response:
if _raw_response and isinstance(_raw_response, str):
# cast sr -> dict
import json
@ -478,10 +482,10 @@ class OpenTelemetry(CustomLogger):
return _parent_context
def _get_span_context(self, kwargs):
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import (
TraceContextTextMapPropagator,
)
from opentelemetry import trace
litellm_params = kwargs.get("litellm_params", {}) or {}
proxy_server_request = litellm_params.get("proxy_server_request", {}) or {}
@ -505,17 +509,17 @@ class OpenTelemetry(CustomLogger):
return TraceContextTextMapPropagator().extract(carrier=carrier), None
def _get_span_processor(self):
from opentelemetry.sdk.trace.export import (
SpanExporter,
SimpleSpanProcessor,
BatchSpanProcessor,
ConsoleSpanExporter,
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterGRPC,
)
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterHTTP,
)
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import (
OTLPSpanExporter as OTLPSpanExporterGRPC,
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
ConsoleSpanExporter,
SimpleSpanProcessor,
SpanExporter,
)
verbose_logger.debug(
@ -574,8 +578,9 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from opentelemetry import trace
from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
_start_time_ns = logging_payload.start_time
@ -619,8 +624,9 @@ class OpenTelemetry(CustomLogger):
logging_payload: ManagementEndpointLoggingPayload,
parent_otel_span: Optional[Span] = None,
):
from opentelemetry import trace
from datetime import datetime
from opentelemetry import trace
from opentelemetry.trace import Status, StatusCode
_start_time_ns = logging_payload.start_time

View file

@ -10,7 +10,7 @@ import sys
import time
import traceback
import uuid
from typing import Callable, Optional
from typing import Any, Callable, Dict, List, Optional
import litellm
from litellm import (
@ -72,6 +72,8 @@ from ..integrations.supabase import Supabase
from ..integrations.traceloop import TraceloopLogger
from ..integrations.weights_biases import WeightsBiasesLogger
_in_memory_loggers: List[Any] = []
class Logging:
global supabaseClient, liteDebuggerClient, promptLayerLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, capture_exception, add_breadcrumb, lunaryLogger, logfireLogger, prometheusLogger, slack_app
@ -1612,6 +1614,7 @@ class Logging:
level=LogfireLevel.ERROR.value,
print_verbose=print_verbose,
)
except Exception as e:
print_verbose(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while failure logging with integrations {str(e)}"
@ -1786,6 +1789,37 @@ def _init_custom_logger_compatible_class(
logging_integration: litellm._custom_logger_compatible_callbacks_literal,
) -> Callable:
if logging_integration == "lago":
return LagoLogger() # type: ignore
for callback in _in_memory_loggers:
if isinstance(callback, LagoLogger):
return callback # type: ignore
lago_logger = LagoLogger()
_in_memory_loggers.append(lago_logger)
return lago_logger # type: ignore
elif logging_integration == "openmeter":
return OpenMeterLogger() # type: ignore
for callback in _in_memory_loggers:
if isinstance(callback, OpenMeterLogger):
return callback # type: ignore
_openmeter_logger = OpenMeterLogger()
_in_memory_loggers.append(_openmeter_logger)
return _openmeter_logger # type: ignore
elif logging_integration == "logfire":
if "LOGFIRE_TOKEN" not in os.environ:
raise ValueError("LOGFIRE_TOKEN not found in environment variables")
from litellm.integrations.opentelemetry import (
OpenTelemetry,
OpenTelemetryConfig,
)
otel_config = OpenTelemetryConfig(
exporter="otlp_http",
endpoint="https://logfire-api.pydantic.dev/v1/traces",
headers=f"Authorization={os.getenv('LOGFIRE_TOKEN')}",
)
for callback in _in_memory_loggers:
if isinstance(callback, OpenTelemetry):
return callback # type: ignore
_otel_logger = OpenTelemetry(config=otel_config)
_in_memory_loggers.append(_otel_logger)
return _otel_logger # type: ignore

View file

@ -107,19 +107,17 @@ from .llms.databricks import DatabricksChatCompletion
from .llms.huggingface_restapi import Huggingface
from .llms.openai import OpenAIChatCompletion, OpenAITextCompletion
from .llms.predibase import PredibaseChatCompletion
from .llms.bedrock_httpx import BedrockLLM, BedrockConverseLLM
from .llms.vertex_httpx import VertexLLM
from .llms.triton import TritonChatCompletion
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
map_system_message_pt,
prompt_factory,
)
from .llms.text_completion_codestral import CodestralTextCompletion
from .llms.triton import TritonChatCompletion
from .llms.vertex_httpx import VertexLLM
from .types.llms.openai import HttpxBinaryResponseContent
from .types.utils import ChatCompletionMessageToolCall
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
@ -431,6 +429,7 @@ def mock_completion(
messages: List,
stream: Optional[bool] = False,
mock_response: Union[str, Exception] = "This is a mock request",
mock_tool_calls: Optional[List] = None,
logging=None,
custom_llm_provider=None,
**kwargs,
@ -499,6 +498,12 @@ def mock_completion(
model_response["created"] = int(time.time())
model_response["model"] = model
if mock_tool_calls:
model_response["choices"][0]["message"]["tool_calls"] = [
ChatCompletionMessageToolCall(**tool_call)
for tool_call in mock_tool_calls
]
setattr(
model_response,
"usage",
@ -612,6 +617,7 @@ def completion(
args = locals()
api_base = kwargs.get("api_base", None)
mock_response = kwargs.get("mock_response", None)
mock_tool_calls = kwargs.get("mock_tool_calls", None)
force_timeout = kwargs.get("force_timeout", 600) ## deprecated
logger_fn = kwargs.get("logger_fn", None)
verbose = kwargs.get("verbose", False)
@ -930,12 +936,13 @@ def completion(
litellm_params=litellm_params,
custom_llm_provider=custom_llm_provider,
)
if mock_response:
if mock_response or mock_tool_calls:
return mock_completion(
model,
messages,
stream=stream,
mock_response=mock_response,
mock_tool_calls=mock_tool_calls,
logging=logging,
acompletion=acompletion,
mock_delay=kwargs.get("mock_delay", None),

View file

@ -1,7 +1,8 @@
import json
import logging
from logging import Formatter
import os
from logging import Formatter
from litellm import json_logs
# Set default log level to INFO
@ -14,8 +15,11 @@ class JsonFormatter(Formatter):
super(JsonFormatter, self).__init__()
def format(self, record):
json_record = {}
json_record["message"] = record.getMessage()
json_record = {
"message": record.getMessage(),
"level": record.levelname,
"timestamp": self.formatTime(record, self.datefmt),
}
return json.dumps(json_record)

View file

@ -79,6 +79,7 @@ litellm_settings:
success_callback: ["langfuse"]
failure_callback: ["langfuse"]
cache: true
json_logs: true
general_settings:
alerting: ["slack"]

View file

@ -1,13 +1,17 @@
from pydantic import BaseModel, Extra, Field, model_validator, Json, ConfigDict
from dataclasses import fields
import enum
from typing import Optional, List, Union, Dict, Literal, Any, TypedDict, TYPE_CHECKING
import json
import os
import sys
import uuid
from dataclasses import fields
from datetime import datetime
import uuid, json, sys, os
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, TypedDict, Union
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
from typing_extensions import Annotated
from litellm.types.router import UpdateRouterConfig
from litellm.types.utils import ProviderField
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
@ -283,12 +287,16 @@ class LiteLLMRoutes(enum.Enum):
"/metrics",
]
internal_user_routes: List = [
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
] + spend_tracking_routes
internal_user_routes: List = (
[
"/key/generate",
"/key/update",
"/key/delete",
"/key/info",
]
+ spend_tracking_routes
+ sso_only_routes
)
# class LiteLLMAllowedRoutes(LiteLLMBase):

View file

@ -7,59 +7,56 @@ Returns a UserAPIKeyAuth object if the API key is valid
"""
import asyncio
import json
import secrets
import traceback
from datetime import datetime, timedelta, timezone
from typing import Optional
import secrets
from uuid import uuid4
import fastapi
from fastapi import Request
from pydantic import BaseModel
import litellm
import traceback
import asyncio
from fastapi import (
FastAPI,
Request,
HTTPException,
status,
Path,
Depends,
Header,
Response,
Form,
UploadFile,
FastAPI,
File,
Form,
Header,
HTTPException,
Path,
Request,
Response,
UploadFile,
status,
)
from fastapi.responses import (
StreamingResponse,
FileResponse,
ORJSONResponse,
JSONResponse,
)
from fastapi.openapi.utils import get_openapi
from fastapi.responses import RedirectResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.staticfiles import StaticFiles
from fastapi.openapi.utils import get_openapi
from fastapi.responses import (
FileResponse,
JSONResponse,
ORJSONResponse,
RedirectResponse,
StreamingResponse,
)
from fastapi.security.api_key import APIKeyHeader
from litellm.proxy._types import *
from litellm._logging import verbose_logger, verbose_proxy_logger
from fastapi.staticfiles import StaticFiles
from pydantic import BaseModel
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.proxy._types import *
from litellm.proxy.auth.auth_checks import (
allowed_routes_check,
common_checks,
get_actual_routes,
get_end_user_object,
get_org_object,
get_team_object,
get_user_object,
allowed_routes_check,
get_actual_routes,
log_to_opentelemetry,
)
from litellm.proxy.utils import _to_ns
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
from litellm.proxy.utils import _to_ns
api_key_header = APIKeyHeader(
name="Authorization", auto_error=False, description="Bearer token"
@ -88,20 +85,20 @@ async def user_api_key_auth(
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
litellm_proxy_admin_name,
allowed_routes_check,
common_checks,
master_key,
prisma_client,
llm_model_list,
user_custom_auth,
custom_db_client,
general_settings,
proxy_logging_obj,
open_telemetry_logger,
user_api_key_cache,
jwt_handler,
allowed_routes_check,
get_actual_routes,
jwt_handler,
litellm_proxy_admin_name,
llm_model_list,
master_key,
open_telemetry_logger,
prisma_client,
proxy_logging_obj,
user_api_key_cache,
user_custom_auth,
)
try:
@ -1004,7 +1001,7 @@ async def user_api_key_auth(
):
pass
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY:
elif _user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value:
if route in LiteLLMRoutes.openai_routes.value:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
@ -1031,7 +1028,7 @@ async def user_api_key_auth(
detail=f"user not allowed to access this route, role= {_user_role}. Trying to access: {route}",
)
elif (
_user_role == LitellmUserRoles.INTERNAL_USER
_user_role == LitellmUserRoles.INTERNAL_USER.value
and route in LiteLLMRoutes.internal_user_routes.value
):
pass
@ -1059,6 +1056,7 @@ async def user_api_key_auth(
# this token is only used for managing the ui
allowed_routes = [
"/sso",
"/sso/get/logout_url",
"/login",
"/key/generate",
"/key/update",
@ -1144,8 +1142,8 @@ async def user_api_key_auth(
raise Exception()
except Exception as e:
verbose_proxy_logger.error(
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}".format(
str(e)
"litellm.proxy.proxy_server.user_api_key_auth(): Exception occured - {}\n{}".format(
str(e), traceback.format_exc()
)
)
@ -1156,7 +1154,6 @@ async def user_api_key_auth(
user_api_key_dict=UserAPIKeyAuth(parent_otel_span=parent_otel_span),
)
verbose_proxy_logger.debug(traceback.format_exc())
if isinstance(e, litellm.BudgetExceededError):
raise ProxyException(
message=e.message, type="auth_error", param=None, code=400

View file

@ -24,9 +24,9 @@ general_settings:
litellm_settings:
success_callback: ["prometheus"]
callbacks: ["otel"]
failure_callback: ["prometheus"]
store_audit_logs: true
turn_off_message_logging: true
redact_messages_in_exceptions: True
enforced_params:
- user

View file

@ -12,6 +12,7 @@ import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import os
from unittest.mock import MagicMock, patch

View file

@ -0,0 +1,54 @@
from litellm.types.files import (
FILE_EXTENSIONS,
FILE_MIME_TYPES,
FileType,
get_file_extension_from_mime_type,
get_file_type_from_extension,
get_file_extension_for_file_type,
get_file_mime_type_for_file_type,
get_file_mime_type_from_extension,
)
import pytest
class TestFileConsts:
def test_all_file_types_have_extensions(self):
for file_type in FileType:
assert file_type in FILE_EXTENSIONS.keys()
def test_all_file_types_have_mime_types(self):
for file_type in FileType:
assert file_type in FILE_MIME_TYPES.keys()
def test_get_file_extension_from_mime_type(self):
assert get_file_extension_from_mime_type("audio/aac") == "aac"
assert get_file_extension_from_mime_type("application/pdf") == "pdf"
with pytest.raises(ValueError):
get_file_extension_from_mime_type("application/unknown")
def test_get_file_type_from_extension(self):
assert get_file_type_from_extension("aac") == FileType.AAC
assert get_file_type_from_extension("pdf") == FileType.PDF
with pytest.raises(ValueError):
get_file_type_from_extension("unknown")
def test_get_file_extension_for_file_type(self):
assert get_file_extension_for_file_type(FileType.AAC) == "aac"
assert get_file_extension_for_file_type(FileType.PDF) == "pdf"
def test_get_file_mime_type_for_file_type(self):
assert get_file_mime_type_for_file_type(FileType.AAC) == "audio/aac"
assert get_file_mime_type_for_file_type(FileType.PDF) == "application/pdf"
def test_get_file_mime_type_from_extension(self):
assert get_file_mime_type_from_extension("aac") == "audio/aac"
assert get_file_mime_type_from_extension("pdf") == "application/pdf"
def test_uppercase_extensions(self):
# Test that uppercase extensions return the correct file type
assert get_file_type_from_extension("AAC") == FileType.AAC
assert get_file_type_from_extension("PDF") == FileType.PDF
# Test that uppercase extensions return the correct MIME type
assert get_file_mime_type_from_extension("AAC") == "audio/aac"
assert get_file_mime_type_from_extension("PDF") == "application/pdf"

View file

@ -1,12 +1,16 @@
import sys
import os
import asyncio
import json
import logging
import os
import sys
import time
import logfire
import litellm
import pytest
from logfire.testing import TestExporter, SimpleSpanProcessor
import litellm
from litellm._logging import verbose_logger, verbose_proxy_logger
verbose_logger.setLevel(logging.DEBUG)
sys.path.insert(0, os.path.abspath("../.."))
@ -17,19 +21,13 @@ sys.path.insert(0, os.path.abspath("../.."))
# 4. Test logfire logging for completion while streaming is enabled
@pytest.mark.skip(reason="Breaks on ci/cd")
@pytest.mark.skip(reason="Breaks on ci/cd but works locally")
@pytest.mark.parametrize("stream", [False, True])
def test_completion_logfire_logging(stream):
litellm.success_callback = ["logfire"]
litellm.set_verbose = True
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
exporter = TestExporter()
logfire.configure(
send_to_logfire=False,
console=False,
processors=[SimpleSpanProcessor(exporter)],
collect_system_metrics=False,
)
litellm.callbacks = ["logfire"]
litellm.set_verbose = True
messages = [{"role": "user", "content": "what llm are u"}]
temperature = 0.3
max_tokens = 10
@ -47,41 +45,16 @@ def test_completion_logfire_logging(stream):
print(chunk)
time.sleep(5)
exported_spans = exporter.exported_spans_as_dict()
assert len(exported_spans) == 1
assert (
exported_spans[0]["attributes"]["logfire.msg"]
== "Chat Completion with 'gpt-3.5-turbo'"
)
request_data = json.loads(exported_spans[0]["attributes"]["request_data"])
assert request_data["model"] == "gpt-3.5-turbo"
assert request_data["messages"] == messages
assert "completion_tokens" in request_data["usage"]
assert "prompt_tokens" in request_data["usage"]
assert "total_tokens" in request_data["usage"]
assert request_data["response"]["choices"][0]["message"]["content"]
assert request_data["modelParameters"]["max_tokens"] == max_tokens
assert request_data["modelParameters"]["temperature"] == temperature
@pytest.mark.skip(reason="Breaks on ci/cd")
@pytest.mark.skip(reason="Breaks on ci/cd but works locally")
@pytest.mark.asyncio
@pytest.mark.parametrize("stream", [False, True])
async def test_acompletion_logfire_logging(stream):
litellm.success_callback = ["logfire"]
litellm.set_verbose = True
from litellm.integrations.opentelemetry import OpenTelemetry, OpenTelemetryConfig
exporter = TestExporter()
logfire.configure(
send_to_logfire=False,
console=False,
processors=[SimpleSpanProcessor(exporter)],
collect_system_metrics=False,
)
litellm.callbacks = ["logfire"]
litellm.set_verbose = True
messages = [{"role": "user", "content": "what llm are u"}]
temperature = 0.3
max_tokens = 10
@ -90,30 +63,11 @@ async def test_acompletion_logfire_logging(stream):
messages=messages,
max_tokens=max_tokens,
temperature=temperature,
stream=stream,
)
print(response)
if stream:
for chunk in response:
async for chunk in response:
print(chunk)
time.sleep(5)
exported_spans = exporter.exported_spans_as_dict()
print("exported_spans", exported_spans)
assert len(exported_spans) == 1
assert (
exported_spans[0]["attributes"]["logfire.msg"]
== "Chat Completion with 'gpt-3.5-turbo'"
)
request_data = json.loads(exported_spans[0]["attributes"]["request_data"])
assert request_data["model"] == "gpt-3.5-turbo"
assert request_data["messages"] == messages
assert "completion_tokens" in request_data["usage"]
assert "prompt_tokens" in request_data["usage"]
assert "total_tokens" in request_data["usage"]
assert request_data["response"]["choices"][0]["message"]["content"]
assert request_data["modelParameters"]["max_tokens"] == max_tokens
assert request_data["modelParameters"]["temperature"] == temperature
await asyncio.sleep(5)

View file

@ -151,23 +151,16 @@ Util Functions
"""
def get_file_mime_type_from_extension(extension: str) -> str:
for file_type, extensions in FILE_EXTENSIONS.items():
if extension in extensions:
return FILE_MIME_TYPES[file_type]
raise ValueError(f"Unknown mime type for extension: {extension}")
def get_file_extension_from_mime_type(mime_type: str) -> str:
for file_type, mime in FILE_MIME_TYPES.items():
if mime == mime_type:
if mime.lower() == mime_type.lower():
return FILE_EXTENSIONS[file_type][0]
raise ValueError(f"Unknown extension for mime type: {mime_type}")
def get_file_type_from_extension(extension: str) -> FileType:
for file_type, extensions in FILE_EXTENSIONS.items():
if extension in extensions:
if extension.lower() in extensions:
return file_type
raise ValueError(f"Unknown file type for extension: {extension}")
@ -181,6 +174,11 @@ def get_file_mime_type_for_file_type(file_type: FileType) -> str:
return FILE_MIME_TYPES[file_type]
def get_file_mime_type_from_extension(extension: str) -> str:
file_type = get_file_type_from_extension(extension)
return get_file_mime_type_for_file_type(file_type)
"""
FileType Type Groupings (Videos, Images, etc)
"""

View file

@ -340,6 +340,7 @@ def function_setup(
try:
global callback_list, add_breadcrumb, user_logger_fn, Logging
function_id = kwargs["id"] if "id" in kwargs else None
if len(litellm.callbacks) > 0:
for callback in litellm.callbacks:
# check if callback is a string - e.g. "lago", "openmeter"