Litellm ruff linting enforcement (#5992)

* ci(config.yml): add a 'check_code_quality' step

Addresses https://github.com/BerriAI/litellm/issues/5991

* ci(config.yml): check why circle ci doesn't pick up this test

* ci(config.yml): fix to run 'check_code_quality' tests

* fix(__init__.py): fix unprotected import

* fix(__init__.py): don't remove unused imports

* build(ruff.toml): update ruff.toml to ignore unused imports

* fix: fix: ruff + pyright - fix linting + type-checking errors

* fix: fix linting errors

* fix(lago.py): fix module init error

* fix: fix linting errors

* ci(config.yml): cd into correct dir for checks

* fix(proxy_server.py): fix linting error

* fix(utils.py): fix bare except

causes ruff linting errors

* fix: ruff - fix remaining linting errors

* fix(clickhouse.py): use standard logging object

* fix(__init__.py): fix unprotected import

* fix: ruff - fix linting errors

* fix: fix linting errors

* ci(config.yml): cleanup code qa step (formatting handled in local_testing)

* fix(_health_endpoints.py): fix ruff linting errors

* ci(config.yml): just use ruff in check_code_quality pipeline for now

* build(custom_guardrail.py): include missing file

* style(embedding_handler.py): fix ruff check
This commit is contained in:
Krish Dholakia 2024-10-01 16:44:20 -07:00 committed by GitHub
parent 3fc4ae0d65
commit d57be47b0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
263 changed files with 1687 additions and 3320 deletions

View file

@ -299,6 +299,27 @@ jobs:
ls
python -m pytest -vv tests/local_testing/test_python_38.py
check_code_quality:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project/litellm
steps:
- checkout
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
pip install ruff
pip install pylint
pip install .
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
- run: ruff check ./litellm
build_and_test:
machine:
image: ubuntu-2204:2023.10.1
@ -806,6 +827,12 @@ workflows:
only:
- main
- /litellm_.*/
- check_code_quality:
filters:
branches:
only:
- main
- /litellm_.*/
- ui_endpoint_testing:
filters:
branches:
@ -867,6 +894,7 @@ workflows:
- installing_litellm_on_python
- proxy_logging_guardrails_model_info_tests
- proxy_pass_through_endpoint_tests
- check_code_quality
filters:
branches:
only:

View file

@ -630,18 +630,6 @@ general_settings:
"database_url": "string",
"database_connection_pool_limit": 0, # default 100
"database_connection_timeout": 0, # default 60s
"database_type": "dynamo_db",
"database_args": {
"billing_mode": "PROVISIONED_THROUGHPUT",
"read_capacity_units": 0,
"write_capacity_units": 0,
"ssl_verify": true,
"region_name": "string",
"user_table_name": "LiteLLM_UserTable",
"key_table_name": "LiteLLM_VerificationToken",
"config_table_name": "LiteLLM_Config",
"spend_table_name": "LiteLLM_SpendLogs"
},
"otel": true,
"custom_auth": "string",
"max_parallel_requests": 0, # the max parallel requests allowed per deployment

View file

@ -97,7 +97,7 @@ class GenericAPILogger:
for key, value in payload.items():
try:
payload[key] = str(value)
except:
except Exception:
# non blocking if it can't cast to a str
pass

View file

@ -49,7 +49,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
def __init__(self):
try:
from google.cloud import language_v1
except:
except Exception:
raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
)
@ -90,7 +90,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass
async def async_moderation_hook(

View file

@ -58,7 +58,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass
def set_custom_prompt_template(self, messages: list):

View file

@ -49,7 +49,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass
async def moderation_check(self, text: str):

View file

@ -3,7 +3,8 @@ import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ###
import threading, requests, os
import threading
import os
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache
@ -308,13 +309,13 @@ def get_model_cost_map(url: str):
return content
try:
with requests.get(
response = httpx.get(
url, timeout=5
) as response: # set a 5 second timeout for the get request
) # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful
content = response.json()
return content
except Exception as e:
except Exception:
import importlib.resources
import json
@ -839,7 +840,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout
from .cost_calculator import completion_cost
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
@ -848,7 +849,6 @@ from .utils import (
exception_type,
get_optional_params,
get_response_string,
modify_integration,
token_counter,
create_pretrained_tokenizer,
create_tokenizer,

View file

@ -98,5 +98,5 @@ def print_verbose(print_statement):
try:
if set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass

View file

@ -2,5 +2,5 @@ import importlib_metadata
try:
version = importlib_metadata.version("litellm")
except:
except Exception:
pass

View file

@ -12,7 +12,6 @@ from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_deleted import AssistantDeleted
import litellm
from litellm import client
from litellm.llms.AzureOpenAI import assistants
from litellm.types.router import GenericLiteLLMParams
from litellm.utils import (
@ -96,7 +95,7 @@ def get_assistants(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -280,7 +279,7 @@ def create_assistants(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -464,7 +463,7 @@ def delete_assistant(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -649,7 +648,7 @@ def create_thread(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -805,7 +804,7 @@ def get_thread(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -991,7 +990,7 @@ def add_message(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -1149,7 +1148,7 @@ def get_messages(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -1347,7 +1346,7 @@ def run_thread(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout

View file

@ -22,7 +22,7 @@ import litellm
from litellm import client
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.llms.openai import (
Batch,
CancelBatchRequest,
@ -131,7 +131,7 @@ def create_batch(
extra_headers=extra_headers,
extra_body=extra_body,
)
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -165,27 +165,30 @@ def create_batch(
_is_async=_is_async,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.create_batch(
_is_async=_is_async,
@ -293,7 +296,7 @@ def retrieve_batch(
)
_is_async = kwargs.pop("aretrieve_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -327,27 +330,30 @@ def retrieve_batch(
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.retrieve_batch(
_is_async=_is_async,
@ -384,7 +390,7 @@ async def alist_batches(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Batch:
):
"""
Async: List your organization's batches.
"""
@ -482,27 +488,26 @@ def list_batches(
max_retries=optional_params.max_retries,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
) # type: ignore
or get_secret_str("AZURE_API_VERSION")
)
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
) # type: ignore
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
)
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.list_batches(
_is_async=_is_async,

View file

@ -7,11 +7,16 @@
#
# Thank you users! We ❤️ you! - Krrish & Ishaan
import os, json, time
import json
import os
import threading
import time
from typing import Literal, Optional, Union
import requests # type: ignore
import litellm
from litellm.utils import ModelResponse
import requests, threading # type: ignore
from typing import Optional, Union, Literal
class BudgetManager:
@ -35,7 +40,7 @@ class BudgetManager:
import logging
logging.info(print_statement)
except:
except Exception:
pass
def load_data(self):
@ -52,7 +57,6 @@ class BudgetManager:
elif self.client_type == "hosted":
# Load the user_dict from hosted db
url = self.api_base + "/get_budget"
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name}
response = requests.post(url, headers=self.headers, json=data)
response = response.json()
@ -210,7 +214,6 @@ class BudgetManager:
return {"status": "success"}
elif self.client_type == "hosted":
url = self.api_base + "/set_budget"
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = requests.post(url, headers=self.headers, json=data)
response = response.json()

View file

@ -33,7 +33,7 @@ def print_verbose(print_statement):
verbose_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass
@ -96,15 +96,13 @@ class InMemoryCache(BaseCache):
"""
for key in list(self.ttl_dict.keys()):
if time.time() > self.ttl_dict[key]:
removed_item = self.cache_dict.pop(key, None)
removed_ttl_item = self.ttl_dict.pop(key, None)
self.cache_dict.pop(key, None)
self.ttl_dict.pop(key, None)
# de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed.
removed_item = None
removed_ttl_item = None
def set_cache(self, key, value, **kwargs):
print_verbose(
@ -150,7 +148,7 @@ class InMemoryCache(BaseCache):
original_cached_response = self.cache_dict[key]
try:
cached_response = json.loads(original_cached_response)
except:
except Exception:
cached_response = original_cached_response
return cached_response
return None
@ -251,7 +249,7 @@ class RedisCache(BaseCache):
self.redis_version = "Unknown"
try:
self.redis_version = self.redis_client.info()["redis_version"]
except Exception as e:
except Exception:
pass
### ASYNC HEALTH PING ###
@ -688,7 +686,7 @@ class RedisCache(BaseCache):
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
@ -844,7 +842,7 @@ class RedisCache(BaseCache):
"""
Tests if the sync redis client is correctly setup.
"""
print_verbose(f"Pinging Sync Redis Cache")
print_verbose("Pinging Sync Redis Cache")
start_time = time.time()
try:
response = self.redis_client.ping()
@ -878,7 +876,7 @@ class RedisCache(BaseCache):
_redis_client = self.init_async_client()
start_time = time.time()
async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache")
print_verbose("Pinging Async Redis Cache")
try:
response = await redis_client.ping()
## LOGGING ##
@ -973,7 +971,6 @@ class RedisSemanticCache(BaseCache):
},
"fields": {
"text": [{"name": "response"}],
"text": [{"name": "prompt"}],
"vector": [
{
"name": "litellm_embedding",
@ -999,14 +996,14 @@ class RedisSemanticCache(BaseCache):
redis_url = "redis://:" + password + "@" + host + ":" + port
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async == False:
if use_async is False:
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url)
try:
self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
elif use_async == True:
elif use_async is True:
schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True)
@ -1027,7 +1024,7 @@ class RedisSemanticCache(BaseCache):
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
@ -1060,7 +1057,7 @@ class RedisSemanticCache(BaseCache):
]
# Add more data
keys = self.index.load(new_data)
self.index.load(new_data)
return
@ -1092,7 +1089,7 @@ class RedisSemanticCache(BaseCache):
)
results = self.index.query(query)
if results == None:
if results is None:
return None
if isinstance(results, list):
if len(results) == 0:
@ -1173,7 +1170,7 @@ class RedisSemanticCache(BaseCache):
]
# Add more data
keys = await self.index.aload(new_data)
await self.index.aload(new_data)
return
async def async_get_cache(self, key, **kwargs):
@ -1222,7 +1219,7 @@ class RedisSemanticCache(BaseCache):
return_fields=["response", "prompt", "vector_distance"],
)
results = await self.index.aquery(query)
if results == None:
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
@ -1396,7 +1393,7 @@ class QdrantSemanticCache(BaseCache):
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except:
except Exception:
cached_response = ast.literal_eval(cached_response)
return cached_response
@ -1435,7 +1432,7 @@ class QdrantSemanticCache(BaseCache):
},
]
}
keys = self.sync_client.put(
self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
@ -1481,7 +1478,7 @@ class QdrantSemanticCache(BaseCache):
)
results = search_response.json()["result"]
if results == None:
if results is None:
return None
if isinstance(results, list):
if len(results) == 0:
@ -1563,7 +1560,7 @@ class QdrantSemanticCache(BaseCache):
]
}
keys = await self.async_client.put(
await self.async_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers,
json=data,
@ -1629,7 +1626,7 @@ class QdrantSemanticCache(BaseCache):
results = search_response.json()["result"]
if results == None:
if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None
if isinstance(results, list):
@ -1767,7 +1764,7 @@ class S3Cache(BaseCache):
cached_response = json.loads(
cached_response
) # Convert string to dictionary
except Exception as e:
except Exception:
cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict:
cached_response = dict(cached_response)
@ -1845,7 +1842,7 @@ class DualCache(BaseCache):
self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
if self.redis_cache is not None and local_only is False:
self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e:
print_verbose(e)
@ -1865,7 +1862,7 @@ class DualCache(BaseCache):
if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
if self.redis_cache is not None and local_only is False:
result = self.redis_cache.increment_cache(key, value, **kwargs)
return result
@ -1887,7 +1884,7 @@ class DualCache(BaseCache):
if (
(self.always_read_redis is True)
and self.redis_cache is not None
and local_only == False
and local_only is False
):
# If not found in in-memory cache or always_read_redis is True, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs)
@ -1900,7 +1897,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
except Exception:
verbose_logger.error(traceback.format_exc())
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
@ -1913,7 +1910,7 @@ class DualCache(BaseCache):
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
if None in result and self.redis_cache is not None and local_only is False:
"""
- for the none values in the result
- check the redis cache
@ -1933,7 +1930,7 @@ class DualCache(BaseCache):
print_verbose(f"async batch get cache: cache result: {result}")
return result
except Exception as e:
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
@ -1952,7 +1949,7 @@ class DualCache(BaseCache):
if in_memory_result is not None:
result = in_memory_result
if result is None and self.redis_cache is not None and local_only == False:
if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
@ -1966,7 +1963,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}")
return result
except Exception as e:
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_batch_get_cache(
@ -1981,7 +1978,7 @@ class DualCache(BaseCache):
if in_memory_result is not None:
result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False:
if None in result and self.redis_cache is not None and local_only is False:
"""
- for the none values in the result
- check the redis cache
@ -2006,7 +2003,7 @@ class DualCache(BaseCache):
result[index] = value
return result
except Exception as e:
except Exception:
verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
@ -2017,7 +2014,7 @@ class DualCache(BaseCache):
if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False:
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e:
verbose_logger.exception(
@ -2039,7 +2036,7 @@ class DualCache(BaseCache):
cache_list=cache_list, **kwargs
)
if self.redis_cache is not None and local_only == False:
if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache_pipeline(
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
)
@ -2459,7 +2456,7 @@ class Cache:
cached_response = json.loads(
cached_response # type: ignore
) # Convert string to dictionary
except:
except Exception:
cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response
return cached_result
@ -2492,7 +2489,7 @@ class Cache:
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
@ -2506,7 +2503,7 @@ class Cache:
if self.should_use_cache(*args, **kwargs) is not True:
return
messages = kwargs.get("messages", [])
kwargs.get("messages", [])
if "cache_key" in kwargs:
cache_key = kwargs["cache_key"]
else:
@ -2522,7 +2519,7 @@ class Cache:
return self._get_cache_logic(
cached_result=cached_result, max_age=max_age
)
except Exception as e:
except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None
@ -2701,7 +2698,7 @@ class DiskCache(BaseCache):
if original_cached_response:
try:
cached_response = json.loads(original_cached_response) # type: ignore
except:
except Exception:
cached_response = original_cached_response
return cached_response
return None
@ -2803,7 +2800,7 @@ def enable_cache(
if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache")
if litellm.cache == None:
if litellm.cache is None:
litellm.cache = Cache(
type=type,
host=host,

View file

@ -57,7 +57,7 @@
# config = yaml.safe_load(file)
# else:
# pass
# except:
# except Exception:
# pass
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')

View file

@ -9,12 +9,12 @@ import asyncio
import contextvars
import os
from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
import httpx
import litellm
from litellm import client, get_secret
from litellm import client, get_secret_str
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.types.llms.openai import (
@ -39,7 +39,7 @@ async def afile_retrieve(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> Coroutine[Any, Any, FileObject]:
):
"""
Async: Get file contents
@ -66,7 +66,7 @@ async def afile_retrieve(
if asyncio.iscoroutine(init_response):
response = await init_response
else:
response = init_response # type: ignore
response = init_response
return response
except Exception as e:
@ -137,27 +137,26 @@ def file_retrieve(
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.retrieve_file(
_is_async=_is_async,
@ -181,7 +180,7 @@ def file_retrieve(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
return cast(FileObject, response)
except Exception as e:
raise e
@ -222,7 +221,7 @@ async def afile_delete(
else:
response = init_response # type: ignore
return response
return cast(FileDeleted, response) # type: ignore
except Exception as e:
raise e
@ -248,7 +247,7 @@ def file_delete(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -288,27 +287,26 @@ def file_delete(
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.delete_file(
_is_async=_is_async,
@ -332,7 +330,7 @@ def file_delete(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
),
)
return response
return cast(FileDeleted, response)
except Exception as e:
raise e
@ -399,7 +397,7 @@ def file_list(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -441,27 +439,26 @@ def file_list(
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.list_files(
_is_async=_is_async,
@ -556,7 +553,7 @@ def create_file(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -603,27 +600,26 @@ def create_file(
create_file_data=_create_file_request,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.create_file(
_is_async=_is_async,
@ -713,7 +709,7 @@ def file_content(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -761,27 +757,26 @@ def file_content(
organization=organization,
)
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.file_content(
_is_async=_is_async,

View file

@ -25,7 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
OpenAIFineTuningAPI,
)
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
from litellm.secret_managers.main import get_secret
from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import Hyperparameters
from litellm.types.router import *
from litellm.utils import supports_httpx_timeout
@ -119,7 +119,7 @@ def create_fine_tuning_job(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -177,28 +177,27 @@ def create_fine_tuning_job(
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
create_fine_tuning_job_data = FineTuningJobCreate(
model=model,
@ -228,14 +227,14 @@ def create_fine_tuning_job(
vertex_ai_project = (
optional_params.vertex_project
or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT")
or get_secret_str("VERTEXAI_PROJECT")
)
vertex_ai_location = (
optional_params.vertex_location
or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION")
or get_secret_str("VERTEXAI_LOCATION")
)
vertex_credentials = optional_params.vertex_credentials or get_secret(
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS"
)
create_fine_tuning_job_data = FineTuningJobCreate(
@ -315,7 +314,7 @@ async def acancel_fine_tuning_job(
def cancel_fine_tuning_job(
fine_tuning_job_id: str,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -335,7 +334,7 @@ def cancel_fine_tuning_job(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -386,23 +385,22 @@ def cancel_fine_tuning_job(
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base,
@ -438,7 +436,7 @@ async def alist_fine_tuning_jobs(
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
) -> FineTuningJob:
):
"""
Async: List your organization's fine-tuning jobs
"""
@ -473,7 +471,7 @@ async def alist_fine_tuning_jobs(
def list_fine_tuning_jobs(
after: Optional[str] = None,
limit: Optional[int] = None,
custom_llm_provider: Literal["openai"] = "openai",
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None,
**kwargs,
@ -495,7 +493,7 @@ def list_fine_tuning_jobs(
if (
timeout is not None
and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False
and supports_httpx_timeout(custom_llm_provider) is False
):
read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout
@ -542,28 +540,27 @@ def list_fine_tuning_jobs(
)
# Azure OpenAI
elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = (
optional_params.api_version
or litellm.api_version
or get_secret("AZURE_API_VERSION")
or get_secret_str("AZURE_API_VERSION")
) # type: ignore
api_key = (
optional_params.api_key
or litellm.api_key
or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY")
or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret_str("AZURE_API_KEY")
) # type: ignore
extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None)
extra_body.pop("azure_ad_token", None)
else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base,

View file

@ -23,6 +23,9 @@ import litellm.types
from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.caching import DualCache
from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.litellm_core_utils.exception_mapping_utils import (
_add_key_name_and_team_to_alert,
)
from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler,
@ -219,7 +222,7 @@ class SlackAlerting(CustomBatchLogger):
and "metadata" in kwargs["litellm_params"]
):
_metadata: dict = kwargs["litellm_params"]["metadata"]
request_info = litellm.utils._add_key_name_and_team_to_alert(
request_info = _add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata
)
@ -281,7 +284,7 @@ class SlackAlerting(CustomBatchLogger):
return_val += 1
return return_val
except Exception as e:
except Exception:
return 0
async def send_daily_reports(self, router) -> bool:
@ -455,7 +458,7 @@ class SlackAlerting(CustomBatchLogger):
try:
messages = str(messages)
messages = messages[:100]
except:
except Exception:
messages = ""
if (
@ -508,7 +511,7 @@ class SlackAlerting(CustomBatchLogger):
_metadata: dict = request_data["metadata"]
_api_base = _metadata.get("api_base", "")
request_info = litellm.utils._add_key_name_and_team_to_alert(
request_info = _add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata
)
@ -846,7 +849,7 @@ class SlackAlerting(CustomBatchLogger):
## MINOR OUTAGE ALERT SENT ##
if (
outage_value["minor_alert_sent"] == False
outage_value["minor_alert_sent"] is False
and len(outage_value["alerts"])
>= self.alerting_args.minor_outage_alert_threshold
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
@ -871,7 +874,7 @@ class SlackAlerting(CustomBatchLogger):
## MAJOR OUTAGE ALERT SENT ##
elif (
outage_value["major_alert_sent"] == False
outage_value["major_alert_sent"] is False
and len(outage_value["alerts"])
>= self.alerting_args.major_outage_alert_threshold
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
@ -941,7 +944,7 @@ class SlackAlerting(CustomBatchLogger):
if provider is None:
try:
model, provider, _, _ = litellm.get_llm_provider(model=model)
except Exception as e:
except Exception:
provider = ""
api_base = litellm.get_api_base(
model=model, optional_params=deployment.litellm_params
@ -976,7 +979,7 @@ class SlackAlerting(CustomBatchLogger):
## MINOR OUTAGE ALERT SENT ##
if (
outage_value["minor_alert_sent"] == False
outage_value["minor_alert_sent"] is False
and len(outage_value["alerts"])
>= self.alerting_args.minor_outage_alert_threshold
):
@ -998,7 +1001,7 @@ class SlackAlerting(CustomBatchLogger):
# set to true
outage_value["minor_alert_sent"] = True
elif (
outage_value["major_alert_sent"] == False
outage_value["major_alert_sent"] is False
and len(outage_value["alerts"])
>= self.alerting_args.major_outage_alert_threshold
):
@ -1024,7 +1027,7 @@ class SlackAlerting(CustomBatchLogger):
await self.internal_usage_cache.async_set_cache(
key=deployment_id, value=outage_value
)
except Exception as e:
except Exception:
pass
async def model_added_alert(
@ -1177,7 +1180,6 @@ Model Info:
if user_row is not None:
recipient_email = user_row.user_email
key_name = webhook_event.key_alias
key_token = webhook_event.token
key_budget = webhook_event.max_budget
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
@ -1221,14 +1223,14 @@ Model Info:
extra=webhook_event.model_dump(),
)
payload = webhook_event.model_dump_json()
webhook_event.model_dump_json()
email_event = {
"to": recipient_email,
"subject": f"LiteLLM: {event_name}",
"html": email_html_content,
}
response = await send_email(
await send_email(
receiver_email=email_event["to"],
subject=email_event["subject"],
html=email_event["html"],
@ -1292,14 +1294,14 @@ Model Info:
The LiteLLM team <br />
"""
payload = webhook_event.model_dump_json()
webhook_event.model_dump_json()
email_event = {
"to": recipient_email,
"subject": f"LiteLLM: {event_name}",
"html": email_html_content,
}
response = await send_email(
await send_email(
receiver_email=email_event["to"],
subject=email_event["subject"],
html=email_event["html"],
@ -1446,7 +1448,6 @@ Model Info:
response_s: timedelta = end_time - start_time
final_value = response_s
total_tokens = 0
if isinstance(response_obj, litellm.ModelResponse) and (
hasattr(response_obj, "usage")
@ -1505,7 +1506,7 @@ Model Info:
await self.region_outage_alerts(
exception=kwargs["exception"], deployment_id=model_id
)
except Exception as e:
except Exception:
pass
async def _run_scheduler_helper(self, llm_router) -> bool:

View file

@ -35,7 +35,7 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump() # noqa
except:
except Exception:
# if using pydantic v1
return self.dict()

View file

@ -1,8 +1,10 @@
#### What this does ####
# On success + failure, log events to aispend.io
import dotenv, os
import traceback
import datetime
import os
import traceback
import dotenv
model_cost = {
"gpt-3.5-turbo": {
@ -118,8 +120,6 @@ class AISpendLogger:
for model in model_cost:
input_cost_sum += model_cost[model]["input_cost_per_token"]
output_cost_sum += model_cost[model]["output_cost_per_token"]
avg_input_cost = input_cost_sum / len(model_cost.keys())
avg_output_cost = output_cost_sum / len(model_cost.keys())
prompt_tokens_cost_usd_dollar = (
model_cost[model]["input_cost_per_token"]
* response_obj["usage"]["prompt_tokens"]
@ -137,12 +137,6 @@ class AISpendLogger:
f"AISpend Logging - Enters logging function for model {model}"
)
url = f"https://aispend.io/api/v1/accounts/{self.account_id}/data"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
response_timestamp = datetime.datetime.fromtimestamp(
int(response_obj["created"])
).strftime("%Y-%m-%d")
@ -168,6 +162,6 @@ class AISpendLogger:
]
print_verbose(f"AISpend Logging - final data object: {data}")
except:
except Exception:
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
pass

View file

@ -23,7 +23,7 @@ def set_arize_ai_attributes(span: Span, kwargs, response_obj):
)
optional_params = kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {}
# litellm_params = kwargs.get("litellm_params", {}) or {}
#############################################
############ LLM CALL METADATA ##############

View file

@ -1,5 +1,6 @@
import datetime
class AthinaLogger:
def __init__(self):
import os
@ -23,17 +24,20 @@ class AthinaLogger:
]
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
import requests # type: ignore
import json
import traceback
import requests # type: ignore
try:
is_stream = kwargs.get("stream", False)
if is_stream:
if "complete_streaming_response" in kwargs:
# Log the completion response in streaming mode
completion_response = kwargs["complete_streaming_response"]
response_json = completion_response.model_dump() if completion_response else {}
response_json = (
completion_response.model_dump() if completion_response else {}
)
else:
# Skip logging if the completion response is not available
return
@ -52,8 +56,8 @@ class AthinaLogger:
}
if (
type(end_time) == datetime.datetime
and type(start_time) == datetime.datetime
type(end_time) is datetime.datetime
and type(start_time) is datetime.datetime
):
data["response_time"] = int(
(end_time - start_time).total_seconds() * 1000

View file

@ -1,10 +1,11 @@
#### What this does ####
# On success + failure, log events to aispend.io
import dotenv, os
import requests # type: ignore
import traceback
import datetime
import os
import traceback
import dotenv
import requests # type: ignore
model_cost = {
"gpt-3.5-turbo": {
@ -92,91 +93,12 @@ class BerriSpendLogger:
self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID")
def price_calculator(self, model, response_obj, start_time, end_time):
# try and find if the model is in the model_cost map
# else default to the average of the costs
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0
if model in model_cost:
prompt_tokens_cost_usd_dollar = (
model_cost[model]["input_cost_per_token"]
* response_obj["usage"]["prompt_tokens"]
)
completion_tokens_cost_usd_dollar = (
model_cost[model]["output_cost_per_token"]
* response_obj["usage"]["completion_tokens"]
)
elif "replicate" in model:
# replicate models are charged based on time
# llama 2 runs on an nvidia a100 which costs $0.0032 per second - https://replicate.com/replicate/llama-2-70b-chat
model_run_time = end_time - start_time # assuming time in seconds
cost_usd_dollar = model_run_time * 0.0032
prompt_tokens_cost_usd_dollar = cost_usd_dollar / 2
completion_tokens_cost_usd_dollar = cost_usd_dollar / 2
else:
# calculate average input cost
input_cost_sum = 0
output_cost_sum = 0
for model in model_cost:
input_cost_sum += model_cost[model]["input_cost_per_token"]
output_cost_sum += model_cost[model]["output_cost_per_token"]
avg_input_cost = input_cost_sum / len(model_cost.keys())
avg_output_cost = output_cost_sum / len(model_cost.keys())
prompt_tokens_cost_usd_dollar = (
model_cost[model]["input_cost_per_token"]
* response_obj["usage"]["prompt_tokens"]
)
completion_tokens_cost_usd_dollar = (
model_cost[model]["output_cost_per_token"]
* response_obj["usage"]["completion_tokens"]
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
return
def log_event(
self, model, messages, response_obj, start_time, end_time, print_verbose
):
# Method definition
try:
print_verbose(
f"BerriSpend Logging - Enters logging function for model {model}"
)
url = f"https://berrispend.berri.ai/spend"
headers = {"Content-Type": "application/json"}
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = self.price_calculator(model, response_obj, start_time, end_time)
total_cost = (
prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
)
response_time = (end_time - start_time).total_seconds()
if "response" in response_obj:
data = [
{
"response_time": response_time,
"model_id": response_obj["model"],
"total_cost": total_cost,
"messages": messages,
"response": response_obj["choices"][0]["message"]["content"],
"account_id": self.account_id,
}
]
elif "error" in response_obj:
data = [
{
"response_time": response_time,
"model_id": response_obj["model"],
"total_cost": total_cost,
"messages": messages,
"error": response_obj["error"],
"account_id": self.account_id,
}
]
print_verbose(f"BerriSpend Logging - final data object: {data}")
response = requests.post(url, headers=headers, json=data)
except:
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
pass
"""
This integration is not implemented yet.
"""
return

View file

@ -136,27 +136,23 @@ class BraintrustLogger(CustomLogger):
project_id = self.default_project_id
prompt = {"messages": kwargs.get("messages")}
output = None
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
@ -169,7 +165,7 @@ class BraintrustLogger(CustomLogger):
metadata = copy.deepcopy(
metadata
) # Avoid modifying the original metadata
except:
except Exception:
new_metadata = {}
for key, value in metadata.items():
if (
@ -210,16 +206,13 @@ class BraintrustLogger(CustomLogger):
clean_metadata["litellm_response_cost"] = cost
metrics: Optional[dict] = None
if (
response_obj is not None
and hasattr(response_obj, "usage")
and isinstance(response_obj.usage, litellm.Usage)
):
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
usage_obj = getattr(response_obj, "usage", None)
if usage_obj and isinstance(usage_obj, litellm.Usage):
litellm.utils.get_logging_id(start_time, response_obj)
metrics = {
"prompt_tokens": response_obj.usage.prompt_tokens,
"completion_tokens": response_obj.usage.completion_tokens,
"total_tokens": response_obj.usage.total_tokens,
"prompt_tokens": usage_obj.prompt_tokens,
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
}
@ -255,27 +248,23 @@ class BraintrustLogger(CustomLogger):
project_id = self.default_project_id
prompt = {"messages": kwargs.get("messages")}
output = None
if response_obj is not None and (
kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse)
):
input = prompt
output = None
elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse
):
input = prompt
output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse
):
input = prompt
output = response_obj.choices[0].text
elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse
):
input = prompt
output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {})
@ -331,16 +320,13 @@ class BraintrustLogger(CustomLogger):
clean_metadata["litellm_response_cost"] = cost
metrics: Optional[dict] = None
if (
response_obj is not None
and hasattr(response_obj, "usage")
and isinstance(response_obj.usage, litellm.Usage)
):
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
usage_obj = getattr(response_obj, "usage", None)
if usage_obj and isinstance(usage_obj, litellm.Usage):
litellm.utils.get_logging_id(start_time, response_obj)
metrics = {
"prompt_tokens": response_obj.usage.prompt_tokens,
"completion_tokens": response_obj.usage.completion_tokens,
"total_tokens": response_obj.usage.total_tokens,
"prompt_tokens": usage_obj.prompt_tokens,
"completion_tokens": usage_obj.completion_tokens,
"total_tokens": usage_obj.total_tokens,
"total_cost": cost,
}

View file

@ -2,25 +2,24 @@
#### What this does ####
# On success, logs events to Promptlayer
import dotenv, os
from litellm.proxy._types import UserAPIKeyAuth
from litellm.caching import DualCache
from typing import Literal, Union
import datetime
import json
import os
import traceback
from typing import Literal, Optional, Union
import dotenv
import requests
import litellm
from litellm._logging import verbose_logger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import StandardLoggingPayload
#### What this does ####
# On success + failure, log events to Supabase
import dotenv, os
import requests
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose, verbose_logger
def create_client():
try:
@ -260,18 +259,12 @@ class ClickhouseLogger:
f"ClickhouseLogger Logging - Enters logging function for model {kwargs}"
)
# follows the same params as langfuse.py
from litellm.proxy.utils import get_logging_payload
payload = get_logging_payload(
kwargs=kwargs,
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
payload: Optional[StandardLoggingPayload] = kwargs.get(
"standard_logging_object"
)
metadata = payload.get("metadata", "") or ""
request_tags = payload.get("request_tags", "") or ""
payload["metadata"] = str(metadata)
payload["request_tags"] = str(request_tags)
if payload is None:
return
# Build the initial payload
verbose_logger.debug(f"\nClickhouse Logger - Logging payload = {payload}")

View file

@ -12,7 +12,12 @@ from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
from litellm.types.utils import (
AdapterCompletionStreamWrapper,
EmbeddingResponse,
ImageResponse,
ModelResponse,
)
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
@ -140,8 +145,8 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
self,
data: dict,
user_api_key_dict: UserAPIKeyAuth,
response,
):
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
) -> Any:
pass
async def async_logging_hook(
@ -188,7 +193,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
kwargs,
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event(
@ -202,7 +207,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
kwargs,
)
print_verbose(f"Custom Logger - model call details: {kwargs}")
except:
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event(
@ -217,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
start_time,
end_time,
)
except:
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass
@ -233,6 +238,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
start_time,
end_time,
)
except:
except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass

View file

@ -54,7 +54,7 @@ class DataDogLogger(CustomBatchLogger):
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
"""
try:
verbose_logger.debug(f"Datadog: in init datadog logger")
verbose_logger.debug("Datadog: in init datadog logger")
# check if the correct env variables are set
if os.getenv("DD_API_KEY", None) is None:
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
@ -245,12 +245,12 @@ class DataDogLogger(CustomBatchLogger):
usage = dict(usage)
try:
response_time = (end_time - start_time).total_seconds() * 1000
except:
except Exception:
response_time = None
try:
response_obj = dict(response_obj)
except:
except Exception:
response_obj = response_obj
# Clean Metadata before logging - never log raw metadata

View file

@ -7,7 +7,7 @@ def make_json_serializable(payload):
elif not isinstance(value, (str, int, float, bool, type(None))):
# everything else becomes a string
payload[key] = str(value)
except:
except Exception:
# non blocking if it can't cast to a str
pass
return payload

View file

@ -1,12 +1,16 @@
#### What this does ####
# On success + failure, log events to Supabase
import dotenv, os
import requests # type: ignore
import datetime
import os
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose
import uuid
from typing import Any
import dotenv
import requests # type: ignore
import litellm
class DyanmoDBLogger:
@ -16,7 +20,7 @@ class DyanmoDBLogger:
# Instance variables
import boto3
self.dynamodb = boto3.resource(
self.dynamodb: Any = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
)
if litellm.dynamodb_table_name is None:
@ -67,7 +71,7 @@ class DyanmoDBLogger:
for key, value in payload.items():
try:
payload[key] = str(value)
except:
except Exception:
# non blocking if it can't cast to a str
pass
@ -84,6 +88,6 @@ class DyanmoDBLogger:
f"DynamoDB Layer Logging - final response object: {response_obj}"
)
return response
except:
except Exception:
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass

View file

@ -9,6 +9,7 @@ import litellm
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client,
httpxSpecialProvider,
)
@ -41,7 +42,7 @@ class GalileoObserve(CustomLogger):
self.batch_size = 1
self.base_url = os.getenv("GALILEO_BASE_URL", None)
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
self.headers = None
self.headers: Optional[Dict[str, str]] = None
self.async_httpx_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback
)
@ -54,7 +55,7 @@ class GalileoObserve(CustomLogger):
"accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded",
}
galileo_login_response = self.async_httpx_handler.post(
galileo_login_response = litellm.module_level_client.post(
url=f"{self.base_url}/login",
headers=headers,
data={
@ -94,13 +95,9 @@ class GalileoObserve(CustomLogger):
return output
async def async_log_success_event(
self,
kwargs,
start_time,
end_time,
response_obj,
self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
):
verbose_logger.debug(f"On Async Success")
verbose_logger.debug("On Async Success")
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
_call_type = kwargs.get("call_type", "litellm")
@ -116,6 +113,7 @@ class GalileoObserve(CustomLogger):
response_obj=response_obj, kwargs=kwargs
)
if output_text is not None:
request_record = LLMResponse(
latency_ms=_latency_ms,
status_code=200,
@ -159,4 +157,4 @@ class GalileoObserve(CustomLogger):
)
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
verbose_logger.debug(f"On Async Failure")
verbose_logger.debug("On Async Failure")

View file

@ -56,8 +56,8 @@ class GCSBucketLogger(GCSBucketBase):
response_obj,
)
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers()
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
@ -103,8 +103,8 @@ class GCSBucketLogger(GCSBucketBase):
response_obj,
)
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers()
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(

View file

@ -1,8 +1,9 @@
import requests # type: ignore
import json
import traceback
from datetime import datetime, timezone
import requests # type: ignore
class GreenscaleLogger:
def __init__(self):
@ -29,7 +30,7 @@ class GreenscaleLogger:
"%Y-%m-%dT%H:%M:%SZ"
)
if type(end_time) == datetime and type(start_time) == datetime:
if type(end_time) is datetime and type(start_time) is datetime:
data["invocationLatency"] = int(
(end_time - start_time).total_seconds() * 1000
)
@ -50,6 +51,9 @@ class GreenscaleLogger:
data["tags"] = tags
if self.greenscale_logging_url is None:
raise Exception("Greenscale Logger Error - No logging URL found")
response = requests.post(
self.greenscale_logging_url,
headers=self.headers,

View file

@ -1,15 +1,28 @@
#### What this does ####
# On success, logs events to Helicone
import dotenv, os
import requests # type: ignore
import litellm
import os
import traceback
import dotenv
import requests # type: ignore
import litellm
from litellm._logging import verbose_logger
class HeliconeLogger:
# Class variables or attributes
helicone_model_list = ["gpt", "claude", "command-r", "command-r-plus", "command-light", "command-medium", "command-medium-beta", "command-xlarge-nightly", "command-nightly"]
helicone_model_list = [
"gpt",
"claude",
"command-r",
"command-r-plus",
"command-light",
"command-medium",
"command-medium-beta",
"command-xlarge-nightly",
"command-nightly",
]
def __init__(self):
# Instance variables
@ -17,7 +30,7 @@ class HeliconeLogger:
self.key = os.getenv("HELICONE_API_KEY")
def claude_mapping(self, model, messages, response_obj):
from anthropic import HUMAN_PROMPT, AI_PROMPT
from anthropic import AI_PROMPT, HUMAN_PROMPT
prompt = f"{HUMAN_PROMPT}"
for message in messages:
@ -29,7 +42,6 @@ class HeliconeLogger:
else:
prompt += f"{HUMAN_PROMPT}{message['content']}"
prompt += f"{AI_PROMPT}"
claude_provider_request = {"model": model, "prompt": prompt}
choice = response_obj["choices"][0]
message = choice["message"]
@ -37,12 +49,14 @@ class HeliconeLogger:
content = []
if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]:
content.append({
content.append(
{
"type": "tool_use",
"id": tool_call["id"],
"name": tool_call["function"]["name"],
"input": tool_call["function"]["arguments"]
})
"input": tool_call["function"]["arguments"],
}
)
elif "content" in message and message["content"]:
content = [{"type": "text", "text": message["content"]}]
@ -56,8 +70,8 @@ class HeliconeLogger:
"stop_sequence": None,
"usage": {
"input_tokens": response_obj["usage"]["prompt_tokens"],
"output_tokens": response_obj["usage"]["completion_tokens"]
}
"output_tokens": response_obj["usage"]["completion_tokens"],
},
}
return claude_response_obj
@ -99,10 +113,8 @@ class HeliconeLogger:
f"Helicone Logging - Enters logging function for model {model}"
)
litellm_params = kwargs.get("litellm_params", {})
litellm_call_id = kwargs.get("litellm_call_id", None)
metadata = (
litellm_params.get("metadata", {}) or {}
)
kwargs.get("litellm_call_id", None)
metadata = litellm_params.get("metadata", {}) or {}
metadata = self.add_metadata_from_header(litellm_params, metadata)
model = (
model
@ -175,6 +187,6 @@ class HeliconeLogger:
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
)
print_verbose(f"Helicone Logging - Error {response.text}")
except:
except Exception:
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
pass

View file

@ -11,7 +11,7 @@ import dotenv
import httpx
import litellm
from litellm import verbose_logger
from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import (
HTTPHandler,
@ -65,8 +65,8 @@ class LagoLogger(CustomLogger):
raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict:
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat()
response_obj.get("id", kwargs.get("litellm_call_id"))
get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None)
model = kwargs.get("model")
usage = {}
@ -86,7 +86,7 @@ class LagoLogger(CustomLogger):
end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
litellm_params["metadata"].get("user_api_key_org_id", None)
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
external_customer_id: Optional[str] = None
@ -158,8 +158,9 @@ class LagoLogger(CustomLogger):
response.raise_for_status()
except Exception as e:
if hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
error_response = getattr(e, "response", None)
if error_response is not None and hasattr(error_response, "text"):
verbose_logger.debug(f"\nError Message: {error_response.text}")
raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -199,5 +200,5 @@ class LagoLogger(CustomLogger):
verbose_logger.debug(f"Logged Lago Object: {response.text}")
except Exception as e:
if response is not None and hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}")
verbose_logger.debug(f"\nError Message: {response.text}")
raise e

View file

@ -67,7 +67,7 @@ class LangFuseLogger:
try:
project_id = self.Langfuse.client.projects.get().data[0].id
os.environ["LANGFUSE_PROJECT_ID"] = project_id
except:
except Exception:
project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
@ -184,7 +184,7 @@ class LangFuseLogger:
if not isinstance(value, (str, int, bool, float)):
try:
optional_params[param] = str(value)
except:
except Exception:
# if casting value to str fails don't block logging
pass
@ -275,7 +275,7 @@ class LangFuseLogger:
print_verbose(
f"Langfuse Layer Logging - final response object: {response_obj}"
)
verbose_logger.info(f"Langfuse Layer Logging - logging success")
verbose_logger.info("Langfuse Layer Logging - logging success")
return {"trace_id": trace_id, "generation_id": generation_id}
except Exception as e:
@ -492,7 +492,7 @@ class LangFuseLogger:
output if not mask_output else "redacted-by-litellm"
)
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
if debug is True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params:
# log the raw_metadata in the trace
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
@ -535,8 +535,8 @@ class LangFuseLogger:
proxy_server_request = litellm_params.get("proxy_server_request", None)
if proxy_server_request:
method = proxy_server_request.get("method", None)
url = proxy_server_request.get("url", None)
proxy_server_request.get("method", None)
proxy_server_request.get("url", None)
headers = proxy_server_request.get("headers", None)
clean_headers = {}
if headers:
@ -625,7 +625,7 @@ class LangFuseLogger:
generation_client = trace.generation(**generation_params)
return generation_client.trace_id, generation_id
except Exception as e:
except Exception:
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None

View file

@ -404,7 +404,7 @@ class LangsmithLogger(CustomBatchLogger):
verbose_logger.exception(
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
)
except Exception as e:
except Exception:
verbose_logger.exception(
f"Langsmith Layer Error - {traceback.format_exc()}"
)

View file

@ -1,6 +1,10 @@
import requests, traceback, json, os
import json
import os
import traceback
import types
import requests
class LiteDebugger:
user_email = None
@ -17,23 +21,17 @@ class LiteDebugger:
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
)
if (
self.user_email == None
self.user_email is None
): # if users are trying to use_client=True but token not set
raise ValueError(
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
)
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try:
print(
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
)
except:
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
if self.user_email == None:
if self.user_email is None:
raise ValueError(
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
)
except Exception as e:
except Exception:
raise ValueError(
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
)
@ -49,123 +47,18 @@ class LiteDebugger:
litellm_params,
optional_params,
):
print_verbose(
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
)
try:
print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
)
def remove_key_value(dictionary, key):
new_dict = dictionary.copy() # Create a copy of the original dictionary
new_dict.pop(key) # Remove the specified key-value pair from the copy
return new_dict
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
if call_type == "embedding":
for (
message
) in (
messages
): # assuming the input is a list as required by the embedding function
litellm_data_obj = {
"model": model,
"messages": [{"role": "user", "content": message}],
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
"litellm_params": updated_litellm_params,
"optional_params": optional_params,
}
print_verbose(
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: embedding api response - {response.text}")
elif call_type == "completion":
litellm_data_obj = {
"model": model,
"messages": messages
if isinstance(messages, list)
else [{"role": "user", "content": messages}],
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
"litellm_params": updated_litellm_params,
"optional_params": optional_params,
}
print_verbose(
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(
f"LiteDebugger: completion api response - {response.text}"
)
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass
"""
This integration is not implemented yet.
"""
return
def post_call_log_event(
self, original_response, litellm_call_id, print_verbose, call_type, stream
):
print_verbose(
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
)
try:
if call_type == "embedding":
litellm_data_obj = {
"status": "received",
"additional_details": {
"original_response": str(
original_response["data"][0]["embedding"][:5]
)
}, # don't store the entire vector
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
elif call_type == "completion" and not stream:
litellm_data_obj = {
"status": "received",
"additional_details": {"original_response": original_response},
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
elif call_type == "completion" and stream:
litellm_data_obj = {
"status": "received",
"additional_details": {
"original_response": "Streamed response"
if isinstance(original_response, types.GeneratorType)
else original_response
},
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
print_verbose(f"litedebugger post-call data object - {litellm_data_obj}")
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: api response - {response.text}")
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
"""
This integration is not implemented yet.
"""
return
def log_event(
self,
@ -178,85 +71,7 @@ class LiteDebugger:
call_type,
stream=False,
):
print_verbose(
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
)
try:
print_verbose(
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
)
total_cost = 0 # [TODO] implement cost tracking
response_time = (end_time - start_time).total_seconds()
if call_type == "completion" and stream == False:
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,
"response": response_obj["choices"][0]["message"]["content"],
"litellm_call_id": litellm_call_id,
"status": "success",
}
print_verbose(
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif call_type == "embedding":
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,
"response": str(response_obj["data"][0]["embedding"][:5]),
"litellm_call_id": litellm_call_id,
"status": "success",
}
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif call_type == "completion" and stream == True:
if len(response_obj["content"]) > 0: # don't log the empty strings
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,
"response": response_obj["content"],
"litellm_call_id": litellm_call_id,
"status": "success",
}
print_verbose(
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif "error" in response_obj:
if "Unable to map your input to a model." in response_obj["error"]:
total_cost = 0
litellm_data_obj = {
"response_time": response_time,
"model": response_obj["model"],
"total_cost": total_cost,
"error": response_obj["error"],
"end_user": end_user,
"litellm_call_id": litellm_call_id,
"status": "failure",
"user_email": self.user_email,
}
print_verbose(
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: api response - {response.text}")
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass
"""
This integration is not implemented yet.
"""
return

View file

@ -27,7 +27,7 @@ class LogfireLogger:
# Class variables or attributes
def __init__(self):
try:
verbose_logger.debug(f"in init logfire logger")
verbose_logger.debug("in init logfire logger")
import logfire
# only setting up logfire if we are sending to logfire
@ -116,7 +116,7 @@ class LogfireLogger:
id = response_obj.get("id", str(uuid.uuid4()))
try:
response_time = (end_time - start_time).total_seconds()
except:
except Exception:
response_time = None
# Clean Metadata before logging - never log raw metadata

View file

@ -1,8 +1,8 @@
#### What this does ####
# On success + failure, log events to lunary.ai
from datetime import datetime, timezone
import traceback
import importlib
import traceback
from datetime import datetime, timezone
import packaging
@ -74,9 +74,9 @@ class LunaryLogger:
try:
import lunary
version = importlib.metadata.version("lunary")
version = importlib.metadata.version("lunary") # type: ignore
# if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
)
@ -97,7 +97,7 @@ class LunaryLogger:
run_id,
model,
print_verbose,
extra=None,
extra={},
input=None,
user_id=None,
response_obj=None,
@ -128,7 +128,7 @@ class LunaryLogger:
if not isinstance(value, (str, int, bool, float)) and param != "tools":
try:
extra[param] = str(value)
except:
except Exception:
pass
if response_obj:
@ -175,6 +175,6 @@ class LunaryLogger:
token_usage=usage,
)
except:
except Exception:
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
pass

View file

@ -98,7 +98,7 @@ class OpenTelemetry(CustomLogger):
import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
logging.getLogger(__name__)
# Enable OpenTelemetry logging
otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
@ -520,7 +520,7 @@ class OpenTelemetry(CustomLogger):
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
from litellm.proxy._types import SpanAttributes
optional_params = kwargs.get("optional_params", {})
kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {}
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
@ -769,6 +769,6 @@ class OpenTelemetry(CustomLogger):
management_endpoint_span.set_attribute(f"request.{key}", value)
_exception = logging_payload.exception
management_endpoint_span.set_attribute(f"exception", str(_exception))
management_endpoint_span.set_attribute("exception", str(_exception))
management_endpoint_span.set_status(Status(StatusCode.ERROR))
management_endpoint_span.end(end_time=_end_time_ns)

View file

@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
user_api_team_alias = standard_logging_payload["metadata"][
"user_api_key_team_alias"
]
exception = kwargs.get("exception", None)
kwargs.get("exception", None)
try:
self.litellm_llm_api_failed_requests_metric.labels(
@ -679,7 +679,7 @@ class PrometheusLogger(CustomLogger):
).inc()
pass
except:
except Exception:
pass
def set_llm_deployment_success_metrics(
@ -800,7 +800,7 @@ class PrometheusLogger(CustomLogger):
if (
request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] == True
and request_kwargs["stream"] is True
):
# only log ttft for streaming request
time_to_first_token_response_time = (

View file

@ -3,11 +3,17 @@
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
import dotenv, os
import requests # type: ignore
import datetime
import os
import subprocess
import sys
import traceback
import datetime, subprocess, sys
import litellm, uuid
import uuid
import dotenv
import requests # type: ignore
import litellm
from litellm._logging import print_verbose, verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
@ -23,7 +29,7 @@ class PrometheusServicesLogger:
):
try:
try:
from prometheus_client import Counter, Histogram, REGISTRY
from prometheus_client import REGISTRY, Counter, Histogram
except ImportError:
raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`"
@ -33,7 +39,7 @@ class PrometheusServicesLogger:
self.Counter = Counter
self.REGISTRY = REGISTRY
verbose_logger.debug(f"in init prometheus services metrics")
verbose_logger.debug("in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes]

View file

@ -1,9 +1,11 @@
#### What this does ####
# On success, logs events to Promptlayer
import dotenv, os
import os
import traceback
import dotenv
import requests # type: ignore
from pydantic import BaseModel
import traceback
class PromptLayerLogger:
@ -84,6 +86,6 @@ class PromptLayerLogger:
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
)
except:
except Exception:
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
pass

View file

@ -1,10 +1,15 @@
#### What this does ####
# On success + failure, log events to Supabase
import dotenv, os
import requests # type: ignore
import datetime
import os
import subprocess
import sys
import traceback
import datetime, subprocess, sys
import dotenv
import requests # type: ignore
import litellm
@ -21,7 +26,12 @@ class Supabase:
except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
import supabase
self.supabase_client = supabase.create_client(
if self.supabase_url is None or self.supabase_key is None:
raise ValueError(
"LiteLLM Error, trying to use Supabase but url or key not passed. Create a table and set `litellm.supabase_url=<your-url>` and `litellm.supabase_key=<your-key>`"
)
self.supabase_client = supabase.create_client( # type: ignore
self.supabase_url, self.supabase_key
)
@ -45,7 +55,7 @@ class Supabase:
.execute()
)
print_verbose(f"data: {data}")
except:
except Exception:
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass
@ -109,6 +119,6 @@ class Supabase:
.execute()
)
except:
except Exception:
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass

View file

@ -167,18 +167,17 @@ try:
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
return trace
except:
except Exception:
imported_openAIResponse = False
#### What this does ####
# On success, logs events to Langfuse
import os
import requests
import requests
import traceback
from datetime import datetime
import traceback
import requests
class WeightsBiasesLogger:
@ -186,11 +185,11 @@ class WeightsBiasesLogger:
def __init__(self):
try:
import wandb
except:
except Exception:
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
if imported_openAIResponse == False:
if imported_openAIResponse is False:
raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
)
@ -209,13 +208,14 @@ class WeightsBiasesLogger:
kwargs, response_obj, (end_time - start_time).total_seconds()
)
if trace is not None:
if trace is not None and run is not None:
run.log({"trace": trace})
if run is not None:
run.finish()
print_verbose(
f"W&B Logging Logging - final response object: {response_obj}"
)
except:
except Exception:
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
pass

View file

@ -62,7 +62,7 @@ def get_error_message(error_obj) -> Optional[str]:
# If all else fails, return None
return None
except Exception as e:
except Exception:
return None
@ -910,7 +910,7 @@ def exception_type( # type: ignore
):
exception_mapping_worked = True
raise BadRequestError(
message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
model=model,
llm_provider="sagemaker",
response=original_exception.response,
@ -1122,7 +1122,7 @@ def exception_type( # type: ignore
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
exception_mapping_worked = True
raise BadRequestError(
message=f"GeminiException - Invalid api key",
message="GeminiException - Invalid api key",
model=model,
llm_provider="palm",
response=original_exception.response,
@ -2067,12 +2067,34 @@ def exception_logging(
logger_fn(
model_call_details
) # Expectation: any logger function passed in by the user should accept a dict object
except Exception as e:
except Exception:
verbose_logger.debug(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
except Exception as e:
except Exception:
verbose_logger.debug(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
)
pass
def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
"""
Internal helper function for litellm proxy
Add the Key Name + Team Name to the error
Only gets added if the metadata contains the user_api_key_alias and user_api_key_team_alias
[Non-Blocking helper function]
"""
try:
_api_key_name = metadata.get("user_api_key_alias", None)
_user_api_key_team_alias = metadata.get("user_api_key_team_alias", None)
if _api_key_name is not None:
request_info = (
f"\n\nKey Name: `{_api_key_name}`\nTeam: `{_user_api_key_team_alias}`"
+ request_info
)
return request_info
except Exception:
return request_info

View file

@ -476,7 +476,7 @@ def get_llm_provider(
elif model == "*":
custom_llm_provider = "openai"
if custom_llm_provider is None or custom_llm_provider == "":
if litellm.suppress_debug_info == False:
if litellm.suppress_debug_info is False:
print() # noqa
print( # noqa
"\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa

View file

@ -52,18 +52,8 @@ from litellm.types.utils import (
)
from litellm.utils import (
_get_base_model_from_metadata,
add_breadcrumb,
capture_exception,
customLogger,
liteDebuggerClient,
logfireLogger,
lunaryLogger,
print_verbose,
prometheusLogger,
prompt_token_calculator,
promptLayerLogger,
supabaseClient,
weightsBiasesLogger,
)
from ..integrations.aispend import AISpendLogger
@ -71,7 +61,6 @@ from ..integrations.athina import AthinaLogger
from ..integrations.berrispend import BerriSpendLogger
from ..integrations.braintrust_logging import BraintrustLogger
from ..integrations.clickhouse import ClickhouseLogger
from ..integrations.custom_logger import CustomLogger
from ..integrations.datadog.datadog import DataDogLogger
from ..integrations.dynamodb import DyanmoDBLogger
from ..integrations.galileo import GalileoObserve
@ -423,7 +412,7 @@ class Logging:
elif callback == "sentry" and add_breadcrumb:
try:
details_to_log = copy.deepcopy(self.model_call_details)
except:
except Exception:
details_to_log = self.model_call_details
if litellm.turn_off_message_logging:
# make a copy of the _model_Call_details and log it
@ -528,7 +517,7 @@ class Logging:
verbose_logger.debug("reaches sentry breadcrumbing")
try:
details_to_log = copy.deepcopy(self.model_call_details)
except:
except Exception:
details_to_log = self.model_call_details
if litellm.turn_off_message_logging:
# make a copy of the _model_Call_details and log it
@ -1326,7 +1315,7 @@ class Logging:
and customLogger is not None
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
"success callbacks: Running Custom Callback Function"
)
customLogger.log_event(
kwargs=self.model_call_details,
@ -1400,7 +1389,7 @@ class Logging:
self.model_call_details["response_cost"] = 0.0
else:
# check if base_model set on azure
base_model = _get_base_model_from_metadata(
_get_base_model_from_metadata(
model_call_details=self.model_call_details
)
# base_model defaults to None if not set on model_info
@ -1483,7 +1472,7 @@ class Logging:
for callback in callbacks:
# check if callback can run for this request
litellm_params = self.model_call_details.get("litellm_params", {})
if litellm_params.get("no-log", False) == True:
if litellm_params.get("no-log", False) is True:
# proxy cost tracking cal backs should run
if not (
isinstance(callback, CustomLogger)
@ -1492,7 +1481,7 @@ class Logging:
print_verbose("no-log request, skipping logging")
continue
try:
if kwargs.get("no-log", False) == True:
if kwargs.get("no-log", False) is True:
print_verbose("no-log request, skipping logging")
continue
if (
@ -1641,7 +1630,7 @@ class Logging:
end_time=end_time,
print_verbose=print_verbose,
)
except Exception as e:
except Exception:
verbose_logger.error(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
)
@ -2433,7 +2422,7 @@ def get_standard_logging_object_payload(
call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False)
usage = response_obj.get("usage", None) or {}
if type(usage) == litellm.Usage:
if type(usage) is litellm.Usage:
usage = dict(usage)
id = response_obj.get("id", kwargs.get("litellm_call_id"))
@ -2656,3 +2645,11 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
litellm_params["metadata"] = metadata
return litellm_params
# integration helper function
def modify_integration(integration_name, integration_params):
global supabaseClient
if integration_name == "supabase":
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]

View file

@ -45,7 +45,7 @@ def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider
)
except:
except Exception:
continue
if model_info.get("mode") != "chat":
continue

View file

@ -123,7 +123,7 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
@ -165,7 +165,7 @@ def completion(
)
if response.status_code != 200:
raise AI21Error(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
@ -191,7 +191,7 @@ def completion(
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception as e:
except Exception:
raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code
)

View file

@ -151,7 +151,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
@ -260,7 +260,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
if a_add_message is not None and a_add_message is True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
@ -365,7 +365,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
if aget_messages is not None and aget_messages is True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
@ -483,7 +483,7 @@ class AzureAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
if acreate_thread is not None and acreate_thread is True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
@ -586,7 +586,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
if aget_thread is not None and aget_thread is True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
@ -774,8 +774,8 @@ class AzureAssistantsAPI(BaseLLM):
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
if arun_thread is not None and arun_thread is True:
if stream is not None and stream is True:
azure_client = self.async_get_azure_client(
api_key=api_key,
api_base=api_base,
@ -823,7 +823,7 @@ class AzureAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
if stream is not None and stream is True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,
@ -887,7 +887,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None,
async_create_assistants=None,
):
if async_create_assistants is not None and async_create_assistants == True:
if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants(
api_key=api_key,
api_base=api_base,
@ -950,7 +950,7 @@ class AzureAssistantsAPI(BaseLLM):
async_delete_assistants: Optional[bool] = None,
client=None,
):
if async_delete_assistants is not None and async_delete_assistants == True:
if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant(
api_key=api_key,
api_base=api_base,

View file

@ -317,7 +317,7 @@ class AzureOpenAIAssistantsAPIConfig:
if "file_id" in item:
file_ids.append(item["file_id"])
else:
if litellm.drop_params == True:
if litellm.drop_params is True:
pass
else:
raise litellm.utils.UnsupportedParamsError(
@ -580,7 +580,7 @@ class AzureChatCompletion(BaseLLM):
try:
if model is None or messages is None:
raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages"
status_code=422, message="Missing model or messages"
)
max_retries = optional_params.pop("max_retries", 2)
@ -1240,12 +1240,6 @@ class AzureChatCompletion(BaseLLM):
)
while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs:
timeout_msg = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
raise AzureOpenAIError(
status_code=408, message="Operation polling timed out."
@ -1493,7 +1487,6 @@ class AzureChatCompletion(BaseLLM):
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
try:
if model and len(model) > 0:
model = model
@ -1534,7 +1527,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True:
if aimg_generation is True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response

View file

@ -1263,7 +1263,6 @@ class OpenAIChatCompletion(BaseLLM):
client=None,
aimg_generation=None,
):
exception_mapping_worked = False
data = {}
try:
model = model
@ -1272,7 +1271,7 @@ class OpenAIChatCompletion(BaseLLM):
if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int")
if aimg_generation == True:
if aimg_generation is True:
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response
@ -1311,7 +1310,6 @@ class OpenAIChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
input=prompt,
@ -1543,7 +1541,7 @@ class OpenAITextCompletion(BaseLLM):
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
and isinstance(messages[0]["content"], list)
):
prompt = messages[0]["content"]
else:
@ -2413,7 +2411,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None,
aget_assistants=None,
):
if aget_assistants is not None and aget_assistants == True:
if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants(
api_key=api_key,
api_base=api_base,
@ -2470,7 +2468,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None,
async_create_assistants=None,
):
if async_create_assistants is not None and async_create_assistants == True:
if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants(
api_key=api_key,
api_base=api_base,
@ -2527,7 +2525,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None,
async_delete_assistants=None,
):
if async_delete_assistants is not None and async_delete_assistants == True:
if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant(
api_key=api_key,
api_base=api_base,
@ -2629,7 +2627,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None,
a_add_message: Optional[bool] = None,
):
if a_add_message is not None and a_add_message == True:
if a_add_message is not None and a_add_message is True:
return self.a_add_message(
thread_id=thread_id,
message_data=message_data,
@ -2727,7 +2725,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None,
aget_messages=None,
):
if aget_messages is not None and aget_messages == True:
if aget_messages is not None and aget_messages is True:
return self.async_get_messages(
thread_id=thread_id,
api_key=api_key,
@ -2838,7 +2836,7 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message])
```
"""
if acreate_thread is not None and acreate_thread == True:
if acreate_thread is not None and acreate_thread is True:
return self.async_create_thread(
metadata=metadata,
api_key=api_key,
@ -2934,7 +2932,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None,
aget_thread=None,
):
if aget_thread is not None and aget_thread == True:
if aget_thread is not None and aget_thread is True:
return self.async_get_thread(
thread_id=thread_id,
api_key=api_key,
@ -3117,8 +3115,8 @@ class OpenAIAssistantsAPI(BaseLLM):
arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None,
):
if arun_thread is not None and arun_thread == True:
if stream is not None and stream == True:
if arun_thread is not None and arun_thread is True:
if stream is not None and stream is True:
_client = self.async_get_openai_client(
api_key=api_key,
api_base=api_base,
@ -3163,7 +3161,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client,
)
if stream is not None and stream == True:
if stream is not None and stream is True:
return self.run_thread_stream(
client=openai_client,
thread_id=thread_id,

View file

@ -191,7 +191,7 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
default_max_tokens_to_sample=None,
@ -246,7 +246,7 @@ def completion(
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
@ -279,7 +279,7 @@ def completion(
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except:
except Exception:
raise AlephAlphaError(
message=json.dumps(completion_response),
status_code=response.status_code,

View file

@ -607,7 +607,6 @@ class ModelResponseIterator:
def _handle_usage(
self, anthropic_usage_chunk: Union[dict, UsageDelta]
) -> AnthropicChatCompletionUsageBlock:
special_fields = ["input_tokens", "output_tokens"]
usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
@ -683,7 +682,7 @@ class ModelResponseIterator:
"index": self.tool_index,
}
elif type_chunk == "content_block_stop":
content_block_stop = ContentBlockStop(**chunk) # type: ignore
ContentBlockStop(**chunk) # type: ignore
# check if tool call content block
is_empty = self.check_empty_tool_call_args()
if is_empty:

View file

@ -114,7 +114,7 @@ class AnthropicTextCompletion(BaseLLM):
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
except Exception:
raise AnthropicError(
message=response.text, status_code=response.status_code
)
@ -229,7 +229,7 @@ class AnthropicTextCompletion(BaseLLM):
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
headers={},
@ -276,8 +276,8 @@ class AnthropicTextCompletion(BaseLLM):
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if acompletion == True:
if "stream" in optional_params and optional_params["stream"] is True:
if acompletion is True:
return self.async_streaming(
model=model,
api_base=api_base,
@ -309,7 +309,7 @@ class AnthropicTextCompletion(BaseLLM):
logging_obj=logging_obj,
)
return stream_response
elif acompletion == True:
elif acompletion is True:
return self.async_completion(
model=model,
model_response=model_response,

View file

@ -233,7 +233,7 @@ class AzureTextCompletion(BaseLLM):
client=client,
logging_obj=logging_obj,
)
elif "stream" in optional_params and optional_params["stream"] == True:
elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming(
logging_obj=logging_obj,
api_base=api_base,

View file

@ -36,7 +36,7 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
@ -59,7 +59,7 @@ def completion(
"parameters": optional_params,
"stream": (
True
if "stream" in optional_params and optional_params["stream"] == True
if "stream" in optional_params and optional_params["stream"] is True
else False
),
}
@ -77,12 +77,12 @@ def completion(
data=json.dumps(data),
stream=(
True
if "stream" in optional_params and optional_params["stream"] == True
if "stream" in optional_params and optional_params["stream"] is True
else False
),
)
if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True
"stream" in optional_params and optional_params["stream"] is True
):
return response.iter_lines()
else:

View file

@ -183,7 +183,7 @@ class BedrockConverseLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return litellm.AmazonConverseConfig()._transform_response(

View file

@ -251,9 +251,7 @@ class AmazonConverseConfig:
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"]
json_mode: Optional[bool] = inference_params.pop(
"json_mode", None
) # used for handling json_schema
inference_params.pop("json_mode", None) # used for handling json_schema
## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(

View file

@ -234,7 +234,7 @@ async def make_call(
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise BedrockError(status_code=500, message=str(e))
@ -335,7 +335,7 @@ class BedrockLLM(BaseAWSLLM):
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
except Exception:
raise BedrockError(message=response.text, status_code=422)
outputText: Optional[str] = None
@ -394,12 +394,12 @@ class BedrockLLM(BaseAWSLLM):
outputText # allow user to access raw anthropic tool calling response
)
if (
_is_function_call == True
_is_function_call is True
and stream is not None
and stream == True
and stream is True
):
print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
)
# return an iterator
streaming_model_response = ModelResponse(stream=True)
@ -440,7 +440,7 @@ class BedrockLLM(BaseAWSLLM):
model_response=streaming_model_response
)
print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
)
return litellm.CustomStreamWrapper(
completion_stream=completion_stream,
@ -597,7 +597,7 @@ class BedrockLLM(BaseAWSLLM):
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ##
@ -700,7 +700,7 @@ class BedrockLLM(BaseAWSLLM):
k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v
if stream == True:
if stream is True:
inference_params["stream"] = (
True # cohere requires stream = True in inference params
)
@ -845,7 +845,7 @@ class BedrockLLM(BaseAWSLLM):
if acompletion:
if isinstance(client, HTTPHandler):
client = None
if stream == True and provider != "ai21":
if stream is True and provider != "ai21":
return self.async_streaming(
model=model,
messages=messages,
@ -891,7 +891,7 @@ class BedrockLLM(BaseAWSLLM):
self.client = _get_httpx_client(_params) # type: ignore
else:
self.client = client
if (stream is not None and stream == True) and provider != "ai21":
if (stream is not None and stream is True) and provider != "ai21":
response = self.client.post(
url=proxy_endpoint_url,
headers=prepped.headers, # type: ignore
@ -929,7 +929,7 @@ class BedrockLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(
@ -980,7 +980,7 @@ class BedrockLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response(

View file

@ -260,7 +260,7 @@ class AmazonAnthropicConfig:
optional_params["top_p"] = value
if param == "stop":
optional_params["stop_sequences"] = value
if param == "stream" and value == True:
if param == "stream" and value is True:
optional_params["stream"] = value
return optional_params

View file

@ -39,7 +39,7 @@ class BedrockEmbedding(BaseAWSLLM):
) -> Tuple[Any, str]:
try:
from botocore.credentials import Credentials
except ImportError as e:
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them

View file

@ -130,7 +130,7 @@ def process_response(
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception as e:
except Exception:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
@ -219,7 +219,7 @@ async def async_completion(
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception as e:
except Exception:
raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model
)
@ -251,9 +251,9 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params: dict,
custom_prompt_dict={},
acompletion=False,
optional_params=None,
litellm_params=None,
logger_fn=None,
):
@ -268,20 +268,12 @@ def completion(
optional_params[k] = v
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
if custom_llm_provider == "anthropic":
prompt = prompt_factory(
prompt: str = prompt_factory( # type: ignore
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider="clarifai",
)
else:
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider=custom_llm_provider,
)
# print(prompt); exit(0)
data = {
@ -300,7 +292,7 @@ def completion(
"api_base": model,
},
)
if acompletion == True:
if acompletion is True:
return async_completion(
model=model,
prompt=prompt,
@ -331,7 +323,7 @@ def completion(
status_code=response.status_code, message=response.text, url=model
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper(
completion_stream=completion_stream,

View file

@ -80,8 +80,8 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params: dict,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
@ -97,7 +97,7 @@ def completion(
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
custom_prompt(
role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
@ -126,7 +126,7 @@ def completion(
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
response = requests.post(
api_base,
headers=headers,

View file

@ -268,7 +268,7 @@ def completion(
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
@ -283,12 +283,12 @@ def completion(
completion_response = response.json()
try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e:
except Exception:
raise CohereError(message=response.text, status_code=response.status_code)
## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response is not []:
if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format
tool_calls = []
for tool in cohere_tools_response:

View file

@ -146,7 +146,7 @@ def completion(
api_key,
logging_obj,
headers: dict,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
@ -198,7 +198,7 @@ def completion(
if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
@ -231,7 +231,7 @@ def completion(
)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception as e:
except Exception:
raise CohereError(
message=response.text, status_code=response.status_code
)

View file

@ -17,7 +17,7 @@ else:
try:
from litellm._version import version
except:
except Exception:
version = "0.0.0"
headers = {

View file

@ -1,15 +1,17 @@
from typing import Optional
import httpx
try:
from litellm._version import version
except:
except Exception:
version = "0.0.0"
headers = {
"User-Agent": f"litellm/{version}",
}
class HTTPHandler:
def __init__(self, concurrent_limit=1000):
# Create a client with a connection pool

View file

@ -113,7 +113,7 @@ class DatabricksConfig:
optional_params["max_tokens"] = value
if param == "n":
optional_params["n"] = value
if param == "stream" and value == True:
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "temperature":
optional_params["temperature"] = value
@ -564,7 +564,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise DatabricksError(
status_code=408, message="Timeout error occurred."
)
@ -614,7 +614,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code,
message=response.text if response else str(e),
)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise DatabricksError(
status_code=408, message="Timeout error occurred."
)
@ -669,7 +669,7 @@ class DatabricksChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base},
)
if aembedding == True:
if aembedding is True:
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler):
self.client = HTTPHandler(timeout=timeout) # type: ignore
@ -692,7 +692,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code,
message=e.response.text,
)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise DatabricksError(status_code=500, message=str(e))

View file

@ -71,7 +71,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self,
_is_async: bool,
create_file_data: CreateFileRequest,
api_base: str,
api_base: Optional[str],
api_key: Optional[str],
api_version: Optional[str],
timeout: Union[float, httpx.Timeout],
@ -117,7 +117,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self,
_is_async: bool,
file_content_request: FileContentRequest,
api_base: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
@ -168,7 +168,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self,
_is_async: bool,
file_id: str,
api_base: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
@ -220,7 +220,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self,
_is_async: bool,
file_id: str,
api_base: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],
@ -275,7 +275,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
def list_files(
self,
_is_async: bool,
api_base: str,
api_base: Optional[str],
api_key: Optional[str],
timeout: Union[float, httpx.Timeout],
max_retries: Optional[int],

View file

@ -41,7 +41,7 @@ class VertexFineTuningAPI(VertexLLM):
created_at = int(create_time_datetime.timestamp())
return created_at
except Exception as e:
except Exception:
return 0
def convert_vertex_response_to_open_ai_response(

View file

@ -136,7 +136,7 @@ class GeminiConfig:
# ):
# try:
# import google.generativeai as genai # type: ignore
# except:
# except Exception:
# raise Exception(
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
# )
@ -282,7 +282,7 @@ class GeminiConfig:
# completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None:
# raise Exception
# except:
# except Exception:
# original_response = f"response: {response}"
# if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}"
@ -374,7 +374,7 @@ class GeminiConfig:
# completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None:
# raise Exception
# except:
# except Exception:
# original_response = f"response: {response}"
# if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}"

View file

@ -13,6 +13,7 @@ import requests
import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
@ -181,7 +182,7 @@ class HuggingfaceConfig:
return optional_params
def get_hf_api_key(self) -> Optional[str]:
return litellm.utils.get_secret("HUGGINGFACE_API_KEY")
return get_secret_str("HUGGINGFACE_API_KEY")
def output_parser(generated_text: str):
@ -240,7 +241,7 @@ def read_tgi_conv_models():
# Cache the set for future use
conv_models_cache = conv_models
return tgi_models, conv_models
except:
except Exception:
return set(), set()
@ -372,7 +373,7 @@ class Huggingface(BaseLLM):
]["finish_reason"]
sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1:
@ -386,7 +387,7 @@ class Huggingface(BaseLLM):
):
sum_logprob = 0
for token in item["tokens"]:
if token["logprob"] != None:
if token["logprob"] is not None:
sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0:
message_obj = Message(
@ -417,7 +418,7 @@ class Huggingface(BaseLLM):
prompt_tokens = len(
encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here
except:
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
output_text = model_response["choices"][0]["message"].get("content", "")
@ -429,7 +430,7 @@ class Huggingface(BaseLLM):
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use the llama2 tokenizer here
except:
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:
@ -559,7 +560,7 @@ class Huggingface(BaseLLM):
True
if "stream" in optional_params
and isinstance(optional_params["stream"], bool)
and optional_params["stream"] == True # type: ignore
and optional_params["stream"] is True # type: ignore
else False
),
}
@ -595,7 +596,7 @@ class Huggingface(BaseLLM):
data["stream"] = ( # type: ignore
True # type: ignore
if "stream" in optional_params
and optional_params["stream"] == True
and optional_params["stream"] is True
else False
)
input_text = prompt
@ -631,7 +632,7 @@ class Huggingface(BaseLLM):
### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
response = requests.post(
completion_url,
headers=headers,
@ -691,7 +692,7 @@ class Huggingface(BaseLLM):
completion_response = response.json()
if isinstance(completion_response, dict):
completion_response = [completion_response]
except:
except Exception:
import traceback
raise HuggingfaceError(

View file

@ -101,7 +101,7 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
@ -135,7 +135,7 @@ def completion(
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
@ -159,7 +159,7 @@ def completion(
model_response.choices[0].message.content = completion_response[ # type: ignore
"answer"
]
except Exception as e:
except Exception:
raise MaritalkError(
message=response.text, status_code=response.status_code
)

View file

@ -120,7 +120,7 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
default_max_tokens_to_sample=None,
@ -164,7 +164,7 @@ def completion(
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return clean_and_iterate_chunks(response)
else:
## LOGGING
@ -178,7 +178,7 @@ def completion(
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
except Exception:
raise NLPCloudError(message=response.text, status_code=response.status_code)
if "error" in completion_response:
raise NLPCloudError(
@ -191,7 +191,7 @@ def completion(
model_response.choices[0].message.content = ( # type: ignore
completion_response["generated_text"]
)
except:
except Exception:
raise NLPCloudError(
message=json.dumps(completion_response),
status_code=response.status_code,

View file

@ -14,7 +14,7 @@ import requests # type: ignore
import litellm
from litellm import verbose_logger
from litellm.types.utils import ProviderField
from litellm.types.utils import ProviderField, StreamingChoices
from .prompt_templates.factory import custom_prompt, prompt_factory
@ -172,7 +172,7 @@ def _convert_image(image):
try:
from PIL import Image
except:
except Exception:
raise Exception(
"ollama image conversion failed please run `pip install Pillow`"
)
@ -184,7 +184,7 @@ def _convert_image(image):
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
if image_data.format in ["JPEG", "PNG"]:
return image
except:
except Exception:
return orig
jpeg_image = io.BytesIO()
image_data.convert("RGB").save(jpeg_image, "JPEG")
@ -195,13 +195,13 @@ def _convert_image(image):
# ollama implementation
def get_ollama_response(
model_response: litellm.ModelResponse,
api_base="http://localhost:11434",
model="llama2",
prompt="Why is the sky blue?",
optional_params=None,
logging_obj=None,
model: str,
prompt: str,
optional_params: dict,
logging_obj: Any,
encoding: Any,
acompletion: bool = False,
encoding=None,
api_base="http://localhost:11434",
):
if api_base.endswith("/api/generate"):
url = api_base
@ -242,7 +242,7 @@ def get_ollama_response(
},
)
if acompletion is True:
if stream == True:
if stream is True:
response = ollama_async_streaming(
url=url,
data=data,
@ -340,11 +340,16 @@ def ollama_completion_stream(url, data, logging_obj):
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
response_content = "".join(
chunk.choices[0].delta.content
for chunk in chain([first_chunk], streamwrapper)
if chunk.choices[0].delta.content
)
content_chunks = []
for chunk in chain([first_chunk], streamwrapper):
content_chunk = chunk.choices[0]
if (
isinstance(content_chunk, StreamingChoices)
and hasattr(content_chunk, "delta")
and hasattr(content_chunk.delta, "content")
):
content_chunks.append(content_chunk.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
@ -392,15 +397,27 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper)
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content
]
)
first_chunk = await anext(streamwrapper) # noqa F821
chunk_choice = first_chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
first_chunk_content = chunk_choice.delta.content or ""
else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,
@ -501,8 +518,8 @@ async def ollama_aembeddings(
prompts: List[str],
model_response: litellm.EmbeddingResponse,
optional_params: dict,
logging_obj=None,
encoding=None,
logging_obj: Any,
encoding: Any,
):
if api_base.endswith("/api/embed"):
url = api_base
@ -581,9 +598,9 @@ def ollama_embeddings(
api_base: str,
model: str,
prompts: list,
optional_params=None,
logging_obj=None,
model_response=None,
optional_params: dict,
model_response: litellm.EmbeddingResponse,
logging_obj: Any,
encoding=None,
):
return asyncio.run(

View file

@ -4,7 +4,7 @@ import traceback
import types
import uuid
from itertools import chain
from typing import List, Optional
from typing import Any, List, Optional
import aiohttp
import httpx
@ -15,6 +15,7 @@ import litellm
from litellm import verbose_logger
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices
class OllamaError(Exception):
@ -216,10 +217,10 @@ def get_ollama_response(
model_response: litellm.ModelResponse,
messages: list,
optional_params: dict,
model: str,
logging_obj: Any,
api_base="http://localhost:11434",
api_key: Optional[str] = None,
model="llama2",
logging_obj=None,
acompletion: bool = False,
encoding=None,
):
@ -252,10 +253,13 @@ def get_ollama_response(
for tool in m["tool_calls"]:
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
if typed_tool["type"] == "function":
arguments = {}
if "arguments" in typed_tool["function"]:
arguments = json.loads(typed_tool["function"]["arguments"])
ollama_tool_call = OllamaToolCall(
function=OllamaToolCallFunction(
name=typed_tool["function"]["name"],
arguments=json.loads(typed_tool["function"]["arguments"]),
name=typed_tool["function"].get("name") or "",
arguments=arguments,
)
)
new_tools.append(ollama_tool_call)
@ -401,12 +405,16 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = next(streamwrapper)
response_content = "".join(
chunk.choices[0].delta.content
for chunk in chain([first_chunk], streamwrapper)
if chunk.choices[0].delta.content
)
content_chunks = []
for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
@ -422,7 +430,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
}
],
)
model_response = first_chunk
model_response = content_chunks[0]
model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls"
yield model_response
@ -462,15 +470,28 @@ async def ollama_async_streaming(
# If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper)
first_chunk_content = first_chunk.choices[0].delta.content or ""
response_content = first_chunk_content + "".join(
[
chunk.choices[0].delta.content
async for chunk in streamwrapper
if chunk.choices[0].delta.content
]
)
first_chunk = await anext(streamwrapper) # noqa F821
chunk_choice = first_chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
first_chunk_content = chunk_choice.delta.content or ""
else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content)
delta = litellm.utils.Delta(
content=None,

View file

@ -39,8 +39,8 @@ def completion(
encoding,
api_key,
logging_obj,
optional_params: dict,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
default_max_tokens_to_sample=None,
@ -77,7 +77,7 @@ def completion(
data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines()
else:
## LOGGING
@ -91,7 +91,7 @@ def completion(
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
except Exception:
raise OobaboogaError(
message=response.text, status_code=response.status_code
)
@ -103,7 +103,7 @@ def completion(
else:
try:
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
except:
except Exception:
raise OobaboogaError(
message=json.dumps(completion_response),
status_code=response.status_code,

View file

@ -96,13 +96,13 @@ def completion(
api_key,
encoding,
logging_obj,
optional_params=None,
optional_params: dict,
litellm_params=None,
logger_fn=None,
):
try:
import google.generativeai as palm # type: ignore
except:
except Exception:
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
@ -167,14 +167,14 @@ def completion(
choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore
except Exception as e:
except Exception:
raise PalmError(
message=traceback.format_exc(), status_code=response.status_code
)
try:
completion_response = model_response["choices"][0]["message"].get("content")
except:
except Exception:
raise PalmError(
status_code=400,
message=f"No response received. Original response - {response}",

View file

@ -98,7 +98,7 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
optional_params=None,
optional_params: dict,
stream=False,
litellm_params=None,
logger_fn=None,
@ -123,6 +123,7 @@ def completion(
else:
prompt = prompt_factory(model=model, messages=messages)
output_text: Optional[str] = None
if api_base:
## LOGGING
logging_obj.pre_call(
@ -157,7 +158,7 @@ def completion(
import torch
from petals import AutoDistributedModelForCausalLM # type: ignore
from transformers import AutoTokenizer
except:
except Exception:
raise Exception(
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
)
@ -192,7 +193,7 @@ def completion(
## RESPONSE OBJECT
output_text = tokenizer.decode(outputs[0])
if len(output_text) > 0:
if output_text is not None and len(output_text) > 0:
model_response.choices[0].message.content = output_text # type: ignore
prompt_tokens = len(encoding.encode(prompt))

View file

@ -265,7 +265,7 @@ class PredibaseChatCompletion(BaseLLM):
## RESPONSE OBJECT
try:
completion_response = response.json()
except:
except Exception:
raise PredibaseError(message=response.text, status_code=422)
if "error" in completion_response:
raise PredibaseError(
@ -348,7 +348,7 @@ class PredibaseChatCompletion(BaseLLM):
model_response["choices"][0]["message"].get("content", "")
)
) ##[TODO] use a model-specific tokenizer
except:
except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails
pass
else:

View file

@ -5,7 +5,7 @@ import traceback
import uuid
import xml.etree.ElementTree as ET
from enum import Enum
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, cast
from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment
@ -26,11 +26,14 @@ from litellm.types.completion import (
)
from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.ollama import OllamaVisionModelObject
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall,
ChatCompletionFunctionMessage,
ChatCompletionImageObject,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk,
ChatCompletionToolMessage,
ChatCompletionUserMessage,
@ -164,7 +167,9 @@ def convert_to_ollama_image(openai_image_url: str):
def ollama_pt(
model, messages
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
) -> Union[
str, OllamaVisionModelObject
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model:
prompt = custom_prompt(
role_dict={
@ -438,7 +443,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
def _is_system_in_template():
try:
# Try rendering the template with a system message
response = template.render(
template.render(
messages=[{"role": "system", "content": "test"}],
eos_token="<eos>",
bos_token="<bos>",
@ -446,10 +451,11 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
return True
# This will be raised if Jinja attempts to render the system message and it can't
except:
except Exception:
return False
try:
rendered_text = ""
# Render the template with the provided values
if _is_system_in_template():
rendered_text = template.render(
@ -460,8 +466,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
)
else:
# treat a system message as a user message, if system not in template
try:
reformatted_messages = []
try:
for message in messages:
if message["role"] == "system":
reformatted_messages.append(
@ -556,30 +562,31 @@ def get_model_info(token, model):
return None, None
else:
return None, None
except Exception as e: # safely fail a prompt template request
except Exception: # safely fail a prompt template request
return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template):
if prompt_format is None:
return default_pt(messages)
## OLD TOGETHER AI FLOW
# def format_prompt_togetherai(messages, prompt_format, chat_template):
# if prompt_format is None:
# return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split("{prompt}")
# human_prompt, assistant_prompt = prompt_format.split("{prompt}")
if chat_template is not None:
prompt = hf_chat_template(
model=None, messages=messages, chat_template=chat_template
)
elif prompt_format is not None:
prompt = custom_prompt(
role_dict={},
messages=messages,
initial_prompt_value=human_prompt,
final_prompt_value=assistant_prompt,
)
else:
prompt = default_pt(messages)
return prompt
# if chat_template is not None:
# prompt = hf_chat_template(
# model=None, messages=messages, chat_template=chat_template
# )
# elif prompt_format is not None:
# prompt = custom_prompt(
# role_dict={},
# messages=messages,
# initial_prompt_value=human_prompt,
# final_prompt_value=assistant_prompt,
# )
# else:
# prompt = default_pt(messages)
# return prompt
### IBM Granite
@ -1063,7 +1070,7 @@ def convert_to_gemini_tool_call_invoke(
else: # don't silently drop params. Make it clear to user what's happening.
raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
tool
message
)
)
return _parts_list
@ -1216,12 +1223,14 @@ def convert_function_to_anthropic_tool_invoke(
function_call: Union[dict, ChatCompletionToolCallFunctionChunk],
) -> List[AnthropicMessagesToolUseParam]:
try:
_name = get_attribute_or_key(function_call, "name") or ""
_arguments = get_attribute_or_key(function_call, "arguments")
anthropic_tool_invoke = [
AnthropicMessagesToolUseParam(
type="tool_use",
id=str(uuid.uuid4()),
name=get_attribute_or_key(function_call, "name"),
input=json.loads(get_attribute_or_key(function_call, "arguments")),
name=_name,
input=json.loads(_arguments) if _arguments else {},
)
]
return anthropic_tool_invoke
@ -1349,8 +1358,9 @@ def anthropic_messages_pt(
):
for m in user_message_types_block["content"]:
if m.get("type", "") == "image_url":
m = cast(ChatCompletionImageObject, m)
image_chunk = convert_to_anthropic_image_obj(
m["image_url"]["url"]
openai_image_url=m["image_url"]["url"] # type: ignore
)
_anthropic_content_element = AnthropicMessagesImageParam(
@ -1362,21 +1372,31 @@ def anthropic_messages_pt(
),
)
anthropic_content_element = add_cache_control_to_content(
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_content_element,
orignal_content_element=m,
orignal_content_element=dict(m),
)
user_content.append(anthropic_content_element)
if "cache_control" in _content_element:
_anthropic_content_element["cache_control"] = (
_content_element["cache_control"]
)
user_content.append(_anthropic_content_element)
elif m.get("type", "") == "text":
_anthropic_text_content_element = {
"type": "text",
"text": m["text"],
}
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=m,
m = cast(ChatCompletionTextObject, m)
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=m["text"],
)
user_content.append(anthropic_content_element)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=dict(m),
)
_content_element = cast(
AnthropicMessagesTextParam, _content_element
)
user_content.append(_content_element)
elif (
user_message_types_block["role"] == "tool"
or user_message_types_block["role"] == "function"
@ -1390,12 +1410,17 @@ def anthropic_messages_pt(
"type": "text",
"text": user_message_types_block["content"],
}
anthropic_content_element = add_cache_control_to_content(
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_content_text_element,
orignal_content_element=user_message_types_block,
orignal_content_element=dict(user_message_types_block),
)
user_content.append(anthropic_content_element)
if "cache_control" in _content_element:
_anthropic_content_text_element["cache_control"] = _content_element[
"cache_control"
]
user_content.append(_anthropic_content_text_element)
msg_i += 1
@ -1417,11 +1442,14 @@ def anthropic_messages_pt(
anthropic_message = AnthropicMessagesTextParam(
type="text", text=m.get("text")
)
anthropic_message = add_cache_control_to_content(
_cached_message = add_cache_control_to_content(
anthropic_content_element=anthropic_message,
orignal_content_element=m,
orignal_content_element=dict(m),
)
assistant_content.append(
cast(AnthropicMessagesTextParam, _cached_message)
)
assistant_content.append(anthropic_message)
elif (
"content" in assistant_content_block
and isinstance(assistant_content_block["content"], str)
@ -1430,16 +1458,22 @@ def anthropic_messages_pt(
] # don't pass empty text blocks. anthropic api raises errors.
):
_anthropic_text_content_element = {
"type": "text",
"text": assistant_content_block["content"],
}
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=assistant_content_block,
_anthropic_text_content_element = AnthropicMessagesTextParam(
type="text",
text=assistant_content_block["content"],
)
assistant_content.append(anthropic_content_element)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=dict(assistant_content_block),
)
if "cache_control" in _content_element:
_anthropic_text_content_element["cache_control"] = _content_element[
"cache_control"
]
assistant_content.append(_anthropic_text_content_element)
assistant_tool_calls = assistant_content_block.get("tool_calls")
if (
@ -1566,30 +1600,6 @@ def get_system_prompt(messages):
return system_prompt, messages
def convert_to_documents(
observations: Any,
) -> List[MutableMapping]:
"""Converts observations into a 'document' dict"""
documents: List[MutableMapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
return documents
from litellm.types.llms.cohere import (
CallObject,
ChatHistory,
@ -1943,7 +1953,7 @@ def amazon_titan_pt(
def _load_image_from_url(image_url):
try:
from PIL import Image
except:
except Exception:
raise Exception("image conversion failed please run `pip install Pillow`")
from io import BytesIO
@ -2008,7 +2018,7 @@ def _gemini_vision_convert_messages(messages: list):
else:
try:
from PIL import Image
except:
except Exception:
raise Exception(
"gemini image conversion failed please run `pip install Pillow`"
)
@ -2056,7 +2066,7 @@ def gemini_text_image_pt(messages: list):
"""
try:
import google.generativeai as genai # type: ignore
except:
except Exception:
raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
)
@ -2331,7 +2341,7 @@ def _convert_to_bedrock_tool_call_result(
for content in content_list:
if content["type"] == "text":
content_str += content["text"]
name = message.get("name", "")
message.get("name", "")
id = str(message.get("tool_call_id", str(uuid.uuid4())))
tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
@ -2575,7 +2585,7 @@ def function_call_prompt(messages: list, functions: list):
message["content"] += f""" {function_prompt}"""
function_added_to_prompt = True
if function_added_to_prompt == False:
if function_added_to_prompt is False:
messages.append({"role": "system", "content": f"""{function_prompt}"""})
return messages
@ -2692,11 +2702,6 @@ def prompt_factory(
)
elif custom_llm_provider == "anthropic_xml":
return anthropic_messages_pt_xml(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini":
if (
model == "gemini-pro-vision"
@ -2810,7 +2815,7 @@ def prompt_factory(
)
else:
return hf_chat_template(original_model_name, messages)
except Exception as e:
except Exception:
return default_pt(
messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -61,7 +61,7 @@ async def async_convert_url_to_base64(url: str) -> str:
try:
response = await client.get(url, follow_redirects=True)
return _process_image_response(response, url)
except:
except Exception:
pass
raise Exception(
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"

View file

@ -297,7 +297,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
if "output" in response_data:
try:
output_string = "".join(response_data["output"])
except Exception as e:
except Exception:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
@ -344,7 +344,7 @@ async def async_handle_prediction_response_streaming(
if "output" in response_data:
try:
output_string = "".join(response_data["output"])
except Exception as e:
except Exception:
raise ReplicateError(
status_code=422,
message="Unable to parse response. Got={}".format(
@ -479,7 +479,7 @@ def completion(
else:
input_data = {"prompt": prompt, **optional_params}
if acompletion is not None and acompletion == True:
if acompletion is not None and acompletion is True:
return async_completion(
model_response=model_response,
model=model,
@ -513,7 +513,7 @@ def completion(
print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
print_verbose("streaming request")
_response = handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
@ -571,7 +571,7 @@ async def async_completion(
http_handler=http_handler,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
_response = async_handle_prediction_response_streaming(
prediction_url, api_key, print_verbose
)

View file

@ -8,7 +8,7 @@ import types
from copy import deepcopy
from enum import Enum
from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
import httpx # type: ignore
import requests # type: ignore
@ -112,7 +112,7 @@ class SagemakerLLM(BaseAWSLLM):
):
try:
from botocore.credentials import Credentials
except ImportError as e:
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
@ -123,7 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop(
optional_params.pop(
"aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
@ -175,7 +175,7 @@ class SagemakerLLM(BaseAWSLLM):
from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials
except ImportError as e:
except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
@ -244,7 +244,7 @@ class SagemakerLLM(BaseAWSLLM):
hf_model_name = (
hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages)
prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
return prompt
@ -256,10 +256,10 @@ class SagemakerLLM(BaseAWSLLM):
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
timeout: Optional[Union[float, httpx.Timeout]] = None,
custom_prompt_dict={},
hf_model_name=None,
optional_params=None,
litellm_params=None,
logger_fn=None,
acompletion: bool = False,
@ -277,7 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
openai_like_chat_completions = DatabricksChatCompletion()
inference_params["stream"] = True if stream is True else False
_data = {
_data: Dict[str, Any] = {
"model": model,
"messages": messages,
**inference_params,
@ -310,7 +310,7 @@ class SagemakerLLM(BaseAWSLLM):
logger_fn=logger_fn,
timeout=timeout,
encoding=encoding,
headers=prepared_request.headers,
headers=prepared_request.headers, # type: ignore
custom_endpoint=True,
custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore
@ -474,7 +474,7 @@ class SagemakerLLM(BaseAWSLLM):
try:
sync_response = sync_handler.post(
url=prepared_request.url,
headers=prepared_request.headers,
headers=prepared_request.headers, # type: ignore
json=_data,
timeout=timeout,
)
@ -559,7 +559,7 @@ class SagemakerLLM(BaseAWSLLM):
self,
api_base: str,
headers: dict,
data: str,
data: dict,
logging_obj,
client=None,
):
@ -598,7 +598,7 @@ class SagemakerLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err:
error_code = err.response.status_code
raise SagemakerError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e:
except httpx.TimeoutException:
raise SagemakerError(status_code=408, message="Timeout error occurred.")
except Exception as e:
raise SagemakerError(status_code=500, message=str(e))
@ -638,7 +638,7 @@ class SagemakerLLM(BaseAWSLLM):
make_call=partial(
self.make_async_call,
api_base=prepared_request.url,
headers=prepared_request.headers,
headers=prepared_request.headers, # type: ignore
data=data,
logging_obj=logging_obj,
),
@ -716,7 +716,7 @@ class SagemakerLLM(BaseAWSLLM):
try:
response = await async_handler.post(
url=prepared_request.url,
headers=prepared_request.headers,
headers=prepared_request.headers, # type: ignore
json=data,
timeout=timeout,
)
@ -794,8 +794,8 @@ class SagemakerLLM(BaseAWSLLM):
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
@ -1032,7 +1032,7 @@ class AWSEventStreamDecoder:
yield self._chunk_parser_messages_api(chunk_data=_data)
else:
yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError as e:
except json.JSONDecodeError:
# Handle or log any unparseable data at the end
verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}"

View file

@ -17,6 +17,7 @@ import requests # type: ignore
import litellm
from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import (
@ -157,7 +158,7 @@ class MistralTextCompletionConfig:
optional_params["top_p"] = value
if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value
if param == "stream" and value == True:
if param == "stream" and value is True:
optional_params["stream"] = value
if param == "stop":
optional_params["stop"] = value
@ -249,7 +250,7 @@ class CodestralTextCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response],
model_response: TextCompletionResponse,
stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
logging_obj: LiteLLMLogging,
optional_params: dict,
api_key: str,
data: Union[dict, str],
@ -273,7 +274,7 @@ class CodestralTextCompletion(BaseLLM):
)
try:
completion_response = response.json()
except:
except Exception:
raise TextCompletionCodestralError(message=response.text, status_code=422)
_original_choices = completion_response.get("choices", [])

View file

@ -176,7 +176,7 @@ class VertexAIConfig:
if param == "top_p":
optional_params["top_p"] = value
if (
param == "stream" and value == True
param == "stream" and value is True
): # sending stream = False, can cause it to get passed unchecked and raise issues
optional_params["stream"] = value
if param == "n":
@ -1313,7 +1313,6 @@ class ModelResponseIterator:
text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")

View file

@ -268,7 +268,7 @@ def completion(
):
try:
import vertexai
except:
except Exception:
raise VertexAIError(
status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",

View file

@ -5,7 +5,7 @@ import time
import types
import uuid
from enum import Enum
from typing import Any, Callable, List, Literal, Optional, Union
from typing import Any, Callable, List, Literal, Optional, Union, cast
import httpx # type: ignore
import requests # type: ignore
@ -25,7 +25,12 @@ from litellm.types.files import (
is_gemini_1_5_accepted_file_type,
is_video_file_type,
)
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionImageObject,
ChatCompletionTextObject,
)
from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -150,30 +155,34 @@ def _gemini_convert_messages_with_history(
while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
):
if messages[msg_i]["content"] is not None and isinstance(
messages[msg_i]["content"], list
):
_message_content = messages[msg_i].get("content")
if _message_content is not None and isinstance(_message_content, list):
_parts: List[PartType] = []
for element in messages[msg_i]["content"]: # type: ignore
if isinstance(element, dict):
if element["type"] == "text" and len(element["text"]) > 0: # type: ignore
_part = PartType(text=element["text"]) # type: ignore
for element in _message_content:
if (
element["type"] == "text"
and "text" in element
and len(element["text"]) > 0
):
element = cast(ChatCompletionTextObject, element)
_part = PartType(text=element["text"])
_parts.append(_part)
elif element["type"] == "image_url":
img_element: ChatCompletionImageObject = element # type: ignore
element = cast(ChatCompletionImageObject, element)
img_element = element
if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
else:
image_url = img_element["image_url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part) # type: ignore
_parts.append(_part)
user_content.extend(_parts)
elif (
messages[msg_i]["content"] is not None
and isinstance(messages[msg_i]["content"], str)
and len(messages[msg_i]["content"]) > 0 # type: ignore
_message_content is not None
and isinstance(_message_content, str)
and len(_message_content) > 0
):
_part = PartType(text=messages[msg_i]["content"]) # type: ignore
_part = PartType(text=_message_content)
user_content.append(_part)
msg_i += 1
@ -201,22 +210,21 @@ def _gemini_convert_messages_with_history(
else:
msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
if assistant_msg.get("content", None) is not None and isinstance(
assistant_msg["content"], list
):
_message_content = assistant_msg.get("content", None)
if _message_content is not None and isinstance(_message_content, list):
_parts = []
for element in assistant_msg["content"]:
for element in _message_content:
if isinstance(element, dict):
if element["type"] == "text":
_part = PartType(text=element["text"]) # type: ignore
_part = PartType(text=element["text"])
_parts.append(_part)
assistant_content.extend(_parts)
elif (
assistant_msg.get("content", None) is not None
and isinstance(assistant_msg["content"], str)
and assistant_msg["content"]
_message_content is not None
and isinstance(_message_content, str)
and _message_content
):
assistant_text = assistant_msg["content"] # either string or none
assistant_text = _message_content # either string or none
assistant_content.append(PartType(text=assistant_text)) # type: ignore
## HANDLE ASSISTANT FUNCTION CALL
@ -256,7 +264,9 @@ def _gemini_convert_messages_with_history(
raise e
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
def _get_client_cache_key(
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
):
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key
@ -294,7 +304,7 @@ def completion(
"""
try:
import vertexai
except:
except Exception:
raise VertexAIError(
status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
@ -339,6 +349,8 @@ def completion(
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
if _vertex_llm_model_object is None:
from google.auth.credentials import Credentials
if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account
@ -356,7 +368,9 @@ def completion(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
)
vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds
project=vertex_project,
location=vertex_location,
credentials=cast(Credentials, creds),
)
## Load Config
@ -391,7 +405,6 @@ def completion(
request_str = ""
response_obj = None
async_client = None
instances = None
client_options = {
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
@ -400,7 +413,7 @@ def completion(
model in litellm.vertex_language_models
or model in litellm.vertex_vision_models
):
llm_model = _vertex_llm_model_object or GenerativeModel(model)
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models:
@ -459,7 +472,6 @@ def completion(
"model_response": model_response,
"encoding": encoding,
"messages": messages,
"request_str": request_str,
"print_verbose": print_verbose,
"client_options": client_options,
"instances": instances,
@ -474,6 +486,7 @@ def completion(
return async_completion(**data)
completion_response = None
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
@ -529,7 +542,7 @@ def completion(
# Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
):
# If so, convert to list
args_dict[key] = [v for v in val]
@ -560,9 +573,9 @@ def completion(
optional_params["tools"] = tools
elif mode == "chat":
chat = llm_model.start_chat()
request_str += f"chat = llm_model.start_chat()\n"
request_str += "chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
@ -597,7 +610,7 @@ def completion(
)
completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
optional_params.pop(
"stream", None
) # See note above on handling streaming for vertex ai
@ -632,6 +645,12 @@ def completion(
"""
Vertex AI Model Garden
"""
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -661,13 +680,17 @@ def completion(
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
response = TextStreamer(completion_response)
return response
elif mode == "private":
"""
Vertex AI Model Garden deployed on private endpoint
"""
if instances is None:
raise ValueError("instances are required for private endpoint")
if llm_model is None:
raise ValueError("Unable to pick client for private endpoint")
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -686,7 +709,7 @@ def completion(
and "\nOutput:\n" in completion_response
):
completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
response = TextStreamer(completion_response)
return response
@ -715,7 +738,7 @@ def completion(
else:
# init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None:
if hasattr(response_obj, "usage_metadata") and hasattr(
response_obj.usage_metadata, "prompt_token_count"
@ -771,11 +794,13 @@ async def async_completion(
try:
import proto # type: ignore
response_obj = None
completion_response = None
if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False)
optional_params.pop("stream", False)
content = _gemini_convert_messages_with_history(messages=messages)
@ -817,7 +842,7 @@ async def async_completion(
# Check if it's a RepeatedComposite instance
for key, val in function_call.args.items():
if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
):
# If so, convert to list
args_dict[key] = [v for v in val]
@ -880,6 +905,11 @@ async def async_completion(
"""
from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING
logging_obj.pre_call(
input=prompt,
@ -953,7 +983,7 @@ async def async_completion(
else:
# init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None and (
hasattr(response_obj, "usage_metadata")
and hasattr(response_obj.usage_metadata, "prompt_token_count")
@ -1001,6 +1031,7 @@ async def async_streaming(
"""
Add support for async streaming calls for gemini-pro
"""
response: Any = None
if mode == "vision":
stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None)
@ -1065,6 +1096,11 @@ async def async_streaming(
elif mode == "custom":
from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
stream = optional_params.pop("stream", None)
## LOGGING
@ -1102,6 +1138,8 @@ async def async_streaming(
response = TextStreamer(completion_response)
elif mode == "private":
if instances is None:
raise ValueError("Instances are required for private endpoint")
stream = optional_params.pop("stream", None)
_ = instances[0].pop("stream", None)
request_str += f"llm_model.predict_async(instances={instances})\n"
@ -1118,6 +1156,9 @@ async def async_streaming(
if stream:
response = TextStreamer(completion_response)
if response is None:
raise ValueError("Unable to generate response")
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper(

View file

@ -1,7 +1,7 @@
import json
import os
import types
from typing import Literal, Optional, Union
from typing import Any, Literal, Optional, Union, cast
import httpx
from pydantic import BaseModel
@ -53,7 +53,7 @@ class VertexEmbedding(VertexBase):
gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None,
):
if aembedding == True:
if aembedding is True:
return self.async_embedding(
model=model,
input=input,

View file

@ -45,8 +45,8 @@ def completion(
print_verbose: Callable,
encoding,
logging_obj,
optional_params: dict,
custom_prompt_dict={},
optional_params=None,
litellm_params=None,
logger_fn=None,
):
@ -83,7 +83,7 @@ def completion(
)
## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
return iter(outputs)
else:
## LOGGING
@ -144,9 +144,6 @@ def batch_completions(
llm, SamplingParams = validate_environment(model=model)
except Exception as e:
error_str = str(e)
if "data parallel group is already initialized" in error_str:
pass
else:
raise VLLMError(status_code=0, message=error_str)
sampling_params = SamplingParams(**optional_params)
prompts = []

View file

@ -106,6 +106,7 @@ from .llms.prompt_templates.factory import (
custom_prompt,
function_call_prompt,
map_system_message_pt,
ollama_pt,
prompt_factory,
stringify_json_tool_call_content,
)
@ -150,7 +151,6 @@ from .types.utils import (
encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import (
Choices,
CustomStreamWrapper,
EmbeddingResponse,
ImageResponse,
Message,
@ -159,8 +159,6 @@ from litellm.utils import (
TextCompletionResponse,
TextCompletionStreamWrapper,
TranscriptionResponse,
get_secret,
read_config_args,
)
####### ENVIRONMENT VARIABLES ###################
@ -214,7 +212,7 @@ class LiteLLM:
class Chat:
def __init__(self, params, router_obj: Optional[Any]):
self.params = params
if self.params.get("acompletion", False) == True:
if self.params.get("acompletion", False) is True:
self.params.pop("acompletion")
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
self.params, router_obj=router_obj
@ -837,10 +835,10 @@ def completion(
model_response = ModelResponse()
setattr(model_response, "usage", litellm.Usage())
if (
kwargs.get("azure", False) == True
kwargs.get("azure", False) is True
): # don't remove flag check, to remain backwards compatible for repos like Codium
custom_llm_provider = "azure"
if deployment_id != None: # azure llms
if deployment_id is not None: # azure llms
model = deployment_id
custom_llm_provider = "azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
@ -1156,7 +1154,7 @@ def completion(
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
)
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
@ -1278,7 +1276,7 @@ def completion(
if (
len(messages) > 0
and "content" in messages[0]
and type(messages[0]["content"]) == list
and isinstance(messages[0]["content"], list)
):
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
# https://platform.openai.com/docs/api-reference/completions/create
@ -1304,16 +1302,16 @@ def completion(
)
if (
optional_params.get("stream", False) == False
and acompletion == False
and text_completion == False
optional_params.get("stream", False) is False
and acompletion is False
and text_completion is False
):
# convert to chat completion response
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=_response, model_response_object=model_response
)
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
@ -1519,7 +1517,7 @@ def completion(
acompletion=acompletion,
)
if optional_params.get("stream", False) == True:
if optional_params.get("stream", False) is True:
## LOGGING
logging.post_call(
input=messages,
@ -1566,7 +1564,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
## LOGGING
logging.post_call(
@ -1575,7 +1573,7 @@ def completion(
original_response=model_response,
)
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
@ -1654,7 +1652,7 @@ def completion(
timeout=timeout,
client=client,
)
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
@ -1691,7 +1689,7 @@ def completion(
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
response,
@ -1700,7 +1698,7 @@ def completion(
logging_obj=logging,
)
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
@ -1740,7 +1738,7 @@ def completion(
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
@ -1788,7 +1786,7 @@ def completion(
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
@ -1836,7 +1834,7 @@ def completion(
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
@ -1875,7 +1873,7 @@ def completion(
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
@ -1916,7 +1914,7 @@ def completion(
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and optional_params["stream"] is True
and acompletion is False
):
# don't try to access stream object,
@ -1943,7 +1941,7 @@ def completion(
encoding=encoding,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
@ -2095,7 +2093,7 @@ def completion(
logging_obj=logging,
)
# fake palm streaming
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# fake streaming for palm
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper(
@ -2390,7 +2388,7 @@ def completion(
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
model_response,
@ -2527,7 +2525,7 @@ def completion(
)
if (
"stream" in optional_params
and optional_params["stream"] == True
and optional_params["stream"] is True
and not isinstance(response, CustomStreamWrapper)
):
# don't try to access stream object,
@ -2563,7 +2561,7 @@ def completion(
)
if (
"stream" in optional_params and optional_params["stream"] == True
"stream" in optional_params and optional_params["stream"] is True
): ## [BETA]
# don't try to access stream object,
response = CustomStreamWrapper(
@ -2587,38 +2585,38 @@ def completion(
if model in custom_prompt_dict:
# check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt(
ollama_prompt = custom_prompt(
role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages,
)
else:
prompt = prompt_factory(
model=model,
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if isinstance(prompt, dict):
modified_prompt = ollama_pt(model=model, messages=messages)
if isinstance(modified_prompt, dict):
# for multimode models - ollama/llava prompt_factory returns a dict {
# "prompt": prompt,
# "images": images
# }
prompt, images = prompt["prompt"], prompt["images"]
ollama_prompt, images = (
modified_prompt["prompt"],
modified_prompt["images"],
)
optional_params["images"] = images
else:
ollama_prompt = modified_prompt
## LOGGING
generator = ollama.get_ollama_response(
api_base=api_base,
model=model,
prompt=prompt,
prompt=ollama_prompt,
optional_params=optional_params,
logging_obj=logging,
acompletion=acompletion,
model_response=model_response,
encoding=encoding,
)
if acompletion is True or optional_params.get("stream", False) == True:
if acompletion is True or optional_params.get("stream", False) is True:
return generator
response = generator
@ -2701,7 +2699,7 @@ def completion(
api_key=api_key,
logging_obj=logging,
)
if "stream" in optional_params and optional_params["stream"] == True:
if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object,
response = CustomStreamWrapper(
response,
@ -2710,7 +2708,7 @@ def completion(
logging_obj=logging,
)
if optional_params.get("stream", False) or acompletion == True:
if optional_params.get("stream", False) or acompletion is True:
## LOGGING
logging.post_call(
input=messages,
@ -2743,7 +2741,7 @@ def completion(
logging_obj=logging,
)
if inspect.isgenerator(model_response) or (
"stream" in optional_params and optional_params["stream"] == True
"stream" in optional_params and optional_params["stream"] is True
):
# don't try to access stream object,
response = CustomStreamWrapper(
@ -2771,7 +2769,7 @@ def completion(
encoding=encoding,
logging_obj=logging,
)
if stream == True: ## [BETA]
if stream is True: ## [BETA]
# Fake streaming for petals
resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper(
@ -2786,7 +2784,7 @@ def completion(
import requests
url = litellm.api_base or api_base or ""
if url == None or url == "":
if url is None or url == "":
raise ValueError(
"api_base not set. Set api_base or litellm.api_base for custom endpoints"
)
@ -3145,10 +3143,10 @@ def batch_completion_models(*args, **kwargs):
try:
result = future.result()
return result
except Exception as e:
except Exception:
# if model 1 fails, continue with response from model 2, model3
print_verbose(
f"\n\ngot an exception, ignoring, removing from futures"
"\n\ngot an exception, ignoring, removing from futures"
)
print_verbose(futures)
new_futures = {}
@ -3189,9 +3187,6 @@ def batch_completion_models_all_responses(*args, **kwargs):
import concurrent.futures
# ANSI escape codes for colored output
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"
if "model" in kwargs:
kwargs.pop("model")
@ -3520,7 +3515,7 @@ def embedding(
if api_base is None:
raise ValueError(
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
)
## EMBEDDING CALL
@ -4106,7 +4101,6 @@ def text_completion(
*args,
**kwargs,
):
global print_verbose
import copy
"""
@ -4136,7 +4130,7 @@ def text_completion(
Your example of how to use this function goes here.
"""
if "engine" in kwargs:
if model == None:
if model is None:
# only use engine when model not passed
model = kwargs["engine"]
kwargs.pop("engine")
@ -4189,18 +4183,18 @@ def text_completion(
if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3
if echo == True:
if echo is True:
# for tgi llms
if "top_n_tokens" not in kwargs:
kwargs["top_n_tokens"] = 3
# processing prompt - users can pass raw tokens to OpenAI Completion()
if type(prompt) == list:
if isinstance(prompt, list):
import concurrent.futures
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
## if it's a 2d list - each element in the list is a text_completion() request
if len(prompt) > 0 and type(prompt[0]) == list:
if len(prompt) > 0 and isinstance(prompt[0], list):
responses = [None for x in prompt] # init responses
def process_prompt(i, individual_prompt):
@ -4299,7 +4293,7 @@ def text_completion(
raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}")
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
if isinstance(response, TextCompletionResponse):
return response
@ -4813,12 +4807,12 @@ def transcription(
Allows router to load balance between them
"""
atranscription = kwargs.get("atranscription", False)
litellm_call_id = kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
tags = kwargs.pop("tags", [])
kwargs.get("litellm_call_id", None)
kwargs.get("logger_fn", None)
kwargs.get("proxy_server_request", None)
kwargs.get("model_info", None)
kwargs.get("metadata", {})
kwargs.pop("tags", [])
drop_params = kwargs.get("drop_params", None)
client: Optional[
@ -4996,7 +4990,7 @@ def speech(
model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {})
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
tags = kwargs.pop("tags", [])
kwargs.pop("tags", [])
optional_params = {}
if response_format is not None:
@ -5345,12 +5339,12 @@ def print_verbose(print_statement):
verbose_logger.debug(print_statement)
if litellm.set_verbose:
print(print_statement) # noqa
except:
except Exception:
pass
def config_completion(**kwargs):
if litellm.config_path != None:
if litellm.config_path is not None:
config_args = read_config_args(litellm.config_path)
# overwrite any args passed in with config args
return completion(**kwargs, **config_args)
@ -5408,16 +5402,18 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
response["choices"][0]["text"] = combined_content
if len(combined_content) > 0:
completion_output = combined_content
pass
else:
completion_output = ""
pass
# # Update usage information if needed
try:
response["usage"]["prompt_tokens"] = token_counter(
model=model, messages=messages
)
except: # don't allow this failing to block a complete streaming response from being returned
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
except (
Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter(
model=model,

View file

@ -128,14 +128,14 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs): # type: ignore
try:
return self.model_dump(**kwargs) # noqa
except Exception as e:
except Exception:
# if using pydantic v1
return self.dict(**kwargs)
def fields_set(self):
try:
return self.model_fields_set # noqa
except:
except Exception:
# if using pydantic v1
return self.__fields_set__

View file

@ -1,202 +0,0 @@
def add_new_model():
import streamlit as st
import json, requests, uuid
model_name = st.text_input(
"Model Name - user-facing model name", placeholder="gpt-3.5-turbo"
)
st.subheader("LiteLLM Params")
litellm_model_name = st.text_input(
"Model", placeholder="azure/gpt-35-turbo-us-east"
)
litellm_api_key = st.text_input("API Key")
litellm_api_base = st.text_input(
"API Base",
placeholder="https://my-endpoint.openai.azure.com",
)
litellm_api_version = st.text_input("API Version", placeholder="2023-07-01-preview")
litellm_params = json.loads(
st.text_area(
"Additional Litellm Params (JSON dictionary). [See all possible inputs](https://github.com/BerriAI/litellm/blob/3f15d7230fe8e7492c95a752963e7fbdcaf7bf98/litellm/main.py#L293)",
value={},
)
)
st.subheader("Model Info")
mode_options = ("completion", "embedding", "image generation")
mode_selected = st.selectbox("Mode", mode_options)
model_info = json.loads(
st.text_area(
"Additional Model Info (JSON dictionary)",
value={},
)
)
if st.button("Submit"):
try:
model_info = {
"model_name": model_name,
"litellm_params": {
"model": litellm_model_name,
"api_key": litellm_api_key,
"api_base": litellm_api_base,
"api_version": litellm_api_version,
},
"model_info": {
"id": str(uuid.uuid4()),
"mode": mode_selected,
},
}
# Make the POST request to the specified URL
complete_url = ""
if st.session_state["api_url"].endswith("/"):
complete_url = f"{st.session_state['api_url']}model/new"
else:
complete_url = f"{st.session_state['api_url']}/model/new"
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
response = requests.post(complete_url, json=model_info, headers=headers)
if response.status_code == 200:
st.success("Model added successfully!")
else:
st.error(f"Failed to add model. Status code: {response.status_code}")
st.success("Form submitted successfully!")
except Exception as e:
raise e
def list_models():
import streamlit as st
import requests
# Check if the necessary configuration is available
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
# Make the GET request
try:
complete_url = ""
if isinstance(st.session_state["api_url"], str) and st.session_state[
"api_url"
].endswith("/"):
complete_url = f"{st.session_state['api_url']}models"
else:
complete_url = f"{st.session_state['api_url']}/models"
response = requests.get(
complete_url,
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
)
# Check if the request was successful
if response.status_code == 200:
models = response.json()
st.write(models) # or st.json(models) to pretty print the JSON
else:
st.error(f"Failed to get models. Status code: {response.status_code}")
except Exception as e:
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def create_key():
import streamlit as st
import json, requests, uuid
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
models = st.text_input("Models it can access (separated by comma)", value="")
models = models.split(",") if models else []
additional_params = json.loads(
st.text_area(
"Additional Key Params (JSON dictionary). [See all possible inputs](https://litellm-api.up.railway.app/#/key%20management/generate_key_fn_key_generate_post)",
value={},
)
)
if st.button("Submit"):
try:
key_post_body = {
"duration": duration,
"models": models,
**additional_params,
}
# Make the POST request to the specified URL
complete_url = ""
if st.session_state["api_url"].endswith("/"):
complete_url = f"{st.session_state['api_url']}key/generate"
else:
complete_url = f"{st.session_state['api_url']}/key/generate"
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
response = requests.post(
complete_url, json=key_post_body, headers=headers
)
if response.status_code == 200:
st.success(f"Key added successfully! - {response.json()}")
else:
st.error(f"Failed to add Key. Status code: {response.status_code}")
st.success("Form submitted successfully!")
except Exception as e:
raise e
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def streamlit_ui():
import streamlit as st
st.header("Admin Configuration")
# Add a navigation sidebar
st.sidebar.title("Navigation")
page = st.sidebar.radio(
"Go to", ("Proxy Setup", "Add Models", "List Models", "Create Key")
)
# Initialize session state variables if not already present
if "api_url" not in st.session_state:
st.session_state["api_url"] = None
if "proxy_key" not in st.session_state:
st.session_state["proxy_key"] = None
# Display different pages based on navigation selection
if page == "Proxy Setup":
# Use text inputs with intermediary variables
input_api_url = st.text_input(
"Proxy Endpoint",
value=st.session_state.get("api_url", ""),
placeholder="http://0.0.0.0:8000",
)
input_proxy_key = st.text_input(
"Proxy Key",
value=st.session_state.get("proxy_key", ""),
placeholder="sk-...",
)
# When the "Save" button is clicked, update the session state
if st.button("Save"):
st.session_state["api_url"] = input_api_url
st.session_state["proxy_key"] = input_proxy_key
st.success("Configuration saved!")
elif page == "Add Models":
add_new_model()
elif page == "List Models":
list_models()
elif page == "Create Key":
create_key()
if __name__ == "__main__":
streamlit_ui()

View file

@ -69,7 +69,7 @@ async def get_global_activity(
try:
if prisma_client is None:
raise ValueError(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
)
sql_query = """

View file

@ -132,7 +132,7 @@ def common_checks(
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if (
general_settings.get("enforce_user_param", None) is not None
and general_settings["enforce_user_param"] == True
and general_settings["enforce_user_param"] is True
):
if is_llm_api_route(route=route) and "user" not in request_body:
raise Exception(
@ -557,7 +557,7 @@ async def get_team_object(
)
return _response
except Exception as e:
except Exception:
raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
)
@ -664,7 +664,7 @@ async def get_org_object(
raise Exception
return response
except Exception as e:
except Exception:
raise Exception(
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
)

View file

@ -98,7 +98,7 @@ class LicenseCheck:
elif self._verify(license_str=self.license_str) is True:
return True
return False
except Exception as e:
except Exception:
return False
def verify_license_without_api_request(self, public_key, license_key):

View file

@ -112,7 +112,6 @@ async def user_api_key_auth(
),
) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import (
custom_db_client,
general_settings,
jwt_handler,
litellm_proxy_admin_name,
@ -476,7 +475,7 @@ async def user_api_key_auth(
)
if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True:
if general_settings.get("allow_user_auth", False) is True:
return UserAPIKeyAuth()
else:
raise HTTPException(
@ -597,7 +596,7 @@ async def user_api_key_auth(
## VALIDATE MASTER KEY ##
try:
assert isinstance(master_key, str)
except Exception as e:
except Exception:
raise HTTPException(
status_code=500,
detail={
@ -648,7 +647,7 @@ async def user_api_key_auth(
)
if (
prisma_client is None and custom_db_client is None
prisma_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.")
@ -722,9 +721,9 @@ async def user_api_key_auth(
if config != {}:
model_list = config.get("model_list", [])
llm_model_list = model_list
new_model_list = model_list
verbose_proxy_logger.debug(
f"\n new llm router model list {llm_model_list}"
f"\n new llm router model list {new_model_list}"
)
if (
len(valid_token.models) == 0

View file

@ -2,6 +2,7 @@ import sys
from typing import Any, Dict, List, Optional, get_args
import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.utils import get_instance_fn
@ -59,9 +60,15 @@ def initialize_callbacks_on_proxy(
presidio_logging_only
) # validate boolean given
params = {
_presidio_params = {}
if "presidio" in callback_specific_params and isinstance(
callback_specific_params["presidio"], dict
):
_presidio_params = callback_specific_params["presidio"]
params: Dict[str, Any] = {
"logging_only": presidio_logging_only,
**callback_specific_params.get("presidio", {}),
**_presidio_params,
}
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
imported_list.append(pii_masking_object)
@ -70,7 +77,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_LlamaGuard,
)
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value
@ -83,7 +90,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_SecretDetection,
)
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value
@ -96,7 +103,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_OpenAI_Moderation,
)
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value
@ -126,7 +133,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_GoogleTextModeration,
)
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value
@ -137,7 +144,7 @@ def initialize_callbacks_on_proxy(
elif isinstance(callback, str) and callback == "llmguard_moderations":
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value
@ -150,7 +157,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_BlockedUserList,
)
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value
@ -165,7 +172,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_BannedKeywords,
)
if premium_user != True:
if premium_user is not True:
raise Exception(
"Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value
@ -212,7 +219,7 @@ def initialize_callbacks_on_proxy(
and isinstance(v, str)
and v.startswith("os.environ/")
):
azure_content_safety_params[k] = litellm.get_secret(v)
azure_content_safety_params[k] = get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params,

View file

@ -6,13 +6,14 @@ import tracemalloc
from fastapi import APIRouter
import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_proxy_logger
router = APIRouter()
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
try:
import objgraph
import objgraph # type: ignore
print("growth of objects") # noqa
objgraph.show_growth()
@ -21,8 +22,10 @@ if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
roots = objgraph.get_leaking_objects()
print("\n\nLeaking objects") # noqa
objgraph.show_most_common_types(objects=roots)
except:
pass
except ImportError:
raise ImportError(
"objgraph not found. Please install objgraph to use this feature."
)
tracemalloc.start(10)
@ -57,15 +60,20 @@ async def memory_usage_in_mem_cache():
user_api_key_cache,
)
num_items_in_user_api_key_cache = len(
user_api_key_cache.in_memory_cache.cache_dict
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
if llm_router is None:
num_items_in_llm_router_cache = 0
else:
num_items_in_llm_router_cache = len(
llm_router.cache.in_memory_cache.cache_dict
) + len(llm_router.cache.in_memory_cache.ttl_dict)
num_items_in_user_api_key_cache = len(
user_api_key_cache.in_memory_cache.cache_dict
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
num_items_in_proxy_logging_obj_cache = len(
proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict
) + len(proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict)
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
return {
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
@ -89,13 +97,20 @@ async def memory_usage_in_mem_cache_items():
user_api_key_cache,
)
if llm_router is None:
llm_router_in_memory_cache_dict = {}
llm_router_in_memory_ttl_dict = {}
else:
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
return {
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
"llm_router_cache": llm_router.cache.in_memory_cache.cache_dict,
"llm_router_ttl": llm_router.cache.in_memory_cache.ttl_dict,
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict,
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict,
"llm_router_cache": llm_router_in_memory_cache_dict,
"llm_router_ttl": llm_router_in_memory_ttl_dict,
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
}
@ -104,9 +119,18 @@ async def get_otel_spans():
from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.proxy.proxy_server import open_telemetry_logger
open_telemetry_logger: OpenTelemetry = open_telemetry_logger
if open_telemetry_logger is None:
return {
"otel_spans": [],
"spans_grouped_by_parent": {},
"most_recent_parent": None,
}
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
recorded_spans = otel_exporter.get_finished_spans()
if hasattr(otel_exporter, "get_finished_spans"):
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
else:
recorded_spans = []
print("Spans: ", recorded_spans) # noqa
@ -137,11 +161,13 @@ async def get_otel_spans():
# Helper functions for debugging
def init_verbose_loggers():
try:
worker_config = litellm.get_secret("WORKER_CONFIG")
worker_config = get_secret_str("WORKER_CONFIG")
# if not, assume it's a json string
if worker_config is None:
return
if os.path.isfile(worker_config):
return
# if not, assume it's a json string
_settings = json.loads(os.getenv("WORKER_CONFIG"))
_settings = json.loads(worker_config)
if not isinstance(_settings, dict):
return
@ -162,7 +188,7 @@ def init_verbose_loggers():
level=logging.INFO
) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug == True:
if detailed_debug is True:
import logging
from litellm._logging import (
@ -178,10 +204,10 @@ def init_verbose_loggers():
verbose_proxy_logger.setLevel(
level=logging.DEBUG
) # set proxy logs to debug
elif debug == False and detailed_debug == False:
elif debug is False and detailed_debug is False:
# users can control proxy debugging using env variable = 'LITELLM_LOG'
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting != None:
if litellm_log_setting is not None:
if litellm_log_setting.upper() == "INFO":
import logging
@ -213,4 +239,6 @@ def init_verbose_loggers():
level=logging.DEBUG
) # set proxy logs to debug
except Exception as e:
verbose_logger.warning(f"Failed to init verbose loggers: {str(e)}")
import logging
logging.warning(f"Failed to init verbose loggers: {str(e)}")

Some files were not shown because too many files have changed in this diff Show more