forked from phoenix/litellm-mirror
Add pyright to ci/cd + Fix remaining type-checking errors (#6082)
* fix: fix type-checking errors * fix: fix additional type-checking errors * fix: additional type-checking error fixes * fix: fix additional type-checking errors * fix: additional type-check fixes * fix: fix all type-checking errors + add pyright to ci/cd * fix: fix incorrect import * ci(config.yml): use mypy on ci/cd * fix: fix type-checking errors in utils.py * fix: fix all type-checking errors on main.py * fix: fix mypy linting errors * fix(anthropic/cost_calculator.py): fix linting errors * fix: fix mypy linting errors * fix: fix linting errors
This commit is contained in:
parent
f7ce1173f3
commit
fac3b2ee42
65 changed files with 619 additions and 522 deletions
|
@ -315,11 +315,11 @@ jobs:
|
|||
python -m pip install --upgrade pip
|
||||
pip install ruff
|
||||
pip install pylint
|
||||
pip install pyright
|
||||
pip install .
|
||||
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
|
||||
- run: ruff check ./litellm
|
||||
|
||||
|
||||
build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
|
|
|
@ -8,7 +8,7 @@ import requests
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
from typing import Literal, Union, Optional
|
||||
|
||||
import traceback
|
||||
|
||||
|
@ -26,9 +26,9 @@ from litellm._logging import print_verbose, verbose_logger
|
|||
|
||||
class GenericAPILogger:
|
||||
# Class variables or attributes
|
||||
def __init__(self, endpoint=None, headers=None):
|
||||
def __init__(self, endpoint: Optional[str] = None, headers: Optional[dict] = None):
|
||||
try:
|
||||
if endpoint == None:
|
||||
if endpoint is None:
|
||||
# check env for "GENERIC_LOGGER_ENDPOINT"
|
||||
if os.getenv("GENERIC_LOGGER_ENDPOINT"):
|
||||
# Do something with the endpoint
|
||||
|
@ -36,9 +36,15 @@ class GenericAPILogger:
|
|||
else:
|
||||
# Handle the case when the endpoint is not found in the environment variables
|
||||
raise ValueError(
|
||||
f"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
|
||||
"endpoint not set for GenericAPILogger, GENERIC_LOGGER_ENDPOINT not found in environment variables"
|
||||
)
|
||||
headers = headers or litellm.generic_logger_headers
|
||||
|
||||
if endpoint is None:
|
||||
raise ValueError("endpoint not set for GenericAPILogger")
|
||||
if headers is None:
|
||||
raise ValueError("headers not set for GenericAPILogger")
|
||||
|
||||
self.endpoint = endpoint
|
||||
self.headers = headers
|
||||
|
||||
|
|
|
@ -48,8 +48,6 @@ class AporiaGuardrail(CustomGuardrail):
|
|||
)
|
||||
self.aporia_api_key = api_key or os.environ["APORIO_API_KEY"]
|
||||
self.aporia_api_base = api_base or os.environ["APORIO_API_BASE"]
|
||||
self.event_hook: GuardrailEventHooks
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
#### CALL HOOKS - proxy only ####
|
||||
|
|
|
@ -84,7 +84,7 @@ class _ENTERPRISE_BlockedUserList(CustomLogger):
|
|||
)
|
||||
|
||||
cache_key = f"litellm:end_user_id:{user}"
|
||||
end_user_cache_obj: LiteLLM_EndUserTable = cache.get_cache(
|
||||
end_user_cache_obj: Optional[LiteLLM_EndUserTable] = cache.get_cache( # type: ignore
|
||||
key=cache_key
|
||||
)
|
||||
if end_user_cache_obj is None and self.prisma_client is not None:
|
||||
|
|
|
@ -48,7 +48,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
from google.cloud import language_v1
|
||||
from google.cloud import language_v1 # type: ignore
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
||||
|
@ -57,8 +57,8 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
# Instantiates a client
|
||||
self.client = language_v1.LanguageServiceClient()
|
||||
self.moderate_text_request = language_v1.ModerateTextRequest
|
||||
self.language_document = language_v1.types.Document
|
||||
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT
|
||||
self.language_document = language_v1.types.Document # type: ignore
|
||||
self.document_type = language_v1.types.Document.Type.PLAIN_TEXT # type: ignore
|
||||
|
||||
default_confidence_threshold = (
|
||||
litellm.google_moderation_confidence_threshold or 0.8
|
||||
|
|
|
@ -8,6 +8,7 @@
|
|||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import sys, os
|
||||
from collections.abc import Iterable
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
|
@ -19,11 +20,12 @@ from litellm.proxy._types import UserAPIKeyAuth
|
|||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from fastapi import HTTPException
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.utils import (
|
||||
from litellm.types.utils import (
|
||||
ModelResponse,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
StreamingChoices,
|
||||
Choices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
|
@ -34,7 +36,10 @@ litellm.set_verbose = True
|
|||
class _ENTERPRISE_LlamaGuard(CustomLogger):
|
||||
# Class variables or attributes
|
||||
def __init__(self, model_name: Optional[str] = None):
|
||||
self.model = model_name or litellm.llamaguard_model_name
|
||||
_model = model_name or litellm.llamaguard_model_name
|
||||
if _model is None:
|
||||
raise ValueError("model_name not set for LlamaGuard")
|
||||
self.model = _model
|
||||
file_path = litellm.llamaguard_unsafe_content_categories
|
||||
data = None
|
||||
|
||||
|
@ -124,7 +129,13 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
|||
hf_model_name="meta-llama/LlamaGuard-7b",
|
||||
)
|
||||
|
||||
if "unsafe" in response.choices[0].message.content:
|
||||
if (
|
||||
isinstance(response, ModelResponse)
|
||||
and isinstance(response.choices[0], Choices)
|
||||
and response.choices[0].message.content is not None
|
||||
and isinstance(response.choices[0].message.content, Iterable)
|
||||
and "unsafe" in response.choices[0].message.content
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Violated content safety policy"}
|
||||
)
|
||||
|
|
|
@ -8,7 +8,11 @@
|
|||
## This provides an LLM Guard Integration for content moderation on the proxy
|
||||
|
||||
from typing import Optional, Literal, Union
|
||||
import litellm, traceback, sys, uuid, os
|
||||
import litellm
|
||||
import traceback
|
||||
import sys
|
||||
import uuid
|
||||
import os
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
|
@ -21,8 +25,10 @@ from litellm.utils import (
|
|||
StreamingChoices,
|
||||
)
|
||||
from datetime import datetime
|
||||
import aiohttp, asyncio
|
||||
import aiohttp
|
||||
import asyncio
|
||||
from litellm.utils import get_formatted_prompt
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
|
@ -38,7 +44,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
self.llm_guard_mode = litellm.llm_guard_mode
|
||||
if mock_testing == True: # for testing purposes only
|
||||
return
|
||||
self.llm_guard_api_base = litellm.get_secret("LLM_GUARD_API_BASE", None)
|
||||
self.llm_guard_api_base = get_secret_str("LLM_GUARD_API_BASE", None)
|
||||
if self.llm_guard_api_base is None:
|
||||
raise Exception("Missing `LLM_GUARD_API_BASE` from environment")
|
||||
elif not self.llm_guard_api_base.endswith("/"):
|
||||
|
|
|
@ -51,8 +51,8 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
"audio_transcription",
|
||||
],
|
||||
):
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
text = ""
|
||||
if "messages" in data and isinstance(data["messages"], list):
|
||||
for m in data["messages"]: # assume messages is a list
|
||||
if "content" in m and isinstance(m["content"], str):
|
||||
text += m["content"]
|
||||
|
@ -67,7 +67,7 @@ class _ENTERPRISE_OpenAI_Moderation(CustomLogger):
|
|||
)
|
||||
|
||||
verbose_proxy_logger.debug("Moderation response: %s", moderation_response)
|
||||
if moderation_response.results[0].flagged == True:
|
||||
if moderation_response.results[0].flagged is True:
|
||||
raise HTTPException(
|
||||
status_code=403, detail={"error": "Violated content safety policy"}
|
||||
)
|
||||
|
|
|
@ -6,7 +6,9 @@ import collections
|
|||
from datetime import datetime
|
||||
|
||||
|
||||
async def get_spend_by_tags(start_date=None, end_date=None, prisma_client=None):
|
||||
async def get_spend_by_tags(
|
||||
prisma_client: PrismaClient, start_date=None, end_date=None
|
||||
):
|
||||
response = await prisma_client.db.query_raw(
|
||||
"""
|
||||
SELECT
|
||||
|
|
|
@ -191,7 +191,7 @@ def init_redis_cluster(redis_kwargs) -> redis.RedisCluster:
|
|||
new_startup_nodes.append(ClusterNode(**item))
|
||||
|
||||
redis_kwargs.pop("startup_nodes")
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs)
|
||||
return redis.RedisCluster(startup_nodes=new_startup_nodes, **cluster_kwargs) # type: ignore
|
||||
|
||||
|
||||
def _init_redis_sentinel(redis_kwargs) -> redis.Redis:
|
||||
|
|
|
@ -1,5 +1,8 @@
|
|||
import litellm
|
||||
from typing import Optional, Union
|
||||
|
||||
import litellm
|
||||
|
||||
from ..exceptions import UnsupportedParamsError
|
||||
from ..types.llms.openai import *
|
||||
|
||||
|
||||
|
|
|
@ -10,6 +10,7 @@
|
|||
import ast
|
||||
import asyncio
|
||||
import hashlib
|
||||
import inspect
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
|
@ -245,7 +246,8 @@ class RedisCache(BaseCache):
|
|||
self.redis_flush_size = redis_flush_size
|
||||
self.redis_version = "Unknown"
|
||||
try:
|
||||
self.redis_version = self.redis_client.info()["redis_version"]
|
||||
if not inspect.iscoroutinefunction(self.redis_client):
|
||||
self.redis_version = self.redis_client.info()["redis_version"] # type: ignore
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
@ -266,7 +268,8 @@ class RedisCache(BaseCache):
|
|||
|
||||
### SYNC HEALTH PING ###
|
||||
try:
|
||||
self.redis_client.ping()
|
||||
if hasattr(self.redis_client, "ping"):
|
||||
self.redis_client.ping() # type: ignore
|
||||
except Exception as e:
|
||||
verbose_logger.error(
|
||||
"Error connecting to Sync Redis client", extra={"error": str(e)}
|
||||
|
@ -308,7 +311,7 @@ class RedisCache(BaseCache):
|
|||
_redis_client = self.redis_client
|
||||
start_time = time.time()
|
||||
try:
|
||||
result = _redis_client.incr(name=key, amount=value)
|
||||
result: int = _redis_client.incr(name=key, amount=value) # type: ignore
|
||||
|
||||
if ttl is not None:
|
||||
# check if key already has ttl, if not -> set ttl
|
||||
|
@ -561,7 +564,7 @@ class RedisCache(BaseCache):
|
|||
f"Set ASYNC Redis Cache: key: {key}\nValue {value}\nttl={ttl}"
|
||||
)
|
||||
try:
|
||||
await redis_client.sadd(key, *value)
|
||||
await redis_client.sadd(key, *value) # type: ignore
|
||||
if ttl is not None:
|
||||
_td = timedelta(seconds=ttl)
|
||||
await redis_client.expire(key, _td)
|
||||
|
@ -712,7 +715,7 @@ class RedisCache(BaseCache):
|
|||
for cache_key in key_list:
|
||||
cache_key = self.check_and_fix_namespace(key=cache_key)
|
||||
_keys.append(cache_key)
|
||||
results = self.redis_client.mget(keys=_keys)
|
||||
results: List = self.redis_client.mget(keys=_keys) # type: ignore
|
||||
|
||||
# Associate the results back with their keys.
|
||||
# 'results' is a list of values corresponding to the order of keys in 'key_list'.
|
||||
|
@ -842,7 +845,7 @@ class RedisCache(BaseCache):
|
|||
print_verbose("Pinging Sync Redis Cache")
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = self.redis_client.ping()
|
||||
response: bool = self.redis_client.ping() # type: ignore
|
||||
print_verbose(f"Redis Cache PING: {response}")
|
||||
## LOGGING ##
|
||||
end_time = time.time()
|
||||
|
@ -911,8 +914,8 @@ class RedisCache(BaseCache):
|
|||
async with _redis_client as redis_client:
|
||||
await redis_client.delete(*keys)
|
||||
|
||||
def client_list(self):
|
||||
client_list = self.redis_client.client_list()
|
||||
def client_list(self) -> List:
|
||||
client_list: List = self.redis_client.client_list() # type: ignore
|
||||
return client_list
|
||||
|
||||
def info(self):
|
||||
|
|
|
@ -39,8 +39,8 @@ from litellm.llms.fireworks_ai.cost_calculator import (
|
|||
)
|
||||
from litellm.llms.OpenAI.cost_calculation import cost_per_token as openai_cost_per_token
|
||||
from litellm.llms.together_ai.cost_calculator import get_model_params_and_category
|
||||
from litellm.rerank_api.types import RerankResponse
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
from litellm.types.utils import PassthroughCallTypes, Usage
|
||||
from litellm.utils import (
|
||||
|
|
|
@ -39,11 +39,11 @@ from litellm.proxy._types import (
|
|||
VirtualKeyEvent,
|
||||
WebhookEvent,
|
||||
)
|
||||
from litellm.types.integrations.slack_alerting import *
|
||||
from litellm.types.router import LiteLLM_Params
|
||||
|
||||
from ..email_templates.templates import *
|
||||
from .batching_handler import send_to_webhook, squash_payloads
|
||||
from .types import *
|
||||
from .utils import _add_langfuse_trace_id_to_alert, process_slack_alerting_variables
|
||||
|
||||
|
||||
|
|
|
@ -172,14 +172,14 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
"moderation",
|
||||
"audio_transcription",
|
||||
],
|
||||
):
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
async def async_post_call_streaming_hook(
|
||||
self,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response: str,
|
||||
):
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
#### SINGLE-USE #### - https://docs.litellm.ai/docs/observability/custom_callback#using-your-custom-callback-function
|
||||
|
|
|
@ -42,8 +42,10 @@ async def get_all_team_member_emails(team_id: Optional[str] = None) -> list:
|
|||
)
|
||||
_team_member_user_ids: List[str] = []
|
||||
for member in _team_members:
|
||||
if member and isinstance(member, dict) and member.get("user_id") is not None:
|
||||
_team_member_user_ids.append(member.get("user_id"))
|
||||
if member and isinstance(member, dict):
|
||||
_user_id = member.get("user_id")
|
||||
if _user_id and isinstance(_user_id, str):
|
||||
_team_member_user_ids.append(_user_id)
|
||||
|
||||
sql_query = """
|
||||
SELECT user_email
|
||||
|
|
|
@ -149,7 +149,7 @@ class LunaryLogger:
|
|||
else:
|
||||
error_obj = None
|
||||
|
||||
self.lunary_client.track_event(
|
||||
self.lunary_client.track_event( # type: ignore
|
||||
type,
|
||||
"start",
|
||||
run_id,
|
||||
|
@ -164,7 +164,7 @@ class LunaryLogger:
|
|||
params=extra,
|
||||
)
|
||||
|
||||
self.lunary_client.track_event(
|
||||
self.lunary_client.track_event( # type: ignore
|
||||
type,
|
||||
event,
|
||||
run_id,
|
||||
|
|
|
@ -100,16 +100,14 @@ class OpenMeterLogger(CustomLogger):
|
|||
}
|
||||
|
||||
try:
|
||||
response = self.sync_http_handler.post(
|
||||
self.sync_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
raise Exception(f"OpenMeter logging error: {e.response.text}")
|
||||
except Exception as e:
|
||||
if hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -128,18 +126,12 @@ class OpenMeterLogger(CustomLogger):
|
|||
}
|
||||
|
||||
try:
|
||||
response = await self.async_http_handler.post(
|
||||
await self.async_http_handler.post(
|
||||
url=_url,
|
||||
data=json.dumps(_data),
|
||||
headers=_headers,
|
||||
)
|
||||
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as e:
|
||||
verbose_logger.error(
|
||||
"Failed OpenMeter logging - {}".format(e.response.text)
|
||||
)
|
||||
raise e
|
||||
raise Exception(f"OpenMeter logging error: {e.response.text}")
|
||||
except Exception as e:
|
||||
verbose_logger.error("Failed OpenMeter logging - {}".format(str(e)))
|
||||
raise e
|
||||
|
|
|
@ -146,14 +146,14 @@ class S3Logger:
|
|||
|
||||
import json
|
||||
|
||||
payload = json.dumps(payload)
|
||||
payload_str = json.dumps(payload)
|
||||
|
||||
print_verbose(f"\ns3 Logger - Logging payload = {payload}")
|
||||
print_verbose(f"\ns3 Logger - Logging payload = {payload_str}")
|
||||
|
||||
response = self.s3_client.put_object(
|
||||
Bucket=self.bucket_name,
|
||||
Key=s3_object_key,
|
||||
Body=payload,
|
||||
Body=payload_str,
|
||||
ContentType="application/json",
|
||||
ContentLanguage="en",
|
||||
ContentDisposition=f'inline; filename="{s3_object_download_filename}"',
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import traceback
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class TraceloopLogger:
|
||||
|
@ -11,14 +12,15 @@ class TraceloopLogger:
|
|||
|
||||
def __init__(self):
|
||||
try:
|
||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
from traceloop.sdk import Traceloop
|
||||
from traceloop.sdk.instruments import Instruments
|
||||
from opentelemetry.sdk.trace.export import ConsoleSpanExporter
|
||||
from traceloop.sdk.tracing.tracing import TracerWrapper
|
||||
except ModuleNotFoundError as e:
|
||||
verbose_logger.error(
|
||||
f"Traceloop not installed, try running 'pip install traceloop-sdk' to fix this error: {e}\n{traceback.format_exc()}"
|
||||
)
|
||||
raise e
|
||||
|
||||
Traceloop.init(
|
||||
app_name="Litellm-Server",
|
||||
|
@ -38,8 +40,8 @@ class TraceloopLogger:
|
|||
status_message=None,
|
||||
):
|
||||
from opentelemetry import trace
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
from opentelemetry.semconv.ai import SpanAttributes
|
||||
from opentelemetry.trace import SpanKind, Status, StatusCode
|
||||
|
||||
try:
|
||||
print_verbose(
|
||||
|
@ -94,7 +96,7 @@ class TraceloopLogger:
|
|||
)
|
||||
if "temperature" in optional_params:
|
||||
span.set_attribute(
|
||||
SpanAttributes.LLM_REQUEST_TEMPERATURE,
|
||||
SpanAttributes.LLM_REQUEST_TEMPERATURE, # type: ignore
|
||||
kwargs.get("temperature"),
|
||||
)
|
||||
|
||||
|
|
|
@ -32,8 +32,8 @@ from litellm.litellm_core_utils.redact_messages import (
|
|||
redact_message_input_output_from_logging,
|
||||
)
|
||||
from litellm.proxy._types import CommonProxyErrors
|
||||
from litellm.rerank_api.types import RerankResponse
|
||||
from litellm.types.llms.openai import HttpxBinaryResponseContent
|
||||
from litellm.types.rerank import RerankResponse
|
||||
from litellm.types.router import SPECIAL_MODEL_INFO_PARAMS
|
||||
from litellm.types.utils import (
|
||||
CallTypes,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import uuid
|
||||
from typing import Optional, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import httpx
|
||||
from openai import AsyncAzureOpenAI, AzureOpenAI
|
||||
|
@ -24,6 +24,7 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
model: str,
|
||||
audio_file: FileTypes,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
|
@ -32,9 +33,8 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
api_version: Optional[str] = None,
|
||||
client=None,
|
||||
azure_ad_token: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
atranscription: bool = False,
|
||||
):
|
||||
) -> TranscriptionResponse:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
|
||||
# init AzureOpenAI Client
|
||||
|
@ -59,7 +59,7 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
azure_client_params["max_retries"] = max_retries
|
||||
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions(
|
||||
return self.async_audio_transcriptions( # type: ignore
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
|
@ -105,7 +105,7 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": "whisper-1", "custom_llm_provider": "azure"}
|
||||
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
return final_response
|
||||
|
||||
async def async_audio_transcriptions(
|
||||
|
@ -114,12 +114,12 @@ class AzureAudioTranscription(AzureChatCompletion):
|
|||
data: dict,
|
||||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
azure_client_params: dict,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
azure_client_params=None,
|
||||
max_retries=None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
try:
|
||||
|
|
|
@ -1083,7 +1083,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
) -> litellm.EmbeddingResponse:
|
||||
super().embedding()
|
||||
if self._client_session is None:
|
||||
self._client_session = self.create_client_session()
|
||||
|
@ -1128,7 +1128,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
if aembedding is True:
|
||||
response = self.aembedding(
|
||||
return self.aembedding( # type: ignore
|
||||
data=data,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -1138,7 +1138,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
return response
|
||||
if client is None:
|
||||
azure_client = AzureOpenAI(**azure_client_params) # type: ignore
|
||||
else:
|
||||
|
@ -1418,7 +1417,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
logging_obj: LiteLLMLoggingObj,
|
||||
client=None,
|
||||
timeout=None,
|
||||
):
|
||||
) -> litellm.ImageResponse:
|
||||
response: Optional[dict] = None
|
||||
try:
|
||||
# response = await azure_client.images.generate(**data, timeout=timeout)
|
||||
|
@ -1460,7 +1459,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data},
|
||||
original_response=stringified_response,
|
||||
)
|
||||
return convert_to_model_response_object(
|
||||
return convert_to_model_response_object( # type: ignore
|
||||
response_object=stringified_response,
|
||||
model_response_object=model_response,
|
||||
response_type="image_generation",
|
||||
|
@ -1489,7 +1488,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token: Optional[str] = None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
) -> litellm.ImageResponse:
|
||||
try:
|
||||
if model and len(model) > 0:
|
||||
model = model
|
||||
|
@ -1531,8 +1530,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation is True:
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
||||
return response
|
||||
return self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
||||
|
||||
img_gen_api_base = self.create_azure_base_url(
|
||||
azure_client_params=azure_client_params, model=data.get("model", "")
|
||||
|
@ -1742,9 +1740,9 @@ class AzureChatCompletion(BaseLLM):
|
|||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
api_key: str,
|
||||
api_key: Optional[str],
|
||||
api_base: str,
|
||||
api_version: str,
|
||||
api_version: Optional[str],
|
||||
timeout: float,
|
||||
mode: str,
|
||||
messages: Optional[list] = None,
|
||||
|
|
|
@ -77,15 +77,15 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
|||
model_response: TranscriptionResponse,
|
||||
timeout: float,
|
||||
max_retries: int,
|
||||
logging_obj: LiteLLMLoggingObj,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
client=None,
|
||||
logging_obj=None,
|
||||
atranscription: bool = False,
|
||||
):
|
||||
) -> TranscriptionResponse:
|
||||
data = {"model": model, "file": audio_file, **optional_params}
|
||||
if atranscription is True:
|
||||
return self.async_audio_transcriptions(
|
||||
return self.async_audio_transcriptions( # type: ignore
|
||||
audio_file=audio_file,
|
||||
data=data,
|
||||
model_response=model_response,
|
||||
|
@ -97,7 +97,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
|
||||
openai_client = self._get_openai_client(
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -123,7 +123,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
|||
original_response=stringified_response,
|
||||
)
|
||||
hidden_params = {"model": "whisper-1", "custom_llm_provider": "openai"}
|
||||
final_response = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
final_response: TranscriptionResponse = convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, hidden_params=hidden_params, response_type="audio_transcription") # type: ignore
|
||||
return final_response
|
||||
|
||||
async def async_audio_transcriptions(
|
||||
|
@ -139,7 +139,7 @@ class OpenAIAudioTranscription(OpenAIChatCompletion):
|
|||
max_retries=None,
|
||||
):
|
||||
try:
|
||||
openai_aclient = self._get_openai_client(
|
||||
openai_aclient: AsyncOpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
|
|
@ -1167,7 +1167,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
) -> litellm.EmbeddingResponse:
|
||||
super().embedding()
|
||||
try:
|
||||
model = model
|
||||
|
@ -1183,7 +1183,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
if aembedding is True:
|
||||
async_response = self.aembedding(
|
||||
return self.aembedding( # type: ignore
|
||||
data=data,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -1194,7 +1194,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=client,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
return async_response
|
||||
|
||||
openai_client: OpenAI = self._get_openai_client( # type: ignore
|
||||
is_async=False,
|
||||
|
@ -1294,7 +1293,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
model_response: Optional[litellm.utils.ImageResponse] = None,
|
||||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
) -> litellm.ImageResponse:
|
||||
data = {}
|
||||
try:
|
||||
model = model
|
||||
|
@ -1304,8 +1303,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
if aimg_generation is True:
|
||||
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
return self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
|
||||
openai_client = self._get_openai_client(
|
||||
is_async=False,
|
||||
|
@ -1449,7 +1447,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
async def ahealth_check(
|
||||
self,
|
||||
model: Optional[str],
|
||||
api_key: str,
|
||||
api_key: Optional[str],
|
||||
timeout: float,
|
||||
mode: str,
|
||||
messages: Optional[list] = None,
|
||||
|
|
|
@ -282,7 +282,6 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
prompt_tokens = completion_response["usage"]["input_tokens"]
|
||||
completion_tokens = completion_response["usage"]["output_tokens"]
|
||||
_usage = completion_response["usage"]
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
cache_creation_input_tokens: int = 0
|
||||
cache_read_input_tokens: int = 0
|
||||
|
||||
|
@ -290,12 +289,15 @@ class AnthropicChatCompletion(BaseLLM):
|
|||
model_response.model = model
|
||||
if "cache_creation_input_tokens" in _usage:
|
||||
cache_creation_input_tokens = _usage["cache_creation_input_tokens"]
|
||||
prompt_tokens += cache_creation_input_tokens
|
||||
if "cache_read_input_tokens" in _usage:
|
||||
cache_read_input_tokens = _usage["cache_read_input_tokens"]
|
||||
prompt_tokens += cache_read_input_tokens
|
||||
|
||||
prompt_tokens_details = PromptTokensDetails(
|
||||
cached_tokens=cache_read_input_tokens
|
||||
)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = Usage(
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
|
|
|
@ -24,16 +24,33 @@ def cost_per_token(model: str, usage: Usage) -> Tuple[float, float]:
|
|||
model_info = get_model_info(model=model, custom_llm_provider="anthropic")
|
||||
|
||||
## CALCULATE INPUT COST
|
||||
### Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
|
||||
prompt_cost = 0.0
|
||||
### PROCESSING COST
|
||||
non_cache_hit_tokens = usage.prompt_tokens
|
||||
cache_hit_tokens = 0
|
||||
if usage.prompt_tokens_details and usage.prompt_tokens_details.cached_tokens:
|
||||
cache_hit_tokens = usage.prompt_tokens_details.cached_tokens
|
||||
non_cache_hit_tokens = non_cache_hit_tokens - cache_hit_tokens
|
||||
|
||||
prompt_cost: float = usage["prompt_tokens"] * model_info["input_cost_per_token"]
|
||||
if model_info.get("cache_creation_input_token_cost") is not None:
|
||||
prompt_cost = float(non_cache_hit_tokens) * model_info["input_cost_per_token"]
|
||||
|
||||
_cache_read_input_token_cost = model_info.get("cache_read_input_token_cost")
|
||||
if (
|
||||
_cache_read_input_token_cost is not None
|
||||
and usage.prompt_tokens_details
|
||||
and usage.prompt_tokens_details.cached_tokens
|
||||
):
|
||||
prompt_cost += (
|
||||
usage._cache_creation_input_tokens # type: ignore
|
||||
* model_info["cache_creation_input_token_cost"]
|
||||
float(usage.prompt_tokens_details.cached_tokens)
|
||||
* _cache_read_input_token_cost
|
||||
)
|
||||
if model_info.get("cache_read_input_token_cost") is not None:
|
||||
|
||||
### CACHE WRITING COST
|
||||
_cache_creation_input_token_cost = model_info.get("cache_creation_input_token_cost")
|
||||
if _cache_creation_input_token_cost is not None:
|
||||
prompt_cost += (
|
||||
usage._cache_read_input_tokens * model_info["cache_read_input_token_cost"] # type: ignore
|
||||
float(usage._cache_creation_input_tokens) * _cache_creation_input_token_cost
|
||||
)
|
||||
|
||||
## CALCULATE OUTPUT COST
|
||||
|
|
|
@ -216,7 +216,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
|||
api_base: Optional[str] = None,
|
||||
client=None,
|
||||
aembedding=None,
|
||||
):
|
||||
) -> litellm.EmbeddingResponse:
|
||||
"""
|
||||
- Separate image url from text
|
||||
-> route image url call to `/image/embeddings`
|
||||
|
@ -225,7 +225,7 @@ class AzureAIEmbedding(OpenAIChatCompletion):
|
|||
assemble result in-order, and return
|
||||
"""
|
||||
if aembedding is True:
|
||||
return self.async_embedding(
|
||||
return self.async_embedding( # type: ignore
|
||||
model,
|
||||
input,
|
||||
timeout,
|
||||
|
|
|
@ -4,7 +4,7 @@ import httpx
|
|||
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.rerank_api.types import RerankResponse
|
||||
from litellm.types.rerank import RerankResponse
|
||||
|
||||
|
||||
class AzureAIRerank(CohereRerank):
|
||||
|
|
|
@ -5,7 +5,7 @@ Handles image gen calls to Bedrock's `/invoke` endpoint
|
|||
import copy
|
||||
import json
|
||||
import os
|
||||
from typing import List
|
||||
from typing import Any, List
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
|
@ -20,8 +20,8 @@ def image_generation(
|
|||
prompt: str,
|
||||
model_response: ImageResponse,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
timeout=None,
|
||||
logging_obj=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
"""
|
||||
|
|
|
@ -13,7 +13,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.rerank_api.types import RerankRequest, RerankResponse
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class CohereRerank(BaseLLM):
|
||||
|
|
|
@ -40,7 +40,7 @@ class AzureOpenAIFineTuningAPI(BaseLLM):
|
|||
organization: Optional[str] = None,
|
||||
client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = None,
|
||||
api_version: Optional[str] = None,
|
||||
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
|
||||
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||
openai_client: Optional[Union[AzureOpenAI, AsyncAzureOpenAI]] = (
|
||||
get_azure_openai_client(
|
||||
api_key=api_key,
|
||||
|
|
|
@ -68,7 +68,7 @@ class OpenAIFineTuningAPI(BaseLLM):
|
|||
max_retries: Optional[int],
|
||||
organization: Optional[str],
|
||||
client: Optional[Union[OpenAI, AsyncOpenAI]] = None,
|
||||
) -> Union[FineTuningJob, Union[Coroutine[Any, Any, FineTuningJob]]]:
|
||||
) -> Union[FineTuningJob, Coroutine[Any, Any, FineTuningJob]]:
|
||||
openai_client: Optional[Union[OpenAI, AsyncOpenAI]] = self.get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
|
|
@ -554,7 +554,7 @@ class Huggingface(BaseLLM):
|
|||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"inputs": prompt, # type: ignore
|
||||
"parameters": optional_params,
|
||||
"stream": ( # type: ignore
|
||||
True
|
||||
|
@ -589,7 +589,7 @@ class Huggingface(BaseLLM):
|
|||
inference_params.pop("details")
|
||||
inference_params.pop("return_full_text")
|
||||
data = {
|
||||
"inputs": prompt,
|
||||
"inputs": prompt, # type: ignore
|
||||
}
|
||||
if task == "text-generation-inference":
|
||||
data["parameters"] = inference_params
|
||||
|
|
|
@ -4,7 +4,7 @@ import time
|
|||
import traceback
|
||||
import types
|
||||
from enum import Enum
|
||||
from typing import Callable, List, Optional
|
||||
from typing import Any, Callable, List, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
|
||||
|
@ -185,8 +185,8 @@ def completion(
|
|||
def embedding(
|
||||
model: str,
|
||||
input: list,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str],
|
||||
logging_obj: Any,
|
||||
model_response=None,
|
||||
encoding=None,
|
||||
):
|
||||
|
|
|
@ -2,7 +2,7 @@ import json
|
|||
import os
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Callable, Optional
|
||||
from typing import Any, Callable, Optional
|
||||
|
||||
import requests # type: ignore
|
||||
|
||||
|
@ -124,9 +124,9 @@ def embedding(
|
|||
model: str,
|
||||
input: list,
|
||||
model_response: EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
api_key: Optional[str],
|
||||
api_base: Optional[str],
|
||||
logging_obj: Any,
|
||||
optional_params=None,
|
||||
encoding=None,
|
||||
):
|
||||
|
|
|
@ -2714,7 +2714,7 @@ def custom_prompt(
|
|||
final_prompt_value: str = "",
|
||||
bos_token: str = "",
|
||||
eos_token: str = "",
|
||||
):
|
||||
) -> str:
|
||||
prompt = bos_token + initial_prompt_value
|
||||
bos_open = True
|
||||
## a bos token is at the start of a system / human message
|
||||
|
|
|
@ -15,7 +15,7 @@ from litellm.llms.custom_httpx.http_handler import (
|
|||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
)
|
||||
from litellm.rerank_api.types import RerankRequest, RerankResponse
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
|
||||
|
||||
class TogetherAIRerank(BaseLLM):
|
||||
|
|
|
@ -47,7 +47,7 @@ class TritonChatCompletion(BaseLLM):
|
|||
data: dict,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
api_base: str,
|
||||
logging_obj=None,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
) -> EmbeddingResponse:
|
||||
async_handler = AsyncHTTPHandler(
|
||||
|
@ -93,9 +93,9 @@ class TritonChatCompletion(BaseLLM):
|
|||
timeout: float,
|
||||
api_base: str,
|
||||
model_response: litellm.utils.EmbeddingResponse,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
aembedding: bool = False,
|
||||
) -> EmbeddingResponse:
|
||||
|
@ -122,7 +122,7 @@ class TritonChatCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
if aembedding:
|
||||
response = await self.aembedding(
|
||||
response = await self.aembedding( # type: ignore
|
||||
data=data_for_triton,
|
||||
model_response=model_response,
|
||||
logging_obj=logging_obj,
|
||||
|
@ -141,10 +141,10 @@ class TritonChatCompletion(BaseLLM):
|
|||
messages: List[dict],
|
||||
timeout: float,
|
||||
api_base: str,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
model_response: ModelResponse,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
client=None,
|
||||
stream: Optional[bool] = False,
|
||||
acompletion: bool = False,
|
||||
|
@ -239,11 +239,13 @@ class TritonChatCompletion(BaseLLM):
|
|||
else:
|
||||
handler = HTTPHandler()
|
||||
if stream:
|
||||
return self._handle_stream(
|
||||
return self._handle_stream( # type: ignore
|
||||
handler, api_base, json_data_for_triton, model, logging_obj
|
||||
)
|
||||
else:
|
||||
response = handler.post(url=api_base, data=json_data_for_triton, headers=headers)
|
||||
response = handler.post(
|
||||
url=api_base, data=json_data_for_triton, headers=headers
|
||||
)
|
||||
return self._handle_response(
|
||||
response, model_response, logging_obj, type_of_model=type_of_model
|
||||
)
|
||||
|
@ -261,7 +263,7 @@ class TritonChatCompletion(BaseLLM):
|
|||
) -> ModelResponse:
|
||||
handler = AsyncHTTPHandler()
|
||||
if stream:
|
||||
return self._ahandle_stream(
|
||||
return self._ahandle_stream( # type: ignore
|
||||
handler, api_base, data_for_triton, model, logging_obj
|
||||
)
|
||||
else:
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import List, Literal, Tuple
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm import supports_system_messages, supports_response_schema, verbose_logger
|
||||
from litellm import supports_response_schema, supports_system_messages, verbose_logger
|
||||
from litellm.types.llms.vertex_ai import PartType
|
||||
|
||||
|
||||
|
@ -67,6 +67,8 @@ def _get_vertex_url(
|
|||
vertex_location: Optional[str],
|
||||
vertex_api_version: Literal["v1", "v1beta1"],
|
||||
) -> Tuple[str, str]:
|
||||
url: Optional[str] = None
|
||||
endpoint: Optional[str] = None
|
||||
if mode == "chat":
|
||||
### SET RUNTIME ENDPOINT ###
|
||||
endpoint = "generateContent"
|
||||
|
@ -88,6 +90,8 @@ def _get_vertex_url(
|
|||
endpoint = "predict"
|
||||
url = f"https://{vertex_location}-aiplatform.googleapis.com/v1/projects/{vertex_project}/locations/{vertex_location}/publishers/google/models/{model}:{endpoint}"
|
||||
|
||||
if not url or not endpoint:
|
||||
raise ValueError(f"Unable to get vertex url/endpoinit for mode: {mode}")
|
||||
return url, endpoint
|
||||
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@ Google AI Studio /batchEmbedContents Embeddings Endpoint
|
|||
"""
|
||||
|
||||
import json
|
||||
from typing import List, Literal, Optional, Union
|
||||
from typing import Any, List, Literal, Optional, Union
|
||||
|
||||
import httpx
|
||||
|
||||
|
@ -31,9 +31,9 @@ class GoogleBatchEmbeddings(VertexLLM):
|
|||
model_response: EmbeddingResponse,
|
||||
custom_llm_provider: Literal["gemini", "vertex_ai"],
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
vertex_project=None,
|
||||
vertex_location=None,
|
||||
|
|
|
@ -43,17 +43,17 @@ class VertexImageGeneration(VertexLLM):
|
|||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
model_response: litellm.ImageResponse,
|
||||
logging_obj: Any,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[Any] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
aimg_generation=False,
|
||||
):
|
||||
) -> litellm.ImageResponse:
|
||||
if aimg_generation is True:
|
||||
return self.aimage_generation(
|
||||
return self.aimage_generation( # type: ignore
|
||||
prompt=prompt,
|
||||
vertex_project=vertex_project,
|
||||
vertex_location=vertex_location,
|
||||
|
@ -138,13 +138,13 @@ class VertexImageGeneration(VertexLLM):
|
|||
vertex_location: Optional[str],
|
||||
vertex_credentials: Optional[str],
|
||||
model_response: litellm.ImageResponse,
|
||||
logging_obj: Any,
|
||||
model: Optional[
|
||||
str
|
||||
] = "imagegeneration", # vertex ai uses imagegeneration as the default model
|
||||
client: Optional[AsyncHTTPHandler] = None,
|
||||
optional_params: Optional[dict] = None,
|
||||
timeout: Optional[int] = None,
|
||||
logging_obj=None,
|
||||
):
|
||||
response = None
|
||||
if client is None:
|
||||
|
|
|
@ -48,7 +48,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
aembedding=False,
|
||||
timeout=300,
|
||||
client=None,
|
||||
):
|
||||
) -> litellm.EmbeddingResponse:
|
||||
|
||||
_auth_header, vertex_project = self._ensure_access_token(
|
||||
credentials=vertex_credentials,
|
||||
|
@ -121,7 +121,7 @@ class VertexMultimodalEmbedding(VertexLLM):
|
|||
)
|
||||
|
||||
if aembedding is True:
|
||||
return self.async_multimodal_embedding(
|
||||
return self.async_multimodal_embedding( # type: ignore
|
||||
model=model,
|
||||
api_base=url,
|
||||
data=request_data,
|
||||
|
|
|
@ -62,7 +62,7 @@ class VertexTextToSpeechAPI(VertexLLM):
|
|||
_is_async: Optional[bool] = False,
|
||||
optional_params: Optional[dict] = None,
|
||||
kwargs: Optional[dict] = None,
|
||||
):
|
||||
) -> HttpxBinaryResponseContent:
|
||||
import base64
|
||||
|
||||
####### Authenticate with Vertex AI ########
|
||||
|
@ -145,7 +145,7 @@ class VertexTextToSpeechAPI(VertexLLM):
|
|||
########## End of logging ############
|
||||
####### Send the request ###################
|
||||
if _is_async is True:
|
||||
return self.async_audio_speech(
|
||||
return self.async_audio_speech( # type:ignore
|
||||
logging_obj=logging_obj, url=url, headers=headers, request=request
|
||||
)
|
||||
sync_handler = _get_httpx_client()
|
||||
|
|
|
@ -52,9 +52,9 @@ class VertexEmbedding(VertexBase):
|
|||
vertex_credentials: Optional[str] = None,
|
||||
gemini_api_key: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
) -> litellm.EmbeddingResponse:
|
||||
if aembedding is True:
|
||||
return self.async_embedding(
|
||||
return self.async_embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
logging_obj=logging_obj,
|
||||
|
|
|
@ -25,13 +25,8 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.utils import (
|
||||
EmbeddingResponse,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
get_secret,
|
||||
map_finish_reason,
|
||||
)
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.utils import EmbeddingResponse, ModelResponse, Usage, map_finish_reason
|
||||
|
||||
from .base import BaseLLM
|
||||
from .prompt_templates import factory as ptf
|
||||
|
@ -184,7 +179,7 @@ class IBMWatsonXAIConfig:
|
|||
]
|
||||
|
||||
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
||||
def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict) -> str:
|
||||
# handle anthropic prompts and amazon titan prompts
|
||||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
|
@ -200,14 +195,10 @@ def convert_messages_to_prompt(model, messages, provider, custom_prompt_dict):
|
|||
eos_token=model_prompt_dict.get("eos_token", ""),
|
||||
)
|
||||
return prompt
|
||||
elif provider == "ibm":
|
||||
prompt = ptf.prompt_factory(
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
elif provider == "ibm-mistralai":
|
||||
prompt = ptf.mistral_instruct_pt(messages=messages)
|
||||
else:
|
||||
prompt = ptf.prompt_factory(
|
||||
prompt: str = ptf.prompt_factory( # type: ignore
|
||||
model=model, messages=messages, custom_llm_provider="watsonx"
|
||||
)
|
||||
return prompt
|
||||
|
@ -327,37 +318,37 @@ class IBMWatsonXAI(BaseLLM):
|
|||
# Load auth variables from environment variables
|
||||
if url is None:
|
||||
url = (
|
||||
get_secret("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret("WATSONX_URL")
|
||||
or get_secret("WX_URL")
|
||||
or get_secret("WML_URL")
|
||||
get_secret_str("WATSONX_API_BASE") # consistent with 'AZURE_API_BASE'
|
||||
or get_secret_str("WATSONX_URL")
|
||||
or get_secret_str("WX_URL")
|
||||
or get_secret_str("WML_URL")
|
||||
)
|
||||
if api_key is None:
|
||||
api_key = (
|
||||
get_secret("WATSONX_APIKEY")
|
||||
or get_secret("WATSONX_API_KEY")
|
||||
or get_secret("WX_API_KEY")
|
||||
get_secret_str("WATSONX_APIKEY")
|
||||
or get_secret_str("WATSONX_API_KEY")
|
||||
or get_secret_str("WX_API_KEY")
|
||||
)
|
||||
if token is None:
|
||||
token = get_secret("WATSONX_TOKEN") or get_secret("WX_TOKEN")
|
||||
token = get_secret_str("WATSONX_TOKEN") or get_secret_str("WX_TOKEN")
|
||||
if project_id is None:
|
||||
project_id = (
|
||||
get_secret("WATSONX_PROJECT_ID")
|
||||
or get_secret("WX_PROJECT_ID")
|
||||
or get_secret("PROJECT_ID")
|
||||
get_secret_str("WATSONX_PROJECT_ID")
|
||||
or get_secret_str("WX_PROJECT_ID")
|
||||
or get_secret_str("PROJECT_ID")
|
||||
)
|
||||
if region_name is None:
|
||||
region_name = (
|
||||
get_secret("WATSONX_REGION")
|
||||
or get_secret("WX_REGION")
|
||||
or get_secret("REGION")
|
||||
get_secret_str("WATSONX_REGION")
|
||||
or get_secret_str("WX_REGION")
|
||||
or get_secret_str("REGION")
|
||||
)
|
||||
if space_id is None:
|
||||
space_id = (
|
||||
get_secret("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret("WATSONX_SPACE_ID")
|
||||
or get_secret("WX_SPACE_ID")
|
||||
or get_secret("SPACE_ID")
|
||||
get_secret_str("WATSONX_DEPLOYMENT_SPACE_ID")
|
||||
or get_secret_str("WATSONX_SPACE_ID")
|
||||
or get_secret_str("WX_SPACE_ID")
|
||||
or get_secret_str("SPACE_ID")
|
||||
)
|
||||
|
||||
# credentials parsing
|
||||
|
@ -446,8 +437,8 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model_response: ModelResponse,
|
||||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
acompletion=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -592,13 +583,13 @@ class IBMWatsonXAI(BaseLLM):
|
|||
model: str,
|
||||
input: Union[list, str],
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
api_key: Optional[str] = None,
|
||||
logging_obj=None,
|
||||
optional_params=None,
|
||||
api_key: Optional[str],
|
||||
logging_obj: Any,
|
||||
optional_params: dict,
|
||||
encoding=None,
|
||||
print_verbose=None,
|
||||
aembedding=None,
|
||||
):
|
||||
) -> litellm.EmbeddingResponse:
|
||||
"""
|
||||
Send a text embedding request to the IBM Watsonx.ai API.
|
||||
"""
|
||||
|
@ -657,7 +648,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
|
||||
try:
|
||||
if aembedding is True:
|
||||
return handle_aembedding(req_params)
|
||||
return handle_aembedding(req_params) # type: ignore
|
||||
else:
|
||||
return handle_embedding(req_params)
|
||||
except WatsonXAIError as e:
|
||||
|
@ -669,7 +660,7 @@ class IBMWatsonXAI(BaseLLM):
|
|||
headers = {}
|
||||
headers["Content-Type"] = "application/x-www-form-urlencoded"
|
||||
if api_key is None:
|
||||
api_key = get_secret("WX_API_KEY") or get_secret("WATSONX_API_KEY")
|
||||
api_key = get_secret_str("WX_API_KEY") or get_secret_str("WATSONX_API_KEY")
|
||||
if api_key is None:
|
||||
raise ValueError("API key is required")
|
||||
headers["Accept"] = "application/json"
|
||||
|
@ -812,22 +803,29 @@ class RequestManager:
|
|||
request_params["data"] = json.dumps(request_params.pop("json", {}))
|
||||
method = request_params.pop("method")
|
||||
retries = 0
|
||||
resp: Optional[httpx.Response] = None
|
||||
while retries < 3:
|
||||
if method.upper() == "POST":
|
||||
resp = await self.async_handler.post(**request_params)
|
||||
else:
|
||||
resp = await self.async_handler.get(**request_params)
|
||||
if resp.status_code in [429, 503, 504, 520]:
|
||||
if resp is not None and resp.status_code in [429, 503, 504, 520]:
|
||||
# to handle rate limiting and service unavailable errors
|
||||
# see: ibm_watsonx_ai.foundation_models.inference.base_model_inference.BaseModelInference._send_inference_payload
|
||||
await asyncio.sleep(2**retries)
|
||||
retries += 1
|
||||
else:
|
||||
break
|
||||
if resp is None:
|
||||
raise WatsonXAIError(
|
||||
status_code=500,
|
||||
message="No response from the server",
|
||||
)
|
||||
if resp.is_error:
|
||||
error_reason = getattr(resp, "reason", "")
|
||||
raise WatsonXAIError(
|
||||
status_code=resp.status_code,
|
||||
message=f"Error {resp.status_code} ({resp.reason}): {resp.text}",
|
||||
message=f"Error {resp.status_code} ({error_reason}): {resp.text}",
|
||||
)
|
||||
yield resp
|
||||
# await async_handler.close()
|
||||
|
|
218
litellm/main.py
218
litellm/main.py
|
@ -19,7 +19,8 @@ import threading
|
|||
import time
|
||||
import traceback
|
||||
import uuid
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from concurrent import futures
|
||||
from concurrent.futures import FIRST_COMPLETED, ThreadPoolExecutor, wait
|
||||
from copy import deepcopy
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Dict, List, Literal, Mapping, Optional, Type, Union
|
||||
|
@ -647,7 +648,7 @@ def mock_completion(
|
|||
|
||||
|
||||
@client
|
||||
def completion(
|
||||
def completion( # type: ignore
|
||||
model: str,
|
||||
# Optional OpenAI params: see https://platform.openai.com/docs/api-reference/chat/create
|
||||
messages: List = [],
|
||||
|
@ -2940,16 +2941,16 @@ def completion_with_retries(*args, **kwargs):
|
|||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy: Literal["exponential_backoff_retry", "constant_retry"] = kwargs.pop("retry_strategy", "constant_retry") # type: ignore
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
wait=tenacity.wait_exponential(multiplier=1, max=10),
|
||||
stop=tenacity.stop_after_attempt(num_retries),
|
||||
reraise=True,
|
||||
)
|
||||
else:
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
return retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -2968,16 +2969,16 @@ async def acompletion_with_retries(*args, **kwargs):
|
|||
num_retries = kwargs.pop("num_retries", 3)
|
||||
retry_strategy = kwargs.pop("retry_strategy", "constant_retry")
|
||||
original_function = kwargs.pop("original_function", completion)
|
||||
if retry_strategy == "constant_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
elif retry_strategy == "exponential_backoff_retry":
|
||||
if retry_strategy == "exponential_backoff_retry":
|
||||
retryer = tenacity.Retrying(
|
||||
wait=tenacity.wait_exponential(multiplier=1, max=10),
|
||||
stop=tenacity.stop_after_attempt(num_retries),
|
||||
reraise=True,
|
||||
)
|
||||
else:
|
||||
retryer = tenacity.Retrying(
|
||||
stop=tenacity.stop_after_attempt(num_retries), reraise=True
|
||||
)
|
||||
return await retryer(original_function, *args, **kwargs)
|
||||
|
||||
|
||||
|
@ -3045,7 +3046,7 @@ def batch_completion(
|
|||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stream=stream,
|
||||
stream=stream or False,
|
||||
stop=stop,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
|
@ -3124,7 +3125,7 @@ def batch_completion_models(*args, **kwargs):
|
|||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
futures = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
with ThreadPoolExecutor(max_workers=len(models)) as executor:
|
||||
for model in models:
|
||||
futures[model] = executor.submit(
|
||||
completion, *args, model=model, **kwargs
|
||||
|
@ -3141,9 +3142,7 @@ def batch_completion_models(*args, **kwargs):
|
|||
kwargs.pop("model_list")
|
||||
nested_kwargs = kwargs.pop("kwargs", {})
|
||||
futures = {}
|
||||
with concurrent.futures.ThreadPoolExecutor(
|
||||
max_workers=len(deployments)
|
||||
) as executor:
|
||||
with ThreadPoolExecutor(max_workers=len(deployments)) as executor:
|
||||
for deployment in deployments:
|
||||
for key in kwargs.keys():
|
||||
if (
|
||||
|
@ -3156,9 +3155,7 @@ def batch_completion_models(*args, **kwargs):
|
|||
while futures:
|
||||
# wait for the first returned future
|
||||
print_verbose("\n\n waiting for next result\n\n")
|
||||
done, _ = concurrent.futures.wait(
|
||||
futures.values(), return_when=concurrent.futures.FIRST_COMPLETED
|
||||
)
|
||||
done, _ = wait(futures.values(), return_when=FIRST_COMPLETED)
|
||||
print_verbose(f"done list\n{done}")
|
||||
for future in done:
|
||||
try:
|
||||
|
@ -3214,6 +3211,8 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
|||
if "models" in kwargs:
|
||||
models = kwargs["models"]
|
||||
kwargs.pop("models")
|
||||
else:
|
||||
raise Exception("'models' param not in kwargs")
|
||||
|
||||
responses = []
|
||||
|
||||
|
@ -3256,6 +3255,7 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
model=model, api_base=kwargs.get("api_base", None)
|
||||
)
|
||||
|
||||
response: Optional[EmbeddingResponse] = None
|
||||
if (
|
||||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "azure"
|
||||
|
@ -3294,12 +3294,21 @@ async def aembedding(*args, **kwargs) -> EmbeddingResponse:
|
|||
elif isinstance(init_response, EmbeddingResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
if (
|
||||
response is not None
|
||||
and isinstance(response, EmbeddingResponse)
|
||||
and hasattr(response, "_hidden_params")
|
||||
):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
if response is None:
|
||||
raise ValueError(
|
||||
"Unable to get Embedding Response. Please pass a valid llm_provider."
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
custom_llm_provider = custom_llm_provider or "openai"
|
||||
|
@ -3329,7 +3338,6 @@ def embedding(
|
|||
user: Optional[str] = None,
|
||||
custom_llm_provider=None,
|
||||
litellm_call_id=None,
|
||||
litellm_logging_obj=None,
|
||||
logger_fn=None,
|
||||
**kwargs,
|
||||
) -> EmbeddingResponse:
|
||||
|
@ -3362,6 +3370,7 @@ def embedding(
|
|||
client = kwargs.pop("client", None)
|
||||
rpm = kwargs.pop("rpm", None)
|
||||
tpm = kwargs.pop("tpm", None)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
cooldown_time = kwargs.get("cooldown_time", None)
|
||||
max_parallel_requests = kwargs.pop("max_parallel_requests", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
|
@ -3491,7 +3500,7 @@ def embedding(
|
|||
}
|
||||
)
|
||||
try:
|
||||
response = None
|
||||
response: Optional[EmbeddingResponse] = None
|
||||
logging: Logging = litellm_logging_obj # type: ignore
|
||||
logging.update_environment_variables(
|
||||
model=model,
|
||||
|
@ -3691,7 +3700,7 @@ def embedding(
|
|||
raise ValueError(
|
||||
"api_base is required for triton. Please pass `api_base`"
|
||||
)
|
||||
response = triton_chat_completions.embedding(
|
||||
response = triton_chat_completions.embedding( # type: ignore
|
||||
model=model,
|
||||
input=input,
|
||||
api_base=api_base,
|
||||
|
@ -3783,6 +3792,7 @@ def embedding(
|
|||
timeout=timeout,
|
||||
aembedding=aembedding,
|
||||
print_verbose=print_verbose,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "oobabooga":
|
||||
response = oobabooga.embedding(
|
||||
|
@ -3793,14 +3803,16 @@ def embedding(
|
|||
logging_obj=logging,
|
||||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "ollama":
|
||||
api_base = (
|
||||
litellm.api_base
|
||||
or api_base
|
||||
or get_secret("OLLAMA_API_BASE")
|
||||
or get_secret_str("OLLAMA_API_BASE")
|
||||
or "http://localhost:11434"
|
||||
) # type: ignore
|
||||
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if not all(isinstance(item, str) for item in input):
|
||||
|
@ -3881,13 +3893,13 @@ def embedding(
|
|||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or get_secret("XINFERENCE_API_KEY")
|
||||
or get_secret_str("XINFERENCE_API_KEY")
|
||||
or "stub-xinference-key"
|
||||
) # xinference does not need an api key, pass a stub key if user did not set one
|
||||
api_base = (
|
||||
api_base
|
||||
or litellm.api_base
|
||||
or get_secret("XINFERENCE_API_BASE")
|
||||
or get_secret_str("XINFERENCE_API_BASE")
|
||||
or "http://127.0.0.1:9997/v1"
|
||||
)
|
||||
response = openai_chat_completions.embedding(
|
||||
|
@ -3911,19 +3923,20 @@ def embedding(
|
|||
optional_params=optional_params,
|
||||
model_response=EmbeddingResponse(),
|
||||
aembedding=aembedding,
|
||||
api_key=api_key,
|
||||
)
|
||||
elif custom_llm_provider == "azure_ai":
|
||||
api_base = (
|
||||
api_base # for deepinfra/perplexity/anyscale/groq/friendliai we check in get_llm_provider and pass in the api base from there
|
||||
or litellm.api_base
|
||||
or get_secret("AZURE_AI_API_BASE")
|
||||
or get_secret_str("AZURE_AI_API_BASE")
|
||||
)
|
||||
# set API KEY
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key # for deepinfra/perplexity/anyscale/friendliai we check in get_llm_provider and pass in the api key from there
|
||||
or litellm.openai_key
|
||||
or get_secret("AZURE_AI_API_KEY")
|
||||
or get_secret_str("AZURE_AI_API_KEY")
|
||||
)
|
||||
|
||||
## EMBEDDING CALL
|
||||
|
@ -3944,10 +3957,14 @@ def embedding(
|
|||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
if response is not None and hasattr(response, "_hidden_params"):
|
||||
response._hidden_params["custom_llm_provider"] = custom_llm_provider
|
||||
|
||||
if response is None:
|
||||
args = locals()
|
||||
raise ValueError(f"No valid embedding model args passed in - {args}")
|
||||
return response
|
||||
except Exception as e:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
litellm_logging_obj.post_call(
|
||||
input=input,
|
||||
api_key=api_key,
|
||||
original_response=str(e),
|
||||
|
@ -4018,7 +4035,11 @@ async def atext_completion(
|
|||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
if kwargs.get("stream", False) is True: # return an async generator
|
||||
if (
|
||||
kwargs.get("stream", False) is True
|
||||
or isinstance(response, TextCompletionStreamWrapper)
|
||||
or isinstance(response, CustomStreamWrapper)
|
||||
): # return an async generator
|
||||
return TextCompletionStreamWrapper(
|
||||
completion_stream=_async_streaming(
|
||||
response=response,
|
||||
|
@ -4153,9 +4174,10 @@ def text_completion(
|
|||
Your example of how to use this function goes here.
|
||||
"""
|
||||
if "engine" in kwargs:
|
||||
if model is None:
|
||||
_engine = kwargs["engine"]
|
||||
if model is None and isinstance(_engine, str):
|
||||
# only use engine when model not passed
|
||||
model = kwargs["engine"]
|
||||
model = _engine
|
||||
kwargs.pop("engine")
|
||||
|
||||
text_completion_response = TextCompletionResponse()
|
||||
|
@ -4223,7 +4245,7 @@ def text_completion(
|
|||
def process_prompt(i, individual_prompt):
|
||||
decoded_prompt = tokenizer.decode(individual_prompt)
|
||||
all_params = {**kwargs, **optional_params}
|
||||
response = text_completion(
|
||||
response: TextCompletionResponse = text_completion( # type: ignore
|
||||
model=model,
|
||||
prompt=decoded_prompt,
|
||||
num_retries=3, # ensure this does not fail for the batch
|
||||
|
@ -4292,6 +4314,8 @@ def text_completion(
|
|||
model = "text-completion-openai/" + _model
|
||||
optional_params.pop("custom_llm_provider", None)
|
||||
|
||||
if model is None:
|
||||
raise ValueError("model is not set. Set either via 'model' or 'engine' param.")
|
||||
kwargs["text_completion"] = True
|
||||
response = completion(
|
||||
model=model,
|
||||
|
@ -4302,7 +4326,11 @@ def text_completion(
|
|||
)
|
||||
if kwargs.get("acompletion", False) is True:
|
||||
return response
|
||||
if stream is True or kwargs.get("stream", False) is True:
|
||||
if (
|
||||
stream is True
|
||||
or kwargs.get("stream", False) is True
|
||||
or isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
response = TextCompletionStreamWrapper(
|
||||
completion_stream=response,
|
||||
model=model,
|
||||
|
@ -4310,6 +4338,8 @@ def text_completion(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
return response
|
||||
elif isinstance(response, TextCompletionStreamWrapper):
|
||||
return response
|
||||
transformed_logprobs = None
|
||||
# only supported for TGI models
|
||||
try:
|
||||
|
@ -4424,7 +4454,10 @@ def moderation(
|
|||
):
|
||||
# only supports open ai for now
|
||||
api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
openai_client = kwargs.get("client", None)
|
||||
|
@ -4433,7 +4466,10 @@ def moderation(
|
|||
api_key=api_key,
|
||||
)
|
||||
|
||||
if model is not None:
|
||||
response = openai_client.moderations.create(input=input, model=model)
|
||||
else:
|
||||
response = openai_client.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
|
||||
|
@ -4441,20 +4477,30 @@ def moderation(
|
|||
async def amoderation(
|
||||
input: str, model: Optional[str] = None, api_key: Optional[str] = None, **kwargs
|
||||
):
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
# only supports open ai for now
|
||||
api_key = (
|
||||
api_key or litellm.api_key or litellm.openai_key or get_secret("OPENAI_API_KEY")
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.openai_key
|
||||
or get_secret_str("OPENAI_API_KEY")
|
||||
)
|
||||
openai_client = kwargs.get("client", None)
|
||||
if openai_client is None:
|
||||
if openai_client is None or not isinstance(openai_client, AsyncOpenAI):
|
||||
|
||||
# call helper to get OpenAI client
|
||||
# _get_openai_client maintains in-memory caching logic for OpenAI clients
|
||||
openai_client = openai_chat_completions._get_openai_client(
|
||||
_openai_client: AsyncOpenAI = openai_chat_completions._get_openai_client( # type: ignore
|
||||
is_async=True,
|
||||
api_key=api_key,
|
||||
)
|
||||
else:
|
||||
_openai_client = openai_client
|
||||
if model is not None:
|
||||
response = await openai_client.moderations.create(input=input, model=model)
|
||||
else:
|
||||
response = await openai_client.moderations.create(input=input)
|
||||
return response
|
||||
|
||||
|
||||
|
@ -4497,7 +4543,7 @@ async def aimage_generation(*args, **kwargs) -> ImageResponse:
|
|||
init_response = ImageResponse(**init_response)
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -4527,7 +4573,6 @@ def image_generation(
|
|||
api_key: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
litellm_logging_obj=None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs,
|
||||
) -> ImageResponse:
|
||||
|
@ -4543,9 +4588,10 @@ def image_generation(
|
|||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
client = kwargs.get("client", None)
|
||||
|
||||
model_response = litellm.utils.ImageResponse()
|
||||
model_response: ImageResponse = litellm.utils.ImageResponse()
|
||||
if model is not None or custom_llm_provider is not None:
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
else:
|
||||
|
@ -4651,25 +4697,27 @@ def image_generation(
|
|||
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_type = get_secret("AZURE_API_TYPE") or "azure"
|
||||
api_type = get_secret_str("AZURE_API_TYPE") or "azure"
|
||||
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
api_version
|
||||
or litellm.api_version
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
azure_ad_token = optional_params.pop("azure_ad_token", None) or get_secret(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
azure_ad_token = optional_params.pop(
|
||||
"azure_ad_token", None
|
||||
) or get_secret_str("AZURE_AD_TOKEN")
|
||||
|
||||
model_response = azure_chat_completions.image_generation(
|
||||
model=model,
|
||||
|
@ -4714,18 +4762,18 @@ def image_generation(
|
|||
optional_params.pop("vertex_project", None)
|
||||
or optional_params.pop("vertex_ai_project", None)
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.pop("vertex_location", None)
|
||||
or optional_params.pop("vertex_ai_location", None)
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = (
|
||||
optional_params.pop("vertex_credentials", None)
|
||||
or optional_params.pop("vertex_ai_credentials", None)
|
||||
or get_secret("VERTEXAI_CREDENTIALS")
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
model_response = vertex_image_generation.image_generation(
|
||||
model=model,
|
||||
|
@ -4786,7 +4834,7 @@ async def atranscription(*args, **kwargs) -> TranscriptionResponse:
|
|||
elif isinstance(init_response, TranscriptionResponse): ## CACHING SCENARIO
|
||||
response = init_response
|
||||
elif asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
response = await init_response # type: ignore
|
||||
else:
|
||||
# Call the synchronous function using run_in_executor
|
||||
response = await loop.run_in_executor(None, func_with_context)
|
||||
|
@ -4820,7 +4868,6 @@ def transcription(
|
|||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
max_retries: Optional[int] = None,
|
||||
litellm_logging_obj: Optional[LiteLLMLoggingObj] = None,
|
||||
custom_llm_provider=None,
|
||||
**kwargs,
|
||||
) -> TranscriptionResponse:
|
||||
|
@ -4830,6 +4877,7 @@ def transcription(
|
|||
Allows router to load balance between them
|
||||
"""
|
||||
atranscription = kwargs.get("atranscription", False)
|
||||
litellm_logging_obj: LiteLLMLoggingObj = kwargs.get("litellm_logging_obj") # type: ignore
|
||||
kwargs.get("litellm_call_id", None)
|
||||
kwargs.get("logger_fn", None)
|
||||
kwargs.get("proxy_server_request", None)
|
||||
|
@ -4869,22 +4917,17 @@ def transcription(
|
|||
custom_llm_provider=custom_llm_provider,
|
||||
drop_params=drop_params,
|
||||
)
|
||||
# optional_params = {
|
||||
# "language": language,
|
||||
# "prompt": prompt,
|
||||
# "response_format": response_format,
|
||||
# "temperature": None, # openai defaults this to 0
|
||||
# }
|
||||
|
||||
response: Optional[TranscriptionResponse] = None
|
||||
if custom_llm_provider == "azure":
|
||||
# azure configs
|
||||
api_base = api_base or litellm.api_base or get_secret("AZURE_API_BASE")
|
||||
api_base = api_base or litellm.api_base or get_secret_str("AZURE_API_BASE")
|
||||
|
||||
api_version = (
|
||||
api_version or litellm.api_version or get_secret("AZURE_API_VERSION")
|
||||
api_version or litellm.api_version or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret(
|
||||
azure_ad_token = kwargs.pop("azure_ad_token", None) or get_secret_str(
|
||||
"AZURE_AD_TOKEN"
|
||||
)
|
||||
|
||||
|
@ -4892,8 +4935,8 @@ def transcription(
|
|||
api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
response = azure_audio_transcriptions.audio_transcriptions(
|
||||
model=model,
|
||||
|
@ -4942,6 +4985,9 @@ def transcription(
|
|||
api_base=api_base,
|
||||
api_key=api_key,
|
||||
)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Unmapped provider passed in. Unable to get the response.")
|
||||
return response
|
||||
|
||||
|
||||
|
@ -5149,15 +5195,16 @@ def speech(
|
|||
vertex_ai_project = (
|
||||
generic_optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
generic_optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = generic_optional_params.vertex_credentials or get_secret(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
vertex_credentials = (
|
||||
generic_optional_params.vertex_credentials
|
||||
or get_secret_str("VERTEXAI_CREDENTIALS")
|
||||
)
|
||||
|
||||
if voice is not None and not isinstance(voice, dict):
|
||||
|
@ -5234,20 +5281,25 @@ async def ahealth_check(
|
|||
if custom_llm_provider == "azure":
|
||||
api_key = (
|
||||
model_params.get("api_key")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
)
|
||||
|
||||
api_base = (
|
||||
api_base: Optional[str] = (
|
||||
model_params.get("api_base")
|
||||
or get_secret("AZURE_API_BASE")
|
||||
or get_secret("AZURE_OPENAI_API_BASE")
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
or get_secret_str("AZURE_OPENAI_API_BASE")
|
||||
)
|
||||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
"Azure API Base cannot be None. Set via 'AZURE_API_BASE' in env var or `.completion(..., api_base=..)`"
|
||||
)
|
||||
|
||||
api_version = (
|
||||
model_params.get("api_version")
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret("AZURE_OPENAI_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_OPENAI_API_VERSION")
|
||||
)
|
||||
|
||||
timeout = (
|
||||
|
@ -5273,7 +5325,7 @@ async def ahealth_check(
|
|||
custom_llm_provider == "openai"
|
||||
or custom_llm_provider == "text-completion-openai"
|
||||
):
|
||||
api_key = model_params.get("api_key") or get_secret("OPENAI_API_KEY")
|
||||
api_key = model_params.get("api_key") or get_secret_str("OPENAI_API_KEY")
|
||||
organization = model_params.get("organization")
|
||||
|
||||
timeout = (
|
||||
|
@ -5282,7 +5334,7 @@ async def ahealth_check(
|
|||
or default_timeout
|
||||
)
|
||||
|
||||
api_base = model_params.get("api_base") or get_secret("OPENAI_API_BASE")
|
||||
api_base = model_params.get("api_base") or get_secret_str("OPENAI_API_BASE")
|
||||
|
||||
if custom_llm_provider == "text-completion-openai":
|
||||
mode = "completion"
|
||||
|
@ -5377,7 +5429,9 @@ def config_completion(**kwargs):
|
|||
)
|
||||
|
||||
|
||||
def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List] = None):
|
||||
def stream_chunk_builder_text_completion(
|
||||
chunks: list, messages: Optional[List] = None
|
||||
) -> TextCompletionResponse:
|
||||
id = chunks[0]["id"]
|
||||
object = chunks[0]["object"]
|
||||
created = chunks[0]["created"]
|
||||
|
@ -5446,7 +5500,7 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
|
|||
response["usage"]["total_tokens"] = (
|
||||
response["usage"]["prompt_tokens"] + response["usage"]["completion_tokens"]
|
||||
)
|
||||
return response
|
||||
return TextCompletionResponse(**response)
|
||||
|
||||
|
||||
def stream_chunk_builder(
|
||||
|
|
|
@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Union
|
|||
from pydantic import BaseModel, ConfigDict, Extra, Field, Json, model_validator
|
||||
from typing_extensions import Annotated, TypedDict
|
||||
|
||||
from litellm.integrations.SlackAlerting.types import AlertType
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
from litellm.types.router import RouterErrors, UpdateRouterConfig
|
||||
from litellm.types.utils import ProviderField, StandardCallbackDynamicParams
|
||||
|
||||
|
|
|
@ -1,12 +1,13 @@
|
|||
from typing import Optional
|
||||
from fastapi import Depends, Request, APIRouter
|
||||
from fastapi import HTTPException
|
||||
import copy
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Request
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.caching import RedisCache
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/cache",
|
||||
tags=["caching"],
|
||||
|
@ -21,9 +22,9 @@ async def cache_ping():
|
|||
"""
|
||||
Endpoint for checking if cache can be pinged
|
||||
"""
|
||||
try:
|
||||
litellm_cache_params = {}
|
||||
specific_cache_params = {}
|
||||
try:
|
||||
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
|
@ -135,7 +136,9 @@ async def cache_redis_info():
|
|||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
if litellm.cache.type == "redis":
|
||||
if litellm.cache.type == "redis" and isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
client_list = litellm.cache.cache.client_list()
|
||||
redis_info = litellm.cache.cache.info()
|
||||
num_clients = len(client_list)
|
||||
|
@ -177,7 +180,9 @@ async def cache_flushall():
|
|||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
if litellm.cache.type == "redis":
|
||||
if litellm.cache.type == "redis" and isinstance(
|
||||
litellm.cache.cache, RedisCache
|
||||
):
|
||||
litellm.cache.cache.flushall()
|
||||
return {
|
||||
"status": "success",
|
||||
|
|
|
@ -118,13 +118,13 @@ async def create_fine_tuning_job(
|
|||
version,
|
||||
)
|
||||
|
||||
data = fine_tuning_request.model_dump(exclude_none=True)
|
||||
try:
|
||||
if premium_user is not True:
|
||||
raise ValueError(
|
||||
f"Only premium users can use this endpoint + {CommonProxyErrors.not_premium_user.value}"
|
||||
)
|
||||
# Convert Pydantic model to dict
|
||||
data = fine_tuning_request.model_dump(exclude_none=True)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
"Request received by LiteLLM:\n{}".format(json.dumps(data, indent=4)),
|
||||
|
@ -146,6 +146,7 @@ async def create_fine_tuning_job(
|
|||
)
|
||||
|
||||
# add llm_provider_config to data
|
||||
if llm_provider_config is not None:
|
||||
data.update(llm_provider_config)
|
||||
|
||||
response = await litellm.acreate_fine_tuning_job(**data)
|
||||
|
@ -262,6 +263,7 @@ async def list_fine_tuning_jobs(
|
|||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if llm_provider_config is not None:
|
||||
data.update(llm_provider_config)
|
||||
|
||||
response = await litellm.alist_fine_tuning_jobs(
|
||||
|
@ -378,6 +380,7 @@ async def retrieve_fine_tuning_job(
|
|||
custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
|
||||
if llm_provider_config is not None:
|
||||
data.update(llm_provider_config)
|
||||
|
||||
response = await litellm.acancel_fine_tuning_job(
|
||||
|
|
|
@ -227,8 +227,8 @@ async def send_management_endpoint_alert(
|
|||
- An internal user is created, updated, or deleted
|
||||
- A team is created, updated, or deleted
|
||||
"""
|
||||
from litellm.integrations.SlackAlerting.types import AlertType
|
||||
from litellm.proxy.proxy_server import premium_user, proxy_logging_obj
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
|
||||
if premium_user is not True:
|
||||
return
|
||||
|
|
|
@ -114,10 +114,7 @@ from litellm import (
|
|||
from litellm._logging import verbose_proxy_logger, verbose_router_logger
|
||||
from litellm.caching import DualCache, RedisCache
|
||||
from litellm.exceptions import RejectedRequestError
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import (
|
||||
SlackAlerting,
|
||||
SlackAlertingArgs,
|
||||
)
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
get_litellm_metadata_from_kwargs,
|
||||
|
@ -249,6 +246,7 @@ from litellm.secret_managers.aws_secret_manager import (
|
|||
)
|
||||
from litellm.secret_managers.google_kms import load_google_kms
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str, str_to_bool
|
||||
from litellm.types.integrations.slack_alerting import SlackAlertingArgs
|
||||
from litellm.types.llms.anthropic import (
|
||||
AnthropicMessagesRequest,
|
||||
AnthropicResponse,
|
||||
|
|
|
@ -54,7 +54,6 @@ from litellm.exceptions import RejectedRequestError
|
|||
from litellm.integrations.custom_guardrail import CustomGuardrail
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.integrations.SlackAlerting.slack_alerting import SlackAlerting
|
||||
from litellm.integrations.SlackAlerting.types import DEFAULT_ALERT_TYPES
|
||||
from litellm.integrations.SlackAlerting.utils import _add_langfuse_trace_id_to_alert
|
||||
from litellm.litellm_core_utils.core_helpers import (
|
||||
_get_parent_otel_span_from_kwargs,
|
||||
|
@ -85,6 +84,7 @@ from litellm.proxy.hooks.parallel_request_limiter import (
|
|||
_PROXY_MaxParallelRequestsHandler,
|
||||
)
|
||||
from litellm.secret_managers.main import str_to_bool
|
||||
from litellm.types.integrations.slack_alerting import DEFAULT_ALERT_TYPES
|
||||
from litellm.types.utils import CallTypes, LoggedLiteLLMParams
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
|
|
@ -208,7 +208,7 @@ async def vertex_proxy_route(
|
|||
request,
|
||||
fastapi_response,
|
||||
user_api_key_dict,
|
||||
stream=is_streaming_request,
|
||||
stream=is_streaming_request, # type: ignore
|
||||
)
|
||||
|
||||
return received_value
|
||||
|
|
|
@ -10,11 +10,10 @@ from litellm.llms.azure_ai.rerank import AzureAIRerank
|
|||
from litellm.llms.cohere.rerank import CohereRerank
|
||||
from litellm.llms.together_ai.rerank import TogetherAIRerank
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.types.rerank import RerankRequest, RerankResponse
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import client, exception_type, supports_httpx_timeout
|
||||
|
||||
from .types import RerankRequest, RerankResponse
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
# Initialize any necessary instances or variables here
|
||||
cohere_rerank = CohereRerank()
|
||||
|
|
|
@ -7,6 +7,7 @@ import httpx
|
|||
import openai
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm._logging import verbose_router_logger
|
||||
from litellm.llms.AzureOpenAI.azure import get_azure_ad_token_from_oidc
|
||||
from litellm.secret_managers.get_azure_ad_token_provider import (
|
||||
|
@ -111,17 +112,17 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
api_key = litellm_params.get("api_key") or default_api_key
|
||||
if api_key and isinstance(api_key, str) and api_key.startswith("os.environ/"):
|
||||
api_key_env_name = api_key.replace("os.environ/", "")
|
||||
api_key = litellm.get_secret(api_key_env_name)
|
||||
api_key = get_secret_str(api_key_env_name)
|
||||
litellm_params["api_key"] = api_key
|
||||
|
||||
api_base = litellm_params.get("api_base")
|
||||
base_url = litellm_params.get("base_url")
|
||||
base_url: Optional[str] = litellm_params.get("base_url")
|
||||
api_base = (
|
||||
api_base or base_url or default_api_base
|
||||
) # allow users to pass in `api_base` or `base_url` for azure
|
||||
if api_base and api_base.startswith("os.environ/"):
|
||||
api_base_env_name = api_base.replace("os.environ/", "")
|
||||
api_base = litellm.get_secret(api_base_env_name)
|
||||
api_base = get_secret_str(api_base_env_name)
|
||||
litellm_params["api_base"] = api_base
|
||||
|
||||
## AZURE AI STUDIO MISTRAL CHECK ##
|
||||
|
@ -147,33 +148,37 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
api_version = litellm_params.get("api_version")
|
||||
if api_version and api_version.startswith("os.environ/"):
|
||||
api_version_env_name = api_version.replace("os.environ/", "")
|
||||
api_version = litellm.get_secret(api_version_env_name)
|
||||
api_version = get_secret_str(api_version_env_name)
|
||||
litellm_params["api_version"] = api_version
|
||||
|
||||
timeout = litellm_params.pop("timeout", None) or litellm.request_timeout
|
||||
timeout: Optional[float] = (
|
||||
litellm_params.pop("timeout", None) or litellm.request_timeout
|
||||
)
|
||||
if isinstance(timeout, str) and timeout.startswith("os.environ/"):
|
||||
timeout_env_name = timeout.replace("os.environ/", "")
|
||||
timeout = litellm.get_secret(timeout_env_name)
|
||||
timeout = get_secret(timeout_env_name) # type: ignore
|
||||
litellm_params["timeout"] = timeout
|
||||
|
||||
stream_timeout = litellm_params.pop(
|
||||
stream_timeout: Optional[float] = litellm_params.pop(
|
||||
"stream_timeout", timeout
|
||||
) # if no stream_timeout is set, default to timeout
|
||||
if isinstance(stream_timeout, str) and stream_timeout.startswith("os.environ/"):
|
||||
stream_timeout_env_name = stream_timeout.replace("os.environ/", "")
|
||||
stream_timeout = litellm.get_secret(stream_timeout_env_name)
|
||||
stream_timeout = get_secret(stream_timeout_env_name) # type: ignore
|
||||
litellm_params["stream_timeout"] = stream_timeout
|
||||
|
||||
max_retries = litellm_params.pop("max_retries", 0) # router handles retry logic
|
||||
max_retries: Optional[int] = litellm_params.pop(
|
||||
"max_retries", 0
|
||||
) # router handles retry logic
|
||||
if isinstance(max_retries, str) and max_retries.startswith("os.environ/"):
|
||||
max_retries_env_name = max_retries.replace("os.environ/", "")
|
||||
max_retries = litellm.get_secret(max_retries_env_name)
|
||||
max_retries = get_secret(max_retries_env_name) # type: ignore
|
||||
litellm_params["max_retries"] = max_retries
|
||||
|
||||
organization = litellm_params.get("organization", None)
|
||||
if isinstance(organization, str) and organization.startswith("os.environ/"):
|
||||
organization_env_name = organization.replace("os.environ/", "")
|
||||
organization = litellm.get_secret(organization_env_name)
|
||||
organization = get_secret_str(organization_env_name)
|
||||
litellm_params["organization"] = organization
|
||||
azure_ad_token_provider: Optional[Callable[[], str]] = None
|
||||
if litellm_params.get("tenant_id"):
|
||||
|
@ -227,8 +232,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -253,8 +258,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -276,8 +281,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -302,8 +307,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
azure_ad_token_provider=azure_ad_token_provider,
|
||||
base_url=api_base,
|
||||
api_version=api_version,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -350,8 +355,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
cache_key = f"{model_id}_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -371,8 +376,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
cache_key = f"{model_id}_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -391,8 +396,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
cache_key = f"{model_id}_stream_async_client"
|
||||
_client = openai.AsyncAzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -413,8 +418,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
cache_key = f"{model_id}_stream_client"
|
||||
_client = openai.AzureOpenAI( # type: ignore
|
||||
**azure_client_params,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
max_connections=1000, max_keepalive_connections=100
|
||||
|
@ -441,8 +446,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AsyncOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
|
@ -465,8 +470,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.OpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
|
@ -487,8 +492,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.AsyncOpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.AsyncClient(
|
||||
limits=httpx.Limits(
|
||||
|
@ -512,8 +517,8 @@ def set_client(litellm_router_instance: LitellmRouter, model: dict):
|
|||
_client = openai.OpenAI( # type: ignore
|
||||
api_key=api_key,
|
||||
base_url=api_base,
|
||||
timeout=stream_timeout,
|
||||
max_retries=max_retries,
|
||||
timeout=stream_timeout, # type: ignore
|
||||
max_retries=max_retries, # type: ignore
|
||||
organization=organization,
|
||||
http_client=httpx.Client(
|
||||
limits=httpx.Limits(
|
||||
|
@ -542,20 +547,29 @@ def get_azure_ad_token_from_entrata_id(
|
|||
verbose_router_logger.debug("Getting Azure AD Token from Entrata ID")
|
||||
|
||||
if tenant_id.startswith("os.environ/"):
|
||||
tenant_id = litellm.get_secret(tenant_id)
|
||||
_tenant_id = get_secret_str(tenant_id)
|
||||
else:
|
||||
_tenant_id = tenant_id
|
||||
|
||||
if client_id.startswith("os.environ/"):
|
||||
client_id = litellm.get_secret(client_id)
|
||||
_client_id = get_secret_str(client_id)
|
||||
else:
|
||||
_client_id = client_id
|
||||
|
||||
if client_secret.startswith("os.environ/"):
|
||||
client_secret = litellm.get_secret(client_secret)
|
||||
_client_secret = get_secret_str(client_secret)
|
||||
else:
|
||||
_client_secret = client_secret
|
||||
|
||||
verbose_router_logger.debug(
|
||||
"tenant_id %s, client_id %s, client_secret %s",
|
||||
tenant_id,
|
||||
client_id,
|
||||
client_secret,
|
||||
_tenant_id,
|
||||
_client_id,
|
||||
_client_secret,
|
||||
)
|
||||
credential = ClientSecretCredential(tenant_id, client_id, client_secret)
|
||||
if _tenant_id is None or _client_id is None or _client_secret is None:
|
||||
raise ValueError("tenant_id, client_id, and client_secret must be provided")
|
||||
credential = ClientSecretCredential(_tenant_id, _client_id, _client_secret)
|
||||
|
||||
verbose_router_logger.debug("credential %s", credential)
|
||||
|
||||
|
|
|
@ -2,6 +2,8 @@ import asyncio
|
|||
import traceback
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm.router import Router as _Router
|
||||
|
||||
|
@ -49,5 +51,6 @@ async def send_llm_exception_alert(
|
|||
await litellm_router_instance.slack_alerting_logger.send_alert(
|
||||
message=f"LLM API call failed: `{exception_str}`",
|
||||
level="High",
|
||||
alert_type="llm_exceptions",
|
||||
alert_type=AlertType.llm_exceptions,
|
||||
alerting_metadata={},
|
||||
)
|
||||
|
|
|
@ -792,7 +792,7 @@ class EmbeddingResponse(OpenAIObject):
|
|||
model: Optional[str] = None
|
||||
"""The model used for embedding."""
|
||||
|
||||
data: Optional[List] = None
|
||||
data: List
|
||||
"""The actual embedding value"""
|
||||
|
||||
object: Literal["list"]
|
||||
|
@ -803,6 +803,7 @@ class EmbeddingResponse(OpenAIObject):
|
|||
|
||||
_hidden_params: dict = {}
|
||||
_response_headers: Optional[Dict] = None
|
||||
_response_ms: Optional[float] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -822,7 +823,7 @@ class EmbeddingResponse(OpenAIObject):
|
|||
if data:
|
||||
data = data
|
||||
else:
|
||||
data = None
|
||||
data = []
|
||||
|
||||
if usage:
|
||||
usage = usage
|
||||
|
|
272
litellm/utils.py
272
litellm/utils.py
|
@ -322,6 +322,9 @@ def function_setup(
|
|||
original_function: str, rules_obj, start_time, *args, **kwargs
|
||||
): # just run once to check if user wants to send their data anywhere - PostHog/Sentry/Slack/etc.
|
||||
### NOTICES ###
|
||||
from litellm import Logging as LiteLLMLogging
|
||||
from litellm.litellm_core_utils.litellm_logging import set_callbacks
|
||||
|
||||
if litellm.set_verbose is True:
|
||||
verbose_logger.warning(
|
||||
"`litellm.set_verbose` is deprecated. Please set `os.environ['LITELLM_LOG'] = 'DEBUG'` for debug logs."
|
||||
|
@ -333,7 +336,7 @@ def function_setup(
|
|||
custom_llm_setup()
|
||||
|
||||
## LOGGING SETUP
|
||||
function_id = kwargs["id"] if "id" in kwargs else None
|
||||
function_id: Optional[str] = kwargs["id"] if "id" in kwargs else None
|
||||
|
||||
if len(litellm.callbacks) > 0:
|
||||
for callback in litellm.callbacks:
|
||||
|
@ -375,9 +378,7 @@ def function_setup(
|
|||
+ litellm.failure_callback
|
||||
)
|
||||
)
|
||||
litellm.litellm_core_utils.litellm_logging.set_callbacks(
|
||||
callback_list=callback_list, function_id=function_id
|
||||
)
|
||||
set_callbacks(callback_list=callback_list, function_id=function_id)
|
||||
## ASYNC CALLBACKS
|
||||
if len(litellm.input_callback) > 0:
|
||||
removed_async_items = []
|
||||
|
@ -560,12 +561,12 @@ def function_setup(
|
|||
else:
|
||||
messages = "default-message-value"
|
||||
stream = True if "stream" in kwargs and kwargs["stream"] is True else False
|
||||
logging_obj = litellm.litellm_core_utils.litellm_logging.Logging(
|
||||
logging_obj = LiteLLMLogging(
|
||||
model=model,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
litellm_call_id=kwargs["litellm_call_id"],
|
||||
function_id=function_id,
|
||||
function_id=function_id or "",
|
||||
call_type=call_type,
|
||||
start_time=start_time,
|
||||
dynamic_success_callbacks=dynamic_success_callbacks,
|
||||
|
@ -655,10 +656,8 @@ def client(original_function):
|
|||
json_response_format = optional_params[
|
||||
"response_format"
|
||||
]
|
||||
elif (
|
||||
_parsing._completions.is_basemodel_type(
|
||||
optional_params["response_format"]
|
||||
)
|
||||
elif _parsing._completions.is_basemodel_type(
|
||||
optional_params["response_format"] # type: ignore
|
||||
):
|
||||
json_response_format = (
|
||||
type_to_response_format_param(
|
||||
|
@ -827,6 +826,7 @@ def client(original_function):
|
|||
print_verbose("INSIDE CHECKING CACHE")
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
):
|
||||
|
@ -879,7 +879,7 @@ def client(original_function):
|
|||
dynamic_api_key,
|
||||
api_base,
|
||||
) = litellm.get_llm_provider(
|
||||
model=model,
|
||||
model=model or "",
|
||||
custom_llm_provider=kwargs.get(
|
||||
"custom_llm_provider", None
|
||||
),
|
||||
|
@ -949,6 +949,8 @@ def client(original_function):
|
|||
base_model=base_model,
|
||||
messages=messages,
|
||||
user_max_tokens=user_max_tokens,
|
||||
buffer_num=None,
|
||||
buffer_perc=None,
|
||||
)
|
||||
kwargs["max_tokens"] = modified_max_tokens
|
||||
except Exception as e:
|
||||
|
@ -990,6 +992,7 @@ def client(original_function):
|
|||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
) and (kwargs.get("cache", {}).get("no-store", False) is not True):
|
||||
|
@ -1006,7 +1009,7 @@ def client(original_function):
|
|||
"id", None
|
||||
)
|
||||
result._hidden_params["api_base"] = get_api_base(
|
||||
model=model,
|
||||
model=model or "",
|
||||
optional_params=getattr(logging_obj, "optional_params", {}),
|
||||
)
|
||||
result._hidden_params["response_cost"] = (
|
||||
|
@ -1053,7 +1056,7 @@ def client(original_function):
|
|||
and not _is_litellm_router_call
|
||||
):
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model]
|
||||
args[0] = context_window_fallback_dict[model] # type: ignore
|
||||
else:
|
||||
kwargs["model"] = context_window_fallback_dict[model]
|
||||
return original_function(*args, **kwargs)
|
||||
|
@ -1065,12 +1068,6 @@ def client(original_function):
|
|||
logging_obj.failure_handler(
|
||||
e, traceback_exception, start_time, end_time
|
||||
) # DO NOT MAKE THREADED - router retry fallback relies on this!
|
||||
if hasattr(e, "message"):
|
||||
if (
|
||||
liteDebuggerClient
|
||||
and liteDebuggerClient.dashboard_url is not None
|
||||
): # make it easy to get to the debugger logs if you've initialized it
|
||||
e.message += f"\n Check the log in your dashboard - {liteDebuggerClient.dashboard_url}"
|
||||
raise e
|
||||
|
||||
@wraps(original_function)
|
||||
|
@ -1126,6 +1123,7 @@ def client(original_function):
|
|||
print_verbose("INSIDE CHECKING CACHE")
|
||||
if (
|
||||
litellm.cache is not None
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
):
|
||||
|
@ -1287,7 +1285,11 @@ def client(original_function):
|
|||
args=(cached_result, start_time, end_time, cache_hit),
|
||||
).start()
|
||||
cache_key = kwargs.get("preset_cache_key", None)
|
||||
cached_result._hidden_params["cache_key"] = cache_key
|
||||
if (
|
||||
isinstance(cached_result, BaseModel)
|
||||
or isinstance(cached_result, CustomStreamWrapper)
|
||||
) and hasattr(cached_result, "_hidden_params"):
|
||||
cached_result._hidden_params["cache_key"] = cache_key # type: ignore
|
||||
return cached_result
|
||||
elif (
|
||||
call_type == CallTypes.aembedding.value
|
||||
|
@ -1447,6 +1449,7 @@ def client(original_function):
|
|||
# [OPTIONAL] ADD TO CACHE
|
||||
if (
|
||||
(litellm.cache is not None)
|
||||
and litellm.cache.supported_call_types is not None
|
||||
and (
|
||||
str(original_function.__name__)
|
||||
in litellm.cache.supported_call_types
|
||||
|
@ -1504,11 +1507,12 @@ def client(original_function):
|
|||
if (
|
||||
isinstance(result, EmbeddingResponse)
|
||||
and final_embedding_cached_response is not None
|
||||
and final_embedding_cached_response.data is not None
|
||||
):
|
||||
idx = 0
|
||||
final_data_list = []
|
||||
for item in final_embedding_cached_response.data:
|
||||
if item is None:
|
||||
if item is None and result.data is not None:
|
||||
final_data_list.append(result.data[idx])
|
||||
idx += 1
|
||||
else:
|
||||
|
@ -1575,7 +1579,7 @@ def client(original_function):
|
|||
and model in context_window_fallback_dict
|
||||
):
|
||||
if len(args) > 0:
|
||||
args[0] = context_window_fallback_dict[model]
|
||||
args[0] = context_window_fallback_dict[model] # type: ignore
|
||||
else:
|
||||
kwargs["model"] = context_window_fallback_dict[model]
|
||||
return await original_function(*args, **kwargs)
|
||||
|
@ -2945,13 +2949,19 @@ def get_optional_params(
|
|||
response_format=non_default_params["response_format"]
|
||||
)
|
||||
# # clean out 'additionalProperties = False'. Causes vertexai/gemini OpenAI API Schema errors - https://github.com/langchain-ai/langchainjs/issues/5240
|
||||
if non_default_params["response_format"].get("json_schema", {}).get(
|
||||
"schema"
|
||||
) is not None and custom_llm_provider in [
|
||||
if (
|
||||
non_default_params["response_format"] is not None
|
||||
and non_default_params["response_format"]
|
||||
.get("json_schema", {})
|
||||
.get("schema")
|
||||
is not None
|
||||
and custom_llm_provider
|
||||
in [
|
||||
"gemini",
|
||||
"vertex_ai",
|
||||
"vertex_ai_beta",
|
||||
]:
|
||||
]
|
||||
):
|
||||
old_schema = copy.deepcopy(
|
||||
non_default_params["response_format"]
|
||||
.get("json_schema", {})
|
||||
|
@ -3754,7 +3764,11 @@ def get_optional_params(
|
|||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "openrouter":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -3863,7 +3877,11 @@ def get_optional_params(
|
|||
non_default_params=non_default_params,
|
||||
optional_params=optional_params,
|
||||
model=model,
|
||||
drop_params=drop_params,
|
||||
drop_params=(
|
||||
drop_params
|
||||
if drop_params is not None and isinstance(drop_params, bool)
|
||||
else False
|
||||
),
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
supported_params = get_supported_openai_params(
|
||||
|
@ -4889,7 +4907,7 @@ def get_model_info(model: str, custom_llm_provider: Optional[str] = None) -> Mod
|
|||
try:
|
||||
split_model, custom_llm_provider, _, _ = get_llm_provider(model=model)
|
||||
except Exception:
|
||||
pass
|
||||
split_model = model
|
||||
combined_model_name = model
|
||||
stripped_model_name = _strip_model_name(model=model)
|
||||
combined_stripped_model_name = stripped_model_name
|
||||
|
@ -5865,6 +5883,8 @@ def convert_to_model_response_object(
|
|||
for idx, choice in enumerate(response_object["choices"]):
|
||||
## HANDLE JSON MODE - anthropic returns single function call]
|
||||
tool_calls = choice["message"].get("tool_calls", None)
|
||||
message: Optional[Message] = None
|
||||
finish_reason: Optional[str] = None
|
||||
if (
|
||||
convert_tool_call_to_json_mode
|
||||
and tool_calls is not None
|
||||
|
@ -5877,7 +5897,7 @@ def convert_to_model_response_object(
|
|||
if json_mode_content_str is not None:
|
||||
message = litellm.Message(content=json_mode_content_str)
|
||||
finish_reason = "stop"
|
||||
else:
|
||||
if message is None:
|
||||
message = Message(
|
||||
content=choice["message"].get("content", None),
|
||||
role=choice["message"]["role"] or "assistant",
|
||||
|
@ -6066,7 +6086,7 @@ def valid_model(model):
|
|||
model in litellm.open_ai_chat_completion_models
|
||||
or model in litellm.open_ai_text_completion_models
|
||||
):
|
||||
openai.Model.retrieve(model)
|
||||
openai.models.retrieve(model)
|
||||
else:
|
||||
messages = [{"role": "user", "content": "Hello World"}]
|
||||
litellm.completion(model=model, messages=messages)
|
||||
|
@ -6386,8 +6406,8 @@ class CustomStreamWrapper:
|
|||
self,
|
||||
completion_stream,
|
||||
model,
|
||||
custom_llm_provider=None,
|
||||
logging_obj=None,
|
||||
logging_obj: Any,
|
||||
custom_llm_provider: Optional[str] = None,
|
||||
stream_options=None,
|
||||
make_call: Optional[Callable] = None,
|
||||
_response_headers: Optional[dict] = None,
|
||||
|
@ -6633,36 +6653,6 @@ class CustomStreamWrapper:
|
|||
"completion_tokens": completion_tokens,
|
||||
}
|
||||
|
||||
def handle_together_ai_chunk(self, chunk):
|
||||
chunk = chunk.decode("utf-8")
|
||||
text = ""
|
||||
is_finished = False
|
||||
finish_reason = None
|
||||
if "text" in chunk:
|
||||
text_index = chunk.find('"text":"') # this checks if text: exists
|
||||
text_start = text_index + len('"text":"')
|
||||
text_end = chunk.find('"}', text_start)
|
||||
if text_index != -1 and text_end != -1:
|
||||
extracted_text = chunk[text_start:text_end]
|
||||
text = extracted_text
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
elif "[DONE]" in chunk:
|
||||
return {"text": text, "is_finished": True, "finish_reason": "stop"}
|
||||
elif "error" in chunk:
|
||||
raise litellm.together_ai.TogetherAIError(
|
||||
status_code=422, message=f"{str(chunk)}"
|
||||
)
|
||||
else:
|
||||
return {
|
||||
"text": text,
|
||||
"is_finished": is_finished,
|
||||
"finish_reason": finish_reason,
|
||||
}
|
||||
|
||||
def handle_predibase_chunk(self, chunk):
|
||||
try:
|
||||
if not isinstance(chunk, str):
|
||||
|
@ -7264,11 +7254,16 @@ class CustomStreamWrapper:
|
|||
try:
|
||||
if isinstance(chunk, dict):
|
||||
parsed_response = chunk
|
||||
if isinstance(chunk, (str, bytes)):
|
||||
elif isinstance(chunk, (str, bytes)):
|
||||
if isinstance(chunk, bytes):
|
||||
parsed_response = chunk.decode("utf-8")
|
||||
else:
|
||||
parsed_response = chunk
|
||||
else:
|
||||
raise ValueError("Unable to parse streaming chunk")
|
||||
if isinstance(parsed_response, dict):
|
||||
data_json = parsed_response
|
||||
else:
|
||||
data_json = json.loads(parsed_response)
|
||||
text = (
|
||||
data_json.get("outputs", "")[0]
|
||||
|
@ -7331,8 +7326,7 @@ class CustomStreamWrapper:
|
|||
|
||||
if (
|
||||
len(model_response.choices) > 0
|
||||
and hasattr(model_response.choices[0], "delta")
|
||||
and model_response.choices[0].delta is not None
|
||||
and getattr(model_response.choices[0], "delta") is not None
|
||||
):
|
||||
# do nothing, if object instantiated
|
||||
pass
|
||||
|
@ -7350,7 +7344,7 @@ class CustomStreamWrapper:
|
|||
is_empty = False
|
||||
return is_empty
|
||||
|
||||
def chunk_creator(self, chunk):
|
||||
def chunk_creator(self, chunk): # type: ignore
|
||||
model_response = self.model_response_creator()
|
||||
response_obj = {}
|
||||
try:
|
||||
|
@ -7422,11 +7416,6 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "together_ai":
|
||||
response_obj = self.handle_together_ai_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider and self.custom_llm_provider == "huggingface":
|
||||
response_obj = self.handle_huggingface_chunk(chunk)
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
|
@ -7475,51 +7464,6 @@ class CustomStreamWrapper:
|
|||
if self.sent_first_chunk is False:
|
||||
raise Exception("An unknown error occurred with the stream")
|
||||
self.received_finish_reason = "stop"
|
||||
elif self.custom_llm_provider == "gemini":
|
||||
if hasattr(chunk, "parts") is True:
|
||||
try:
|
||||
if len(chunk.parts) > 0:
|
||||
completion_obj["content"] = chunk.parts[0].text
|
||||
if len(chunk.parts) > 0 and hasattr(
|
||||
chunk.parts[0], "finish_reason"
|
||||
):
|
||||
self.received_finish_reason = chunk.parts[
|
||||
0
|
||||
].finish_reason.name
|
||||
except Exception:
|
||||
if chunk.parts[0].finish_reason.name == "SAFETY":
|
||||
raise Exception(
|
||||
f"The response was blocked by VertexAI. {str(chunk)}"
|
||||
)
|
||||
else:
|
||||
completion_obj["content"] = str(chunk)
|
||||
elif self.custom_llm_provider and (
|
||||
self.custom_llm_provider == "vertex_ai_beta"
|
||||
):
|
||||
from litellm.types.utils import (
|
||||
GenericStreamingChunk as UtilsStreamingChunk,
|
||||
)
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj: UtilsStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["prompt_tokens"],
|
||||
completion_tokens=response_obj["usage"]["completion_tokens"],
|
||||
total_tokens=response_obj["usage"]["total_tokens"],
|
||||
)
|
||||
|
||||
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
elif self.custom_llm_provider and (self.custom_llm_provider == "vertex_ai"):
|
||||
import proto # type: ignore
|
||||
|
||||
|
@ -7624,53 +7568,7 @@ class CustomStreamWrapper:
|
|||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
elif self.custom_llm_provider == "bedrock":
|
||||
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj: GenericStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
total_tokens=response_obj["usage"]["totalTokens"],
|
||||
)
|
||||
|
||||
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
|
||||
elif self.custom_llm_provider == "sagemaker":
|
||||
from litellm.types.llms.bedrock import GenericStreamingChunk
|
||||
|
||||
if self.received_finish_reason is not None:
|
||||
raise StopIteration
|
||||
response_obj: GenericStreamingChunk = chunk
|
||||
completion_obj["content"] = response_obj["text"]
|
||||
if response_obj["is_finished"]:
|
||||
self.received_finish_reason = response_obj["finish_reason"]
|
||||
|
||||
if (
|
||||
self.stream_options
|
||||
and self.stream_options.get("include_usage", False) is True
|
||||
and response_obj["usage"] is not None
|
||||
):
|
||||
model_response.usage = litellm.Usage(
|
||||
prompt_tokens=response_obj["usage"]["inputTokens"],
|
||||
completion_tokens=response_obj["usage"]["outputTokens"],
|
||||
total_tokens=response_obj["usage"]["totalTokens"],
|
||||
)
|
||||
|
||||
if "tool_use" in response_obj and response_obj["tool_use"] is not None:
|
||||
completion_obj["tool_calls"] = [response_obj["tool_use"]]
|
||||
elif self.custom_llm_provider == "petals":
|
||||
if len(self.completion_stream) == 0:
|
||||
if self.received_finish_reason is not None:
|
||||
|
@ -8181,9 +8079,11 @@ class CustomStreamWrapper:
|
|||
target=self.run_success_logging_in_thread,
|
||||
args=(response, cache_hit),
|
||||
).start() # log response
|
||||
self.response_uptil_now += (
|
||||
response.choices[0].delta.get("content", "") or ""
|
||||
)
|
||||
choice = response.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
else:
|
||||
self.response_uptil_now += ""
|
||||
self.rules.post_call_rules(
|
||||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
|
@ -8223,8 +8123,11 @@ class CustomStreamWrapper:
|
|||
)
|
||||
response = self.model_response_creator()
|
||||
if complete_streaming_response is not None:
|
||||
response.usage = complete_streaming_response.usage
|
||||
response._hidden_params["usage"] = complete_streaming_response.usage # type: ignore
|
||||
setattr(
|
||||
response,
|
||||
"usage",
|
||||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
|
@ -8349,9 +8252,11 @@ class CustomStreamWrapper:
|
|||
processed_chunk, cache_hit=cache_hit
|
||||
)
|
||||
)
|
||||
self.response_uptil_now += (
|
||||
processed_chunk.choices[0].delta.get("content", "") or ""
|
||||
)
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += choice.delta.get("content", "") or ""
|
||||
else:
|
||||
self.response_uptil_now += ""
|
||||
self.rules.post_call_rules(
|
||||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
|
@ -8401,9 +8306,13 @@ class CustomStreamWrapper:
|
|||
)
|
||||
)
|
||||
|
||||
choice = processed_chunk.choices[0]
|
||||
if isinstance(choice, StreamingChoices):
|
||||
self.response_uptil_now += (
|
||||
processed_chunk.choices[0].delta.get("content", "") or ""
|
||||
choice.delta.get("content", "") or ""
|
||||
)
|
||||
else:
|
||||
self.response_uptil_now += ""
|
||||
self.rules.post_call_rules(
|
||||
input=self.response_uptil_now, model=self.model
|
||||
)
|
||||
|
@ -8423,7 +8332,11 @@ class CustomStreamWrapper:
|
|||
)
|
||||
response = self.model_response_creator()
|
||||
if complete_streaming_response is not None:
|
||||
setattr(response, "usage", complete_streaming_response.usage)
|
||||
setattr(
|
||||
response,
|
||||
"usage",
|
||||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
|
@ -8464,7 +8377,11 @@ class CustomStreamWrapper:
|
|||
)
|
||||
response = self.model_response_creator()
|
||||
if complete_streaming_response is not None:
|
||||
response.usage = complete_streaming_response.usage
|
||||
setattr(
|
||||
response,
|
||||
"usage",
|
||||
getattr(complete_streaming_response, "usage"),
|
||||
)
|
||||
## LOGGING
|
||||
threading.Thread(
|
||||
target=self.logging_obj.success_handler,
|
||||
|
@ -8898,7 +8815,7 @@ def trim_messages(
|
|||
if len(tool_messages):
|
||||
messages = messages[: -len(tool_messages)]
|
||||
|
||||
current_tokens = token_counter(model=model, messages=messages)
|
||||
current_tokens = token_counter(model=model or "", messages=messages)
|
||||
print_verbose(f"Current tokens: {current_tokens}, max tokens: {max_tokens}")
|
||||
|
||||
# Do nothing if current tokens under messages
|
||||
|
@ -8909,6 +8826,7 @@ def trim_messages(
|
|||
print_verbose(
|
||||
f"Need to trim input messages: {messages}, current_tokens{current_tokens}, max_tokens: {max_tokens}"
|
||||
)
|
||||
system_message_event: Optional[dict] = None
|
||||
if system_message:
|
||||
system_message_event, max_tokens = process_system_message(
|
||||
system_message=system_message, max_tokens=max_tokens, model=model
|
||||
|
@ -8926,7 +8844,7 @@ def trim_messages(
|
|||
)
|
||||
|
||||
# Add system message to the beginning of the final messages
|
||||
if system_message:
|
||||
if system_message_event:
|
||||
final_messages = [system_message_event] + final_messages
|
||||
|
||||
if len(tool_messages) > 0:
|
||||
|
@ -9214,6 +9132,8 @@ def is_cached_message(message: AllMessageValues) -> bool:
|
|||
|
||||
Follows the anthropic format {"cache_control": {"type": "ephemeral"}}
|
||||
"""
|
||||
if "content" not in message:
|
||||
return False
|
||||
if message["content"] is None or isinstance(message["content"], str):
|
||||
return False
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
{
|
||||
"ignore": [],
|
||||
"exclude": ["**/node_modules", "**/__pycache__", "litellm/tests", "litellm/main.py", "litellm/utils.py", "litellm/types/utils.py"],
|
||||
"reportMissingImports": false
|
||||
"exclude": ["**/node_modules", "**/__pycache__", "litellm/types/utils.py"],
|
||||
"reportMissingImports": false,
|
||||
"reportPrivateImportUsage": false
|
||||
}
|
||||
|
|
@ -14,7 +14,7 @@ from typing import Optional
|
|||
|
||||
import httpx
|
||||
|
||||
from litellm.integrations.SlackAlerting.types import AlertType
|
||||
from litellm.types.integrations.slack_alerting import AlertType
|
||||
|
||||
# import logging
|
||||
# logging.basicConfig(level=logging.DEBUG)
|
||||
|
|
|
@ -1188,13 +1188,36 @@ def test_completion_cost_anthropic_prompt_caching():
|
|||
system_fingerprint=None,
|
||||
usage=Usage(
|
||||
completion_tokens=10,
|
||||
prompt_tokens=14,
|
||||
total_tokens=24,
|
||||
prompt_tokens=114,
|
||||
total_tokens=124,
|
||||
prompt_tokens_details=PromptTokensDetails(cached_tokens=0),
|
||||
cache_creation_input_tokens=100,
|
||||
cache_read_input_tokens=0,
|
||||
),
|
||||
)
|
||||
|
||||
cost_1 = completion_cost(model=model, completion_response=response_1)
|
||||
|
||||
_model_info = litellm.get_model_info(
|
||||
model="claude-3-5-sonnet-20240620", custom_llm_provider="anthropic"
|
||||
)
|
||||
expected_cost = (
|
||||
(
|
||||
response_1.usage.prompt_tokens
|
||||
- response_1.usage.prompt_tokens_details.cached_tokens
|
||||
)
|
||||
* _model_info["input_cost_per_token"]
|
||||
+ response_1.usage.prompt_tokens_details.cached_tokens
|
||||
* _model_info["cache_read_input_token_cost"]
|
||||
+ response_1.usage.cache_creation_input_tokens
|
||||
* _model_info["cache_creation_input_token_cost"]
|
||||
+ response_1.usage.completion_tokens * _model_info["output_cost_per_token"]
|
||||
) # Cost of processing (non-cache hit + cache hit) + Cost of cache-writing (cache writing)
|
||||
|
||||
assert round(expected_cost, 5) == round(cost_1, 5)
|
||||
|
||||
print(f"expected_cost: {expected_cost}, cost_1: {cost_1}")
|
||||
|
||||
## READ FROM CACHE ## (LESS EXPENSIVE)
|
||||
response_2 = ModelResponse(
|
||||
id="chatcmpl-3f427194-0840-4d08-b571-56bfe38a5424",
|
||||
|
@ -1216,14 +1239,14 @@ def test_completion_cost_anthropic_prompt_caching():
|
|||
system_fingerprint=None,
|
||||
usage=Usage(
|
||||
completion_tokens=10,
|
||||
prompt_tokens=14,
|
||||
total_tokens=24,
|
||||
prompt_tokens=114,
|
||||
total_tokens=134,
|
||||
prompt_tokens_details=PromptTokensDetails(cached_tokens=100),
|
||||
cache_creation_input_tokens=0,
|
||||
cache_read_input_tokens=100,
|
||||
),
|
||||
)
|
||||
|
||||
cost_1 = completion_cost(model=model, completion_response=response_1)
|
||||
cost_2 = completion_cost(model=model, completion_response=response_2)
|
||||
|
||||
assert cost_1 > cost_2
|
||||
|
|
|
@ -10,12 +10,40 @@ import litellm
|
|||
import pytest
|
||||
|
||||
|
||||
def _usage_format_tests(usage: litellm.Usage):
|
||||
"""
|
||||
OpenAI prompt caching
|
||||
- prompt_tokens = sum of non-cache hit tokens + cache-hit tokens
|
||||
- total_tokens = prompt_tokens + completion_tokens
|
||||
|
||||
Example
|
||||
```
|
||||
"usage": {
|
||||
"prompt_tokens": 2006,
|
||||
"completion_tokens": 300,
|
||||
"total_tokens": 2306,
|
||||
"prompt_tokens_details": {
|
||||
"cached_tokens": 1920
|
||||
},
|
||||
"completion_tokens_details": {
|
||||
"reasoning_tokens": 0
|
||||
}
|
||||
# ANTHROPIC_ONLY #
|
||||
"cache_creation_input_tokens": 0
|
||||
}
|
||||
```
|
||||
"""
|
||||
assert usage.total_tokens == usage.prompt_tokens + usage.completion_tokens
|
||||
|
||||
assert usage.prompt_tokens > usage.prompt_tokens_details.cached_tokens
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"anthropic/claude-3-5-sonnet-20240620",
|
||||
"openai/gpt-4o",
|
||||
"deepseek/deepseek-chat",
|
||||
# "openai/gpt-4o",
|
||||
# "deepseek/deepseek-chat",
|
||||
],
|
||||
)
|
||||
def test_prompt_caching_model(model):
|
||||
|
@ -66,9 +94,13 @@ def test_prompt_caching_model(model):
|
|||
max_tokens=10,
|
||||
)
|
||||
|
||||
_usage_format_tests(response.usage)
|
||||
|
||||
print("response=", response)
|
||||
print("response.usage=", response.usage)
|
||||
|
||||
_usage_format_tests(response.usage)
|
||||
|
||||
assert "prompt_tokens_details" in response.usage
|
||||
assert response.usage.prompt_tokens_details.cached_tokens > 0
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue