Litellm ruff linting enforcement (#5992)

* ci(config.yml): add a 'check_code_quality' step

Addresses https://github.com/BerriAI/litellm/issues/5991

* ci(config.yml): check why circle ci doesn't pick up this test

* ci(config.yml): fix to run 'check_code_quality' tests

* fix(__init__.py): fix unprotected import

* fix(__init__.py): don't remove unused imports

* build(ruff.toml): update ruff.toml to ignore unused imports

* fix: fix: ruff + pyright - fix linting + type-checking errors

* fix: fix linting errors

* fix(lago.py): fix module init error

* fix: fix linting errors

* ci(config.yml): cd into correct dir for checks

* fix(proxy_server.py): fix linting error

* fix(utils.py): fix bare except

causes ruff linting errors

* fix: ruff - fix remaining linting errors

* fix(clickhouse.py): use standard logging object

* fix(__init__.py): fix unprotected import

* fix: ruff - fix linting errors

* fix: fix linting errors

* ci(config.yml): cleanup code qa step (formatting handled in local_testing)

* fix(_health_endpoints.py): fix ruff linting errors

* ci(config.yml): just use ruff in check_code_quality pipeline for now

* build(custom_guardrail.py): include missing file

* style(embedding_handler.py): fix ruff check
This commit is contained in:
Krish Dholakia 2024-10-01 16:44:20 -07:00 committed by GitHub
parent 3fc4ae0d65
commit d57be47b0f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
263 changed files with 1687 additions and 3320 deletions

View file

@ -299,6 +299,27 @@ jobs:
ls ls
python -m pytest -vv tests/local_testing/test_python_38.py python -m pytest -vv tests/local_testing/test_python_38.py
check_code_quality:
docker:
- image: cimg/python:3.11
auth:
username: ${DOCKERHUB_USERNAME}
password: ${DOCKERHUB_PASSWORD}
working_directory: ~/project/litellm
steps:
- checkout
- run:
name: Install Dependencies
command: |
python -m pip install --upgrade pip
pip install ruff
pip install pylint
pip install .
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
- run: ruff check ./litellm
build_and_test: build_and_test:
machine: machine:
image: ubuntu-2204:2023.10.1 image: ubuntu-2204:2023.10.1
@ -806,6 +827,12 @@ workflows:
only: only:
- main - main
- /litellm_.*/ - /litellm_.*/
- check_code_quality:
filters:
branches:
only:
- main
- /litellm_.*/
- ui_endpoint_testing: - ui_endpoint_testing:
filters: filters:
branches: branches:
@ -867,6 +894,7 @@ workflows:
- installing_litellm_on_python - installing_litellm_on_python
- proxy_logging_guardrails_model_info_tests - proxy_logging_guardrails_model_info_tests
- proxy_pass_through_endpoint_tests - proxy_pass_through_endpoint_tests
- check_code_quality
filters: filters:
branches: branches:
only: only:

View file

@ -630,18 +630,6 @@ general_settings:
"database_url": "string", "database_url": "string",
"database_connection_pool_limit": 0, # default 100 "database_connection_pool_limit": 0, # default 100
"database_connection_timeout": 0, # default 60s "database_connection_timeout": 0, # default 60s
"database_type": "dynamo_db",
"database_args": {
"billing_mode": "PROVISIONED_THROUGHPUT",
"read_capacity_units": 0,
"write_capacity_units": 0,
"ssl_verify": true,
"region_name": "string",
"user_table_name": "LiteLLM_UserTable",
"key_table_name": "LiteLLM_VerificationToken",
"config_table_name": "LiteLLM_Config",
"spend_table_name": "LiteLLM_SpendLogs"
},
"otel": true, "otel": true,
"custom_auth": "string", "custom_auth": "string",
"max_parallel_requests": 0, # the max parallel requests allowed per deployment "max_parallel_requests": 0, # the max parallel requests allowed per deployment

View file

@ -97,7 +97,7 @@ class GenericAPILogger:
for key, value in payload.items(): for key, value in payload.items():
try: try:
payload[key] = str(value) payload[key] = str(value)
except: except Exception:
# non blocking if it can't cast to a str # non blocking if it can't cast to a str
pass pass

View file

@ -49,7 +49,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
def __init__(self): def __init__(self):
try: try:
from google.cloud import language_v1 from google.cloud import language_v1
except: except Exception:
raise Exception( raise Exception(
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`" "Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
) )
@ -90,7 +90,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
verbose_proxy_logger.debug(print_statement) verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except Exception:
pass pass
async def async_moderation_hook( async def async_moderation_hook(

View file

@ -58,7 +58,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
verbose_proxy_logger.debug(print_statement) verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except Exception:
pass pass
def set_custom_prompt_template(self, messages: list): def set_custom_prompt_template(self, messages: list):

View file

@ -49,7 +49,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
verbose_proxy_logger.debug(print_statement) verbose_proxy_logger.debug(print_statement)
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except Exception:
pass pass
async def moderation_check(self, text: str): async def moderation_check(self, text: str):

View file

@ -3,7 +3,8 @@ import warnings
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*") warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
### INIT VARIABLES ### ### INIT VARIABLES ###
import threading, requests, os import threading
import os
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.caching import Cache from litellm.caching import Cache
@ -308,13 +309,13 @@ def get_model_cost_map(url: str):
return content return content
try: try:
with requests.get( response = httpx.get(
url, timeout=5 url, timeout=5
) as response: # set a 5 second timeout for the get request ) # set a 5 second timeout for the get request
response.raise_for_status() # Raise an exception if the request is unsuccessful response.raise_for_status() # Raise an exception if the request is unsuccessful
content = response.json() content = response.json()
return content return content
except Exception as e: except Exception:
import importlib.resources import importlib.resources
import json import json
@ -839,7 +840,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
from .timeout import timeout from .timeout import timeout
from .cost_calculator import completion_cost from .cost_calculator import completion_cost
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
@ -848,7 +849,6 @@ from .utils import (
exception_type, exception_type,
get_optional_params, get_optional_params,
get_response_string, get_response_string,
modify_integration,
token_counter, token_counter,
create_pretrained_tokenizer, create_pretrained_tokenizer,
create_tokenizer, create_tokenizer,

View file

@ -98,5 +98,5 @@ def print_verbose(print_statement):
try: try:
if set_verbose: if set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except Exception:
pass pass

View file

@ -2,5 +2,5 @@ import importlib_metadata
try: try:
version = importlib_metadata.version("litellm") version = importlib_metadata.version("litellm")
except: except Exception:
pass pass

View file

@ -12,7 +12,6 @@ from openai.types.beta.assistant import Assistant
from openai.types.beta.assistant_deleted import AssistantDeleted from openai.types.beta.assistant_deleted import AssistantDeleted
import litellm import litellm
from litellm import client
from litellm.llms.AzureOpenAI import assistants from litellm.llms.AzureOpenAI import assistants
from litellm.types.router import GenericLiteLLMParams from litellm.types.router import GenericLiteLLMParams
from litellm.utils import ( from litellm.utils import (
@ -96,7 +95,7 @@ def get_assistants(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -280,7 +279,7 @@ def create_assistants(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -464,7 +463,7 @@ def delete_assistant(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -649,7 +648,7 @@ def create_thread(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -805,7 +804,7 @@ def get_thread(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -991,7 +990,7 @@ def add_message(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -1149,7 +1148,7 @@ def get_messages(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -1347,7 +1346,7 @@ def run_thread(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout

View file

@ -22,7 +22,7 @@ import litellm
from litellm import client from litellm import client
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret, get_secret_str
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
Batch, Batch,
CancelBatchRequest, CancelBatchRequest,
@ -131,7 +131,7 @@ def create_batch(
extra_headers=extra_headers, extra_headers=extra_headers,
extra_body=extra_body, extra_body=extra_body,
) )
api_base: Optional[str] = None
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -165,27 +165,30 @@ def create_batch(
_is_async=_is_async, _is_async=_is_async,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore )
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore )
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.create_batch( response = azure_batches_instance.create_batch(
_is_async=_is_async, _is_async=_is_async,
@ -293,7 +296,7 @@ def retrieve_batch(
) )
_is_async = kwargs.pop("aretrieve_batch", False) is True _is_async = kwargs.pop("aretrieve_batch", False) is True
api_base: Optional[str] = None
if custom_llm_provider == "openai": if custom_llm_provider == "openai":
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
@ -327,27 +330,30 @@ def retrieve_batch(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = (
optional_params.api_base
or litellm.api_base
or get_secret_str("AZURE_API_BASE")
)
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore )
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore )
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.retrieve_batch( response = azure_batches_instance.retrieve_batch(
_is_async=_is_async, _is_async=_is_async,
@ -384,7 +390,7 @@ async def alist_batches(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> Batch: ):
""" """
Async: List your organization's batches. Async: List your organization's batches.
""" """
@ -482,27 +488,26 @@ def list_batches(
max_retries=optional_params.max_retries, max_retries=optional_params.max_retries,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore )
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore )
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_batches_instance.list_batches( response = azure_batches_instance.list_batches(
_is_async=_is_async, _is_async=_is_async,

View file

@ -7,11 +7,16 @@
# #
# Thank you users! We ❤️ you! - Krrish & Ishaan # Thank you users! We ❤️ you! - Krrish & Ishaan
import os, json, time import json
import os
import threading
import time
from typing import Literal, Optional, Union
import requests # type: ignore
import litellm import litellm
from litellm.utils import ModelResponse from litellm.utils import ModelResponse
import requests, threading # type: ignore
from typing import Optional, Union, Literal
class BudgetManager: class BudgetManager:
@ -35,7 +40,7 @@ class BudgetManager:
import logging import logging
logging.info(print_statement) logging.info(print_statement)
except: except Exception:
pass pass
def load_data(self): def load_data(self):
@ -52,7 +57,6 @@ class BudgetManager:
elif self.client_type == "hosted": elif self.client_type == "hosted":
# Load the user_dict from hosted db # Load the user_dict from hosted db
url = self.api_base + "/get_budget" url = self.api_base + "/get_budget"
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name} data = {"project_name": self.project_name}
response = requests.post(url, headers=self.headers, json=data) response = requests.post(url, headers=self.headers, json=data)
response = response.json() response = response.json()
@ -210,7 +214,6 @@ class BudgetManager:
return {"status": "success"} return {"status": "success"}
elif self.client_type == "hosted": elif self.client_type == "hosted":
url = self.api_base + "/set_budget" url = self.api_base + "/set_budget"
headers = {"Content-Type": "application/json"}
data = {"project_name": self.project_name, "user_dict": self.user_dict} data = {"project_name": self.project_name, "user_dict": self.user_dict}
response = requests.post(url, headers=self.headers, json=data) response = requests.post(url, headers=self.headers, json=data)
response = response.json() response = response.json()

View file

@ -33,7 +33,7 @@ def print_verbose(print_statement):
verbose_logger.debug(print_statement) verbose_logger.debug(print_statement)
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except Exception:
pass pass
@ -96,15 +96,13 @@ class InMemoryCache(BaseCache):
""" """
for key in list(self.ttl_dict.keys()): for key in list(self.ttl_dict.keys()):
if time.time() > self.ttl_dict[key]: if time.time() > self.ttl_dict[key]:
removed_item = self.cache_dict.pop(key, None) self.cache_dict.pop(key, None)
removed_ttl_item = self.ttl_dict.pop(key, None) self.ttl_dict.pop(key, None)
# de-reference the removed item # de-reference the removed item
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/ # https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used. # One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
# This can occur when an object is referenced by another object, but the reference is never removed. # This can occur when an object is referenced by another object, but the reference is never removed.
removed_item = None
removed_ttl_item = None
def set_cache(self, key, value, **kwargs): def set_cache(self, key, value, **kwargs):
print_verbose( print_verbose(
@ -150,7 +148,7 @@ class InMemoryCache(BaseCache):
original_cached_response = self.cache_dict[key] original_cached_response = self.cache_dict[key]
try: try:
cached_response = json.loads(original_cached_response) cached_response = json.loads(original_cached_response)
except: except Exception:
cached_response = original_cached_response cached_response = original_cached_response
return cached_response return cached_response
return None return None
@ -251,7 +249,7 @@ class RedisCache(BaseCache):
self.redis_version = "Unknown" self.redis_version = "Unknown"
try: try:
self.redis_version = self.redis_client.info()["redis_version"] self.redis_version = self.redis_client.info()["redis_version"]
except Exception as e: except Exception:
pass pass
### ASYNC HEALTH PING ### ### ASYNC HEALTH PING ###
@ -688,7 +686,7 @@ class RedisCache(BaseCache):
cached_response = json.loads( cached_response = json.loads(
cached_response cached_response
) # Convert string to dictionary ) # Convert string to dictionary
except: except Exception:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
return cached_response return cached_response
@ -844,7 +842,7 @@ class RedisCache(BaseCache):
""" """
Tests if the sync redis client is correctly setup. Tests if the sync redis client is correctly setup.
""" """
print_verbose(f"Pinging Sync Redis Cache") print_verbose("Pinging Sync Redis Cache")
start_time = time.time() start_time = time.time()
try: try:
response = self.redis_client.ping() response = self.redis_client.ping()
@ -878,7 +876,7 @@ class RedisCache(BaseCache):
_redis_client = self.init_async_client() _redis_client = self.init_async_client()
start_time = time.time() start_time = time.time()
async with _redis_client as redis_client: async with _redis_client as redis_client:
print_verbose(f"Pinging Async Redis Cache") print_verbose("Pinging Async Redis Cache")
try: try:
response = await redis_client.ping() response = await redis_client.ping()
## LOGGING ## ## LOGGING ##
@ -973,7 +971,6 @@ class RedisSemanticCache(BaseCache):
}, },
"fields": { "fields": {
"text": [{"name": "response"}], "text": [{"name": "response"}],
"text": [{"name": "prompt"}],
"vector": [ "vector": [
{ {
"name": "litellm_embedding", "name": "litellm_embedding",
@ -999,14 +996,14 @@ class RedisSemanticCache(BaseCache):
redis_url = "redis://:" + password + "@" + host + ":" + port redis_url = "redis://:" + password + "@" + host + ":" + port
print_verbose(f"redis semantic-cache redis_url: {redis_url}") print_verbose(f"redis semantic-cache redis_url: {redis_url}")
if use_async == False: if use_async is False:
self.index = SearchIndex.from_dict(schema) self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url) self.index.connect(redis_url=redis_url)
try: try:
self.index.create(overwrite=False) # don't overwrite existing index self.index.create(overwrite=False) # don't overwrite existing index
except Exception as e: except Exception as e:
print_verbose(f"Got exception creating semantic cache index: {str(e)}") print_verbose(f"Got exception creating semantic cache index: {str(e)}")
elif use_async == True: elif use_async is True:
schema["index"]["name"] = "litellm_semantic_cache_index_async" schema["index"]["name"] = "litellm_semantic_cache_index_async"
self.index = SearchIndex.from_dict(schema) self.index = SearchIndex.from_dict(schema)
self.index.connect(redis_url=redis_url, use_async=True) self.index.connect(redis_url=redis_url, use_async=True)
@ -1027,7 +1024,7 @@ class RedisSemanticCache(BaseCache):
cached_response = json.loads( cached_response = json.loads(
cached_response cached_response
) # Convert string to dictionary ) # Convert string to dictionary
except: except Exception:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
return cached_response return cached_response
@ -1060,7 +1057,7 @@ class RedisSemanticCache(BaseCache):
] ]
# Add more data # Add more data
keys = self.index.load(new_data) self.index.load(new_data)
return return
@ -1092,7 +1089,7 @@ class RedisSemanticCache(BaseCache):
) )
results = self.index.query(query) results = self.index.query(query)
if results == None: if results is None:
return None return None
if isinstance(results, list): if isinstance(results, list):
if len(results) == 0: if len(results) == 0:
@ -1173,7 +1170,7 @@ class RedisSemanticCache(BaseCache):
] ]
# Add more data # Add more data
keys = await self.index.aload(new_data) await self.index.aload(new_data)
return return
async def async_get_cache(self, key, **kwargs): async def async_get_cache(self, key, **kwargs):
@ -1222,7 +1219,7 @@ class RedisSemanticCache(BaseCache):
return_fields=["response", "prompt", "vector_distance"], return_fields=["response", "prompt", "vector_distance"],
) )
results = await self.index.aquery(query) results = await self.index.aquery(query)
if results == None: if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None return None
if isinstance(results, list): if isinstance(results, list):
@ -1396,7 +1393,7 @@ class QdrantSemanticCache(BaseCache):
cached_response = json.loads( cached_response = json.loads(
cached_response cached_response
) # Convert string to dictionary ) # Convert string to dictionary
except: except Exception:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
return cached_response return cached_response
@ -1435,7 +1432,7 @@ class QdrantSemanticCache(BaseCache):
}, },
] ]
} }
keys = self.sync_client.put( self.sync_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers, headers=self.headers,
json=data, json=data,
@ -1481,7 +1478,7 @@ class QdrantSemanticCache(BaseCache):
) )
results = search_response.json()["result"] results = search_response.json()["result"]
if results == None: if results is None:
return None return None
if isinstance(results, list): if isinstance(results, list):
if len(results) == 0: if len(results) == 0:
@ -1563,7 +1560,7 @@ class QdrantSemanticCache(BaseCache):
] ]
} }
keys = await self.async_client.put( await self.async_client.put(
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points", url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
headers=self.headers, headers=self.headers,
json=data, json=data,
@ -1629,7 +1626,7 @@ class QdrantSemanticCache(BaseCache):
results = search_response.json()["result"] results = search_response.json()["result"]
if results == None: if results is None:
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0 kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
return None return None
if isinstance(results, list): if isinstance(results, list):
@ -1767,7 +1764,7 @@ class S3Cache(BaseCache):
cached_response = json.loads( cached_response = json.loads(
cached_response cached_response
) # Convert string to dictionary ) # Convert string to dictionary
except Exception as e: except Exception:
cached_response = ast.literal_eval(cached_response) cached_response = ast.literal_eval(cached_response)
if type(cached_response) is not dict: if type(cached_response) is not dict:
cached_response = dict(cached_response) cached_response = dict(cached_response)
@ -1845,7 +1842,7 @@ class DualCache(BaseCache):
self.in_memory_cache.set_cache(key, value, **kwargs) self.in_memory_cache.set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only is False:
self.redis_cache.set_cache(key, value, **kwargs) self.redis_cache.set_cache(key, value, **kwargs)
except Exception as e: except Exception as e:
print_verbose(e) print_verbose(e)
@ -1865,7 +1862,7 @@ class DualCache(BaseCache):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
result = self.in_memory_cache.increment_cache(key, value, **kwargs) result = self.in_memory_cache.increment_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only is False:
result = self.redis_cache.increment_cache(key, value, **kwargs) result = self.redis_cache.increment_cache(key, value, **kwargs)
return result return result
@ -1887,7 +1884,7 @@ class DualCache(BaseCache):
if ( if (
(self.always_read_redis is True) (self.always_read_redis is True)
and self.redis_cache is not None and self.redis_cache is not None
and local_only == False and local_only is False
): ):
# If not found in in-memory cache or always_read_redis is True, try fetching from Redis # If not found in in-memory cache or always_read_redis is True, try fetching from Redis
redis_result = self.redis_cache.get_cache(key, **kwargs) redis_result = self.redis_cache.get_cache(key, **kwargs)
@ -1900,7 +1897,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}") print_verbose(f"get cache: cache result: {result}")
return result return result
except Exception as e: except Exception:
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs): def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
@ -1913,7 +1910,7 @@ class DualCache(BaseCache):
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False: if None in result and self.redis_cache is not None and local_only is False:
""" """
- for the none values in the result - for the none values in the result
- check the redis cache - check the redis cache
@ -1933,7 +1930,7 @@ class DualCache(BaseCache):
print_verbose(f"async batch get cache: cache result: {result}") print_verbose(f"async batch get cache: cache result: {result}")
return result return result
except Exception as e: except Exception:
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
async def async_get_cache(self, key, local_only: bool = False, **kwargs): async def async_get_cache(self, key, local_only: bool = False, **kwargs):
@ -1952,7 +1949,7 @@ class DualCache(BaseCache):
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if result is None and self.redis_cache is not None and local_only == False: if result is None and self.redis_cache is not None and local_only is False:
# If not found in in-memory cache, try fetching from Redis # If not found in in-memory cache, try fetching from Redis
redis_result = await self.redis_cache.async_get_cache(key, **kwargs) redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
@ -1966,7 +1963,7 @@ class DualCache(BaseCache):
print_verbose(f"get cache: cache result: {result}") print_verbose(f"get cache: cache result: {result}")
return result return result
except Exception as e: except Exception:
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
async def async_batch_get_cache( async def async_batch_get_cache(
@ -1981,7 +1978,7 @@ class DualCache(BaseCache):
if in_memory_result is not None: if in_memory_result is not None:
result = in_memory_result result = in_memory_result
if None in result and self.redis_cache is not None and local_only == False: if None in result and self.redis_cache is not None and local_only is False:
""" """
- for the none values in the result - for the none values in the result
- check the redis cache - check the redis cache
@ -2006,7 +2003,7 @@ class DualCache(BaseCache):
result[index] = value result[index] = value
return result return result
except Exception as e: except Exception:
verbose_logger.error(traceback.format_exc()) verbose_logger.error(traceback.format_exc())
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs): async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
@ -2017,7 +2014,7 @@ class DualCache(BaseCache):
if self.in_memory_cache is not None: if self.in_memory_cache is not None:
await self.in_memory_cache.async_set_cache(key, value, **kwargs) await self.in_memory_cache.async_set_cache(key, value, **kwargs)
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache(key, value, **kwargs) await self.redis_cache.async_set_cache(key, value, **kwargs)
except Exception as e: except Exception as e:
verbose_logger.exception( verbose_logger.exception(
@ -2039,7 +2036,7 @@ class DualCache(BaseCache):
cache_list=cache_list, **kwargs cache_list=cache_list, **kwargs
) )
if self.redis_cache is not None and local_only == False: if self.redis_cache is not None and local_only is False:
await self.redis_cache.async_set_cache_pipeline( await self.redis_cache.async_set_cache_pipeline(
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
) )
@ -2459,7 +2456,7 @@ class Cache:
cached_response = json.loads( cached_response = json.loads(
cached_response # type: ignore cached_response # type: ignore
) # Convert string to dictionary ) # Convert string to dictionary
except: except Exception:
cached_response = ast.literal_eval(cached_response) # type: ignore cached_response = ast.literal_eval(cached_response) # type: ignore
return cached_response return cached_response
return cached_result return cached_result
@ -2492,7 +2489,7 @@ class Cache:
return self._get_cache_logic( return self._get_cache_logic(
cached_result=cached_result, max_age=max_age cached_result=cached_result, max_age=max_age
) )
except Exception as e: except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}") print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None return None
@ -2506,7 +2503,7 @@ class Cache:
if self.should_use_cache(*args, **kwargs) is not True: if self.should_use_cache(*args, **kwargs) is not True:
return return
messages = kwargs.get("messages", []) kwargs.get("messages", [])
if "cache_key" in kwargs: if "cache_key" in kwargs:
cache_key = kwargs["cache_key"] cache_key = kwargs["cache_key"]
else: else:
@ -2522,7 +2519,7 @@ class Cache:
return self._get_cache_logic( return self._get_cache_logic(
cached_result=cached_result, max_age=max_age cached_result=cached_result, max_age=max_age
) )
except Exception as e: except Exception:
print_verbose(f"An exception occurred: {traceback.format_exc()}") print_verbose(f"An exception occurred: {traceback.format_exc()}")
return None return None
@ -2701,7 +2698,7 @@ class DiskCache(BaseCache):
if original_cached_response: if original_cached_response:
try: try:
cached_response = json.loads(original_cached_response) # type: ignore cached_response = json.loads(original_cached_response) # type: ignore
except: except Exception:
cached_response = original_cached_response cached_response = original_cached_response
return cached_response return cached_response
return None return None
@ -2803,7 +2800,7 @@ def enable_cache(
if "cache" not in litellm._async_success_callback: if "cache" not in litellm._async_success_callback:
litellm._async_success_callback.append("cache") litellm._async_success_callback.append("cache")
if litellm.cache == None: if litellm.cache is None:
litellm.cache = Cache( litellm.cache = Cache(
type=type, type=type,
host=host, host=host,

View file

@ -57,7 +57,7 @@
# config = yaml.safe_load(file) # config = yaml.safe_load(file)
# else: # else:
# pass # pass
# except: # except Exception:
# pass # pass
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral') # ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')

View file

@ -9,12 +9,12 @@ import asyncio
import contextvars import contextvars
import os import os
from functools import partial from functools import partial
from typing import Any, Coroutine, Dict, Literal, Optional, Union from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
import httpx import httpx
import litellm import litellm
from litellm import client, get_secret from litellm import client, get_secret_str
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
@ -39,7 +39,7 @@ async def afile_retrieve(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> Coroutine[Any, Any, FileObject]: ):
""" """
Async: Get file contents Async: Get file contents
@ -66,7 +66,7 @@ async def afile_retrieve(
if asyncio.iscoroutine(init_response): if asyncio.iscoroutine(init_response):
response = await init_response response = await init_response
else: else:
response = init_response # type: ignore response = init_response
return response return response
except Exception as e: except Exception as e:
@ -137,27 +137,26 @@ def file_retrieve(
organization=organization, organization=organization,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.retrieve_file( response = azure_files_instance.retrieve_file(
_is_async=_is_async, _is_async=_is_async,
@ -181,7 +180,7 @@ def file_retrieve(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return response return cast(FileObject, response)
except Exception as e: except Exception as e:
raise e raise e
@ -222,7 +221,7 @@ async def afile_delete(
else: else:
response = init_response # type: ignore response = init_response # type: ignore
return response return cast(FileDeleted, response) # type: ignore
except Exception as e: except Exception as e:
raise e raise e
@ -248,7 +247,7 @@ def file_delete(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -288,27 +287,26 @@ def file_delete(
organization=organization, organization=organization,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.delete_file( response = azure_files_instance.delete_file(
_is_async=_is_async, _is_async=_is_async,
@ -332,7 +330,7 @@ def file_delete(
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
), ),
) )
return response return cast(FileDeleted, response)
except Exception as e: except Exception as e:
raise e raise e
@ -399,7 +397,7 @@ def file_list(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -441,27 +439,26 @@ def file_list(
organization=organization, organization=organization,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.list_files( response = azure_files_instance.list_files(
_is_async=_is_async, _is_async=_is_async,
@ -556,7 +553,7 @@ def create_file(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -603,27 +600,26 @@ def create_file(
create_file_data=_create_file_request, create_file_data=_create_file_request,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.create_file( response = azure_files_instance.create_file(
_is_async=_is_async, _is_async=_is_async,
@ -713,7 +709,7 @@ def file_content(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -761,27 +757,26 @@ def file_content(
organization=organization, organization=organization,
) )
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_files_instance.file_content( response = azure_files_instance.file_content(
_is_async=_is_async, _is_async=_is_async,

View file

@ -25,7 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
OpenAIFineTuningAPI, OpenAIFineTuningAPI,
) )
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret_str
from litellm.types.llms.openai import Hyperparameters from litellm.types.llms.openai import Hyperparameters
from litellm.types.router import * from litellm.types.router import *
from litellm.utils import supports_httpx_timeout from litellm.utils import supports_httpx_timeout
@ -119,7 +119,7 @@ def create_fine_tuning_job(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -177,28 +177,27 @@ def create_fine_tuning_job(
) )
# Azure OpenAI # Azure OpenAI
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
create_fine_tuning_job_data = FineTuningJobCreate( create_fine_tuning_job_data = FineTuningJobCreate(
model=model, model=model,
@ -228,14 +227,14 @@ def create_fine_tuning_job(
vertex_ai_project = ( vertex_ai_project = (
optional_params.vertex_project optional_params.vertex_project
or litellm.vertex_project or litellm.vertex_project
or get_secret("VERTEXAI_PROJECT") or get_secret_str("VERTEXAI_PROJECT")
) )
vertex_ai_location = ( vertex_ai_location = (
optional_params.vertex_location optional_params.vertex_location
or litellm.vertex_location or litellm.vertex_location
or get_secret("VERTEXAI_LOCATION") or get_secret_str("VERTEXAI_LOCATION")
) )
vertex_credentials = optional_params.vertex_credentials or get_secret( vertex_credentials = optional_params.vertex_credentials or get_secret_str(
"VERTEXAI_CREDENTIALS" "VERTEXAI_CREDENTIALS"
) )
create_fine_tuning_job_data = FineTuningJobCreate( create_fine_tuning_job_data = FineTuningJobCreate(
@ -315,7 +314,7 @@ async def acancel_fine_tuning_job(
def cancel_fine_tuning_job( def cancel_fine_tuning_job(
fine_tuning_job_id: str, fine_tuning_job_id: str,
custom_llm_provider: Literal["openai"] = "openai", custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
@ -335,7 +334,7 @@ def cancel_fine_tuning_job(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -386,23 +385,22 @@ def cancel_fine_tuning_job(
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret_str("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job( response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
api_base=api_base, api_base=api_base,
@ -438,7 +436,7 @@ async def alist_fine_tuning_jobs(
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
) -> FineTuningJob: ):
""" """
Async: List your organization's fine-tuning jobs Async: List your organization's fine-tuning jobs
""" """
@ -473,7 +471,7 @@ async def alist_fine_tuning_jobs(
def list_fine_tuning_jobs( def list_fine_tuning_jobs(
after: Optional[str] = None, after: Optional[str] = None,
limit: Optional[int] = None, limit: Optional[int] = None,
custom_llm_provider: Literal["openai"] = "openai", custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
extra_headers: Optional[Dict[str, str]] = None, extra_headers: Optional[Dict[str, str]] = None,
extra_body: Optional[Dict[str, str]] = None, extra_body: Optional[Dict[str, str]] = None,
**kwargs, **kwargs,
@ -495,7 +493,7 @@ def list_fine_tuning_jobs(
if ( if (
timeout is not None timeout is not None
and isinstance(timeout, httpx.Timeout) and isinstance(timeout, httpx.Timeout)
and supports_httpx_timeout(custom_llm_provider) == False and supports_httpx_timeout(custom_llm_provider) is False
): ):
read_timeout = timeout.read or 600 read_timeout = timeout.read or 600
timeout = read_timeout # default 10 min timeout timeout = read_timeout # default 10 min timeout
@ -542,28 +540,27 @@ def list_fine_tuning_jobs(
) )
# Azure OpenAI # Azure OpenAI
elif custom_llm_provider == "azure": elif custom_llm_provider == "azure":
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
api_version = ( api_version = (
optional_params.api_version optional_params.api_version
or litellm.api_version or litellm.api_version
or get_secret("AZURE_API_VERSION") or get_secret_str("AZURE_API_VERSION")
) # type: ignore ) # type: ignore
api_key = ( api_key = (
optional_params.api_key optional_params.api_key
or litellm.api_key or litellm.api_key
or litellm.azure_key or litellm.azure_key
or get_secret("AZURE_OPENAI_API_KEY") or get_secret_str("AZURE_OPENAI_API_KEY")
or get_secret("AZURE_API_KEY") or get_secret_str("AZURE_API_KEY")
) # type: ignore ) # type: ignore
extra_body = optional_params.get("extra_body", {}) extra_body = optional_params.get("extra_body", {})
azure_ad_token: Optional[str] = None
if extra_body is not None: if extra_body is not None:
azure_ad_token = extra_body.pop("azure_ad_token", None) extra_body.pop("azure_ad_token", None)
else: else:
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore get_secret("AZURE_AD_TOKEN") # type: ignore
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs( response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
api_base=api_base, api_base=api_base,

View file

@ -23,6 +23,9 @@ import litellm.types
from litellm._logging import verbose_logger, verbose_proxy_logger from litellm._logging import verbose_logger, verbose_proxy_logger
from litellm.caching import DualCache from litellm.caching import DualCache
from litellm.integrations.custom_batch_logger import CustomBatchLogger from litellm.integrations.custom_batch_logger import CustomBatchLogger
from litellm.litellm_core_utils.exception_mapping_utils import (
_add_key_name_and_team_to_alert,
)
from litellm.litellm_core_utils.litellm_logging import Logging from litellm.litellm_core_utils.litellm_logging import Logging
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
AsyncHTTPHandler, AsyncHTTPHandler,
@ -219,7 +222,7 @@ class SlackAlerting(CustomBatchLogger):
and "metadata" in kwargs["litellm_params"] and "metadata" in kwargs["litellm_params"]
): ):
_metadata: dict = kwargs["litellm_params"]["metadata"] _metadata: dict = kwargs["litellm_params"]["metadata"]
request_info = litellm.utils._add_key_name_and_team_to_alert( request_info = _add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata request_info=request_info, metadata=_metadata
) )
@ -281,7 +284,7 @@ class SlackAlerting(CustomBatchLogger):
return_val += 1 return_val += 1
return return_val return return_val
except Exception as e: except Exception:
return 0 return 0
async def send_daily_reports(self, router) -> bool: async def send_daily_reports(self, router) -> bool:
@ -455,7 +458,7 @@ class SlackAlerting(CustomBatchLogger):
try: try:
messages = str(messages) messages = str(messages)
messages = messages[:100] messages = messages[:100]
except: except Exception:
messages = "" messages = ""
if ( if (
@ -508,7 +511,7 @@ class SlackAlerting(CustomBatchLogger):
_metadata: dict = request_data["metadata"] _metadata: dict = request_data["metadata"]
_api_base = _metadata.get("api_base", "") _api_base = _metadata.get("api_base", "")
request_info = litellm.utils._add_key_name_and_team_to_alert( request_info = _add_key_name_and_team_to_alert(
request_info=request_info, metadata=_metadata request_info=request_info, metadata=_metadata
) )
@ -846,7 +849,7 @@ class SlackAlerting(CustomBatchLogger):
## MINOR OUTAGE ALERT SENT ## ## MINOR OUTAGE ALERT SENT ##
if ( if (
outage_value["minor_alert_sent"] == False outage_value["minor_alert_sent"] is False
and len(outage_value["alerts"]) and len(outage_value["alerts"])
>= self.alerting_args.minor_outage_alert_threshold >= self.alerting_args.minor_outage_alert_threshold
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
@ -871,7 +874,7 @@ class SlackAlerting(CustomBatchLogger):
## MAJOR OUTAGE ALERT SENT ## ## MAJOR OUTAGE ALERT SENT ##
elif ( elif (
outage_value["major_alert_sent"] == False outage_value["major_alert_sent"] is False
and len(outage_value["alerts"]) and len(outage_value["alerts"])
>= self.alerting_args.major_outage_alert_threshold >= self.alerting_args.major_outage_alert_threshold
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
@ -941,7 +944,7 @@ class SlackAlerting(CustomBatchLogger):
if provider is None: if provider is None:
try: try:
model, provider, _, _ = litellm.get_llm_provider(model=model) model, provider, _, _ = litellm.get_llm_provider(model=model)
except Exception as e: except Exception:
provider = "" provider = ""
api_base = litellm.get_api_base( api_base = litellm.get_api_base(
model=model, optional_params=deployment.litellm_params model=model, optional_params=deployment.litellm_params
@ -976,7 +979,7 @@ class SlackAlerting(CustomBatchLogger):
## MINOR OUTAGE ALERT SENT ## ## MINOR OUTAGE ALERT SENT ##
if ( if (
outage_value["minor_alert_sent"] == False outage_value["minor_alert_sent"] is False
and len(outage_value["alerts"]) and len(outage_value["alerts"])
>= self.alerting_args.minor_outage_alert_threshold >= self.alerting_args.minor_outage_alert_threshold
): ):
@ -998,7 +1001,7 @@ class SlackAlerting(CustomBatchLogger):
# set to true # set to true
outage_value["minor_alert_sent"] = True outage_value["minor_alert_sent"] = True
elif ( elif (
outage_value["major_alert_sent"] == False outage_value["major_alert_sent"] is False
and len(outage_value["alerts"]) and len(outage_value["alerts"])
>= self.alerting_args.major_outage_alert_threshold >= self.alerting_args.major_outage_alert_threshold
): ):
@ -1024,7 +1027,7 @@ class SlackAlerting(CustomBatchLogger):
await self.internal_usage_cache.async_set_cache( await self.internal_usage_cache.async_set_cache(
key=deployment_id, value=outage_value key=deployment_id, value=outage_value
) )
except Exception as e: except Exception:
pass pass
async def model_added_alert( async def model_added_alert(
@ -1177,7 +1180,6 @@ Model Info:
if user_row is not None: if user_row is not None:
recipient_email = user_row.user_email recipient_email = user_row.user_email
key_name = webhook_event.key_alias
key_token = webhook_event.token key_token = webhook_event.token
key_budget = webhook_event.max_budget key_budget = webhook_event.max_budget
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000") base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
@ -1221,14 +1223,14 @@ Model Info:
extra=webhook_event.model_dump(), extra=webhook_event.model_dump(),
) )
payload = webhook_event.model_dump_json() webhook_event.model_dump_json()
email_event = { email_event = {
"to": recipient_email, "to": recipient_email,
"subject": f"LiteLLM: {event_name}", "subject": f"LiteLLM: {event_name}",
"html": email_html_content, "html": email_html_content,
} }
response = await send_email( await send_email(
receiver_email=email_event["to"], receiver_email=email_event["to"],
subject=email_event["subject"], subject=email_event["subject"],
html=email_event["html"], html=email_event["html"],
@ -1292,14 +1294,14 @@ Model Info:
The LiteLLM team <br /> The LiteLLM team <br />
""" """
payload = webhook_event.model_dump_json() webhook_event.model_dump_json()
email_event = { email_event = {
"to": recipient_email, "to": recipient_email,
"subject": f"LiteLLM: {event_name}", "subject": f"LiteLLM: {event_name}",
"html": email_html_content, "html": email_html_content,
} }
response = await send_email( await send_email(
receiver_email=email_event["to"], receiver_email=email_event["to"],
subject=email_event["subject"], subject=email_event["subject"],
html=email_event["html"], html=email_event["html"],
@ -1446,7 +1448,6 @@ Model Info:
response_s: timedelta = end_time - start_time response_s: timedelta = end_time - start_time
final_value = response_s final_value = response_s
total_tokens = 0
if isinstance(response_obj, litellm.ModelResponse) and ( if isinstance(response_obj, litellm.ModelResponse) and (
hasattr(response_obj, "usage") hasattr(response_obj, "usage")
@ -1505,7 +1506,7 @@ Model Info:
await self.region_outage_alerts( await self.region_outage_alerts(
exception=kwargs["exception"], deployment_id=model_id exception=kwargs["exception"], deployment_id=model_id
) )
except Exception as e: except Exception:
pass pass
async def _run_scheduler_helper(self, llm_router) -> bool: async def _run_scheduler_helper(self, llm_router) -> bool:

View file

@ -35,7 +35,7 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs): # type: ignore def json(self, **kwargs): # type: ignore
try: try:
return self.model_dump() # noqa return self.model_dump() # noqa
except: except Exception:
# if using pydantic v1 # if using pydantic v1
return self.dict() return self.dict()

View file

@ -1,8 +1,10 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os
import traceback
import datetime import datetime
import os
import traceback
import dotenv
model_cost = { model_cost = {
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
@ -118,8 +120,6 @@ class AISpendLogger:
for model in model_cost: for model in model_cost:
input_cost_sum += model_cost[model]["input_cost_per_token"] input_cost_sum += model_cost[model]["input_cost_per_token"]
output_cost_sum += model_cost[model]["output_cost_per_token"] output_cost_sum += model_cost[model]["output_cost_per_token"]
avg_input_cost = input_cost_sum / len(model_cost.keys())
avg_output_cost = output_cost_sum / len(model_cost.keys())
prompt_tokens_cost_usd_dollar = ( prompt_tokens_cost_usd_dollar = (
model_cost[model]["input_cost_per_token"] model_cost[model]["input_cost_per_token"]
* response_obj["usage"]["prompt_tokens"] * response_obj["usage"]["prompt_tokens"]
@ -137,12 +137,6 @@ class AISpendLogger:
f"AISpend Logging - Enters logging function for model {model}" f"AISpend Logging - Enters logging function for model {model}"
) )
url = f"https://aispend.io/api/v1/accounts/{self.account_id}/data"
headers = {
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
}
response_timestamp = datetime.datetime.fromtimestamp( response_timestamp = datetime.datetime.fromtimestamp(
int(response_obj["created"]) int(response_obj["created"])
).strftime("%Y-%m-%d") ).strftime("%Y-%m-%d")
@ -168,6 +162,6 @@ class AISpendLogger:
] ]
print_verbose(f"AISpend Logging - final data object: {data}") print_verbose(f"AISpend Logging - final data object: {data}")
except: except Exception:
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}") print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -23,7 +23,7 @@ def set_arize_ai_attributes(span: Span, kwargs, response_obj):
) )
optional_params = kwargs.get("optional_params", {}) optional_params = kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {} # litellm_params = kwargs.get("litellm_params", {}) or {}
############################################# #############################################
############ LLM CALL METADATA ############## ############ LLM CALL METADATA ##############

View file

@ -1,5 +1,6 @@
import datetime import datetime
class AthinaLogger: class AthinaLogger:
def __init__(self): def __init__(self):
import os import os
@ -23,17 +24,20 @@ class AthinaLogger:
] ]
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose): def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
import requests # type: ignore
import json import json
import traceback import traceback
import requests # type: ignore
try: try:
is_stream = kwargs.get("stream", False) is_stream = kwargs.get("stream", False)
if is_stream: if is_stream:
if "complete_streaming_response" in kwargs: if "complete_streaming_response" in kwargs:
# Log the completion response in streaming mode # Log the completion response in streaming mode
completion_response = kwargs["complete_streaming_response"] completion_response = kwargs["complete_streaming_response"]
response_json = completion_response.model_dump() if completion_response else {} response_json = (
completion_response.model_dump() if completion_response else {}
)
else: else:
# Skip logging if the completion response is not available # Skip logging if the completion response is not available
return return
@ -52,8 +56,8 @@ class AthinaLogger:
} }
if ( if (
type(end_time) == datetime.datetime type(end_time) is datetime.datetime
and type(start_time) == datetime.datetime and type(start_time) is datetime.datetime
): ):
data["response_time"] = int( data["response_time"] = int(
(end_time - start_time).total_seconds() * 1000 (end_time - start_time).total_seconds() * 1000

View file

@ -1,10 +1,11 @@
#### What this does #### #### What this does ####
# On success + failure, log events to aispend.io # On success + failure, log events to aispend.io
import dotenv, os
import requests # type: ignore
import traceback
import datetime import datetime
import os
import traceback
import dotenv
import requests # type: ignore
model_cost = { model_cost = {
"gpt-3.5-turbo": { "gpt-3.5-turbo": {
@ -92,91 +93,12 @@ class BerriSpendLogger:
self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID") self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID")
def price_calculator(self, model, response_obj, start_time, end_time): def price_calculator(self, model, response_obj, start_time, end_time):
# try and find if the model is in the model_cost map return
# else default to the average of the costs
prompt_tokens_cost_usd_dollar = 0
completion_tokens_cost_usd_dollar = 0
if model in model_cost:
prompt_tokens_cost_usd_dollar = (
model_cost[model]["input_cost_per_token"]
* response_obj["usage"]["prompt_tokens"]
)
completion_tokens_cost_usd_dollar = (
model_cost[model]["output_cost_per_token"]
* response_obj["usage"]["completion_tokens"]
)
elif "replicate" in model:
# replicate models are charged based on time
# llama 2 runs on an nvidia a100 which costs $0.0032 per second - https://replicate.com/replicate/llama-2-70b-chat
model_run_time = end_time - start_time # assuming time in seconds
cost_usd_dollar = model_run_time * 0.0032
prompt_tokens_cost_usd_dollar = cost_usd_dollar / 2
completion_tokens_cost_usd_dollar = cost_usd_dollar / 2
else:
# calculate average input cost
input_cost_sum = 0
output_cost_sum = 0
for model in model_cost:
input_cost_sum += model_cost[model]["input_cost_per_token"]
output_cost_sum += model_cost[model]["output_cost_per_token"]
avg_input_cost = input_cost_sum / len(model_cost.keys())
avg_output_cost = output_cost_sum / len(model_cost.keys())
prompt_tokens_cost_usd_dollar = (
model_cost[model]["input_cost_per_token"]
* response_obj["usage"]["prompt_tokens"]
)
completion_tokens_cost_usd_dollar = (
model_cost[model]["output_cost_per_token"]
* response_obj["usage"]["completion_tokens"]
)
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
def log_event( def log_event(
self, model, messages, response_obj, start_time, end_time, print_verbose self, model, messages, response_obj, start_time, end_time, print_verbose
): ):
# Method definition """
try: This integration is not implemented yet.
print_verbose( """
f"BerriSpend Logging - Enters logging function for model {model}" return
)
url = f"https://berrispend.berri.ai/spend"
headers = {"Content-Type": "application/json"}
(
prompt_tokens_cost_usd_dollar,
completion_tokens_cost_usd_dollar,
) = self.price_calculator(model, response_obj, start_time, end_time)
total_cost = (
prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
)
response_time = (end_time - start_time).total_seconds()
if "response" in response_obj:
data = [
{
"response_time": response_time,
"model_id": response_obj["model"],
"total_cost": total_cost,
"messages": messages,
"response": response_obj["choices"][0]["message"]["content"],
"account_id": self.account_id,
}
]
elif "error" in response_obj:
data = [
{
"response_time": response_time,
"model_id": response_obj["model"],
"total_cost": total_cost,
"messages": messages,
"error": response_obj["error"],
"account_id": self.account_id,
}
]
print_verbose(f"BerriSpend Logging - final data object: {data}")
response = requests.post(url, headers=headers, json=data)
except:
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
pass

View file

@ -136,27 +136,23 @@ class BraintrustLogger(CustomLogger):
project_id = self.default_project_id project_id = self.default_project_id
prompt = {"messages": kwargs.get("messages")} prompt = {"messages": kwargs.get("messages")}
output = None
if response_obj is not None and ( if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(response_obj, litellm.EmbeddingResponse)
): ):
input = prompt
output = None output = None
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse response_obj, litellm.ModelResponse
): ):
input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse response_obj, litellm.TextCompletionResponse
): ):
input = prompt
output = response_obj.choices[0].text output = response_obj.choices[0].text
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse response_obj, litellm.ImageResponse
): ):
input = prompt
output = response_obj["data"] output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
@ -169,7 +165,7 @@ class BraintrustLogger(CustomLogger):
metadata = copy.deepcopy( metadata = copy.deepcopy(
metadata metadata
) # Avoid modifying the original metadata ) # Avoid modifying the original metadata
except: except Exception:
new_metadata = {} new_metadata = {}
for key, value in metadata.items(): for key, value in metadata.items():
if ( if (
@ -210,16 +206,13 @@ class BraintrustLogger(CustomLogger):
clean_metadata["litellm_response_cost"] = cost clean_metadata["litellm_response_cost"] = cost
metrics: Optional[dict] = None metrics: Optional[dict] = None
if ( usage_obj = getattr(response_obj, "usage", None)
response_obj is not None if usage_obj and isinstance(usage_obj, litellm.Usage):
and hasattr(response_obj, "usage") litellm.utils.get_logging_id(start_time, response_obj)
and isinstance(response_obj.usage, litellm.Usage)
):
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
metrics = { metrics = {
"prompt_tokens": response_obj.usage.prompt_tokens, "prompt_tokens": usage_obj.prompt_tokens,
"completion_tokens": response_obj.usage.completion_tokens, "completion_tokens": usage_obj.completion_tokens,
"total_tokens": response_obj.usage.total_tokens, "total_tokens": usage_obj.total_tokens,
"total_cost": cost, "total_cost": cost,
} }
@ -255,27 +248,23 @@ class BraintrustLogger(CustomLogger):
project_id = self.default_project_id project_id = self.default_project_id
prompt = {"messages": kwargs.get("messages")} prompt = {"messages": kwargs.get("messages")}
output = None
if response_obj is not None and ( if response_obj is not None and (
kwargs.get("call_type", None) == "embedding" kwargs.get("call_type", None) == "embedding"
or isinstance(response_obj, litellm.EmbeddingResponse) or isinstance(response_obj, litellm.EmbeddingResponse)
): ):
input = prompt
output = None output = None
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ModelResponse response_obj, litellm.ModelResponse
): ):
input = prompt
output = response_obj["choices"][0]["message"].json() output = response_obj["choices"][0]["message"].json()
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.TextCompletionResponse response_obj, litellm.TextCompletionResponse
): ):
input = prompt
output = response_obj.choices[0].text output = response_obj.choices[0].text
elif response_obj is not None and isinstance( elif response_obj is not None and isinstance(
response_obj, litellm.ImageResponse response_obj, litellm.ImageResponse
): ):
input = prompt
output = response_obj["data"] output = response_obj["data"]
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
@ -331,16 +320,13 @@ class BraintrustLogger(CustomLogger):
clean_metadata["litellm_response_cost"] = cost clean_metadata["litellm_response_cost"] = cost
metrics: Optional[dict] = None metrics: Optional[dict] = None
if ( usage_obj = getattr(response_obj, "usage", None)
response_obj is not None if usage_obj and isinstance(usage_obj, litellm.Usage):
and hasattr(response_obj, "usage") litellm.utils.get_logging_id(start_time, response_obj)
and isinstance(response_obj.usage, litellm.Usage)
):
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
metrics = { metrics = {
"prompt_tokens": response_obj.usage.prompt_tokens, "prompt_tokens": usage_obj.prompt_tokens,
"completion_tokens": response_obj.usage.completion_tokens, "completion_tokens": usage_obj.completion_tokens,
"total_tokens": response_obj.usage.total_tokens, "total_tokens": usage_obj.total_tokens,
"total_cost": cost, "total_cost": cost,
} }

View file

@ -2,25 +2,24 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import datetime
import json
from litellm.proxy._types import UserAPIKeyAuth import os
from litellm.caching import DualCache
from typing import Literal, Union
import traceback import traceback
from typing import Literal, Optional, Union
import dotenv
import requests
import litellm
from litellm._logging import verbose_logger
from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.utils import StandardLoggingPayload
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os
import requests
import traceback
import datetime, subprocess, sys
import litellm, uuid
from litellm._logging import print_verbose, verbose_logger
def create_client(): def create_client():
try: try:
@ -260,18 +259,12 @@ class ClickhouseLogger:
f"ClickhouseLogger Logging - Enters logging function for model {kwargs}" f"ClickhouseLogger Logging - Enters logging function for model {kwargs}"
) )
# follows the same params as langfuse.py # follows the same params as langfuse.py
from litellm.proxy.utils import get_logging_payload
payload = get_logging_payload( payload: Optional[StandardLoggingPayload] = kwargs.get(
kwargs=kwargs, "standard_logging_object"
response_obj=response_obj,
start_time=start_time,
end_time=end_time,
) )
metadata = payload.get("metadata", "") or "" if payload is None:
request_tags = payload.get("request_tags", "") or "" return
payload["metadata"] = str(metadata)
payload["request_tags"] = str(request_tags)
# Build the initial payload # Build the initial payload
verbose_logger.debug(f"\nClickhouse Logger - Logging payload = {payload}") verbose_logger.debug(f"\nClickhouse Logger - Logging payload = {payload}")

View file

@ -12,7 +12,12 @@ from litellm.caching import DualCache
from litellm.proxy._types import UserAPIKeyAuth from litellm.proxy._types import UserAPIKeyAuth
from litellm.types.llms.openai import ChatCompletionRequest from litellm.types.llms.openai import ChatCompletionRequest
from litellm.types.services import ServiceLoggerPayload from litellm.types.services import ServiceLoggerPayload
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse from litellm.types.utils import (
AdapterCompletionStreamWrapper,
EmbeddingResponse,
ImageResponse,
ModelResponse,
)
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
@ -140,8 +145,8 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
self, self,
data: dict, data: dict,
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
response, response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
): ) -> Any:
pass pass
async def async_logging_hook( async def async_logging_hook(
@ -188,7 +193,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
kwargs, kwargs,
) )
print_verbose(f"Custom Logger - model call details: {kwargs}") print_verbose(f"Custom Logger - model call details: {kwargs}")
except: except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
async def async_log_input_event( async def async_log_input_event(
@ -202,7 +207,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
kwargs, kwargs,
) )
print_verbose(f"Custom Logger - model call details: {kwargs}") print_verbose(f"Custom Logger - model call details: {kwargs}")
except: except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
def log_event( def log_event(
@ -217,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
start_time, start_time,
end_time, end_time,
) )
except: except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass
@ -233,6 +238,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
start_time, start_time,
end_time, end_time,
) )
except: except Exception:
print_verbose(f"Custom Logger Error - {traceback.format_exc()}") print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
pass pass

View file

@ -54,7 +54,7 @@ class DataDogLogger(CustomBatchLogger):
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"` `DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
""" """
try: try:
verbose_logger.debug(f"Datadog: in init datadog logger") verbose_logger.debug("Datadog: in init datadog logger")
# check if the correct env variables are set # check if the correct env variables are set
if os.getenv("DD_API_KEY", None) is None: if os.getenv("DD_API_KEY", None) is None:
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>") raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
@ -245,12 +245,12 @@ class DataDogLogger(CustomBatchLogger):
usage = dict(usage) usage = dict(usage)
try: try:
response_time = (end_time - start_time).total_seconds() * 1000 response_time = (end_time - start_time).total_seconds() * 1000
except: except Exception:
response_time = None response_time = None
try: try:
response_obj = dict(response_obj) response_obj = dict(response_obj)
except: except Exception:
response_obj = response_obj response_obj = response_obj
# Clean Metadata before logging - never log raw metadata # Clean Metadata before logging - never log raw metadata

View file

@ -7,7 +7,7 @@ def make_json_serializable(payload):
elif not isinstance(value, (str, int, float, bool, type(None))): elif not isinstance(value, (str, int, float, bool, type(None))):
# everything else becomes a string # everything else becomes a string
payload[key] = str(value) payload[key] = str(value)
except: except Exception:
# non blocking if it can't cast to a str # non blocking if it can't cast to a str
pass pass
return payload return payload

View file

@ -1,12 +1,16 @@
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import datetime
import requests # type: ignore import os
import traceback import traceback
import datetime, subprocess, sys import uuid
import litellm, uuid from typing import Any
from litellm._logging import print_verbose
import dotenv
import requests # type: ignore
import litellm
class DyanmoDBLogger: class DyanmoDBLogger:
@ -16,7 +20,7 @@ class DyanmoDBLogger:
# Instance variables # Instance variables
import boto3 import boto3
self.dynamodb = boto3.resource( self.dynamodb: Any = boto3.resource(
"dynamodb", region_name=os.environ["AWS_REGION_NAME"] "dynamodb", region_name=os.environ["AWS_REGION_NAME"]
) )
if litellm.dynamodb_table_name is None: if litellm.dynamodb_table_name is None:
@ -67,7 +71,7 @@ class DyanmoDBLogger:
for key, value in payload.items(): for key, value in payload.items():
try: try:
payload[key] = str(value) payload[key] = str(value)
except: except Exception:
# non blocking if it can't cast to a str # non blocking if it can't cast to a str
pass pass
@ -84,6 +88,6 @@ class DyanmoDBLogger:
f"DynamoDB Layer Logging - final response object: {response_obj}" f"DynamoDB Layer Logging - final response object: {response_obj}"
) )
return response return response
except: except Exception:
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}") print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -9,6 +9,7 @@ import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
_get_httpx_client,
get_async_httpx_client, get_async_httpx_client,
httpxSpecialProvider, httpxSpecialProvider,
) )
@ -41,7 +42,7 @@ class GalileoObserve(CustomLogger):
self.batch_size = 1 self.batch_size = 1
self.base_url = os.getenv("GALILEO_BASE_URL", None) self.base_url = os.getenv("GALILEO_BASE_URL", None)
self.project_id = os.getenv("GALILEO_PROJECT_ID", None) self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
self.headers = None self.headers: Optional[Dict[str, str]] = None
self.async_httpx_handler = get_async_httpx_client( self.async_httpx_handler = get_async_httpx_client(
llm_provider=httpxSpecialProvider.LoggingCallback llm_provider=httpxSpecialProvider.LoggingCallback
) )
@ -54,7 +55,7 @@ class GalileoObserve(CustomLogger):
"accept": "application/json", "accept": "application/json",
"Content-Type": "application/x-www-form-urlencoded", "Content-Type": "application/x-www-form-urlencoded",
} }
galileo_login_response = self.async_httpx_handler.post( galileo_login_response = litellm.module_level_client.post(
url=f"{self.base_url}/login", url=f"{self.base_url}/login",
headers=headers, headers=headers,
data={ data={
@ -94,13 +95,9 @@ class GalileoObserve(CustomLogger):
return output return output
async def async_log_success_event( async def async_log_success_event(
self, self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
kwargs,
start_time,
end_time,
response_obj,
): ):
verbose_logger.debug(f"On Async Success") verbose_logger.debug("On Async Success")
_latency_ms = int((end_time - start_time).total_seconds() * 1000) _latency_ms = int((end_time - start_time).total_seconds() * 1000)
_call_type = kwargs.get("call_type", "litellm") _call_type = kwargs.get("call_type", "litellm")
@ -116,26 +113,27 @@ class GalileoObserve(CustomLogger):
response_obj=response_obj, kwargs=kwargs response_obj=response_obj, kwargs=kwargs
) )
request_record = LLMResponse( if output_text is not None:
latency_ms=_latency_ms, request_record = LLMResponse(
status_code=200, latency_ms=_latency_ms,
input_text=input_text, status_code=200,
output_text=output_text, input_text=input_text,
node_type=_call_type, output_text=output_text,
model=kwargs.get("model", "-"), node_type=_call_type,
num_input_tokens=num_input_tokens, model=kwargs.get("model", "-"),
num_output_tokens=num_output_tokens, num_input_tokens=num_input_tokens,
created_at=start_time.strftime( num_output_tokens=num_output_tokens,
"%Y-%m-%dT%H:%M:%S" created_at=start_time.strftime(
), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format "%Y-%m-%dT%H:%M:%S"
) ), # timestamp str constructed in "%Y-%m-%dT%H:%M:%S" format
)
# dump to dict # dump to dict
request_dict = request_record.model_dump() request_dict = request_record.model_dump()
self.in_memory_records.append(request_dict) self.in_memory_records.append(request_dict)
if len(self.in_memory_records) >= self.batch_size: if len(self.in_memory_records) >= self.batch_size:
await self.flush_in_memory_records() await self.flush_in_memory_records()
async def flush_in_memory_records(self): async def flush_in_memory_records(self):
verbose_logger.debug("flushing in memory records") verbose_logger.debug("flushing in memory records")
@ -159,4 +157,4 @@ class GalileoObserve(CustomLogger):
) )
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time): async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
verbose_logger.debug(f"On Async Failure") verbose_logger.debug("On Async Failure")

View file

@ -56,8 +56,8 @@ class GCSBucketLogger(GCSBucketBase):
response_obj, response_obj,
) )
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S") start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S") end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers() headers = await self.construct_request_headers()
logging_payload: Optional[StandardLoggingPayload] = kwargs.get( logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
@ -103,8 +103,8 @@ class GCSBucketLogger(GCSBucketBase):
response_obj, response_obj,
) )
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S") start_time.strftime("%Y-%m-%d %H:%M:%S")
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S") end_time.strftime("%Y-%m-%d %H:%M:%S")
headers = await self.construct_request_headers() headers = await self.construct_request_headers()
logging_payload: Optional[StandardLoggingPayload] = kwargs.get( logging_payload: Optional[StandardLoggingPayload] = kwargs.get(

View file

@ -1,8 +1,9 @@
import requests # type: ignore
import json import json
import traceback import traceback
from datetime import datetime, timezone from datetime import datetime, timezone
import requests # type: ignore
class GreenscaleLogger: class GreenscaleLogger:
def __init__(self): def __init__(self):
@ -29,7 +30,7 @@ class GreenscaleLogger:
"%Y-%m-%dT%H:%M:%SZ" "%Y-%m-%dT%H:%M:%SZ"
) )
if type(end_time) == datetime and type(start_time) == datetime: if type(end_time) is datetime and type(start_time) is datetime:
data["invocationLatency"] = int( data["invocationLatency"] = int(
(end_time - start_time).total_seconds() * 1000 (end_time - start_time).total_seconds() * 1000
) )
@ -50,6 +51,9 @@ class GreenscaleLogger:
data["tags"] = tags data["tags"] = tags
if self.greenscale_logging_url is None:
raise Exception("Greenscale Logger Error - No logging URL found")
response = requests.post( response = requests.post(
self.greenscale_logging_url, self.greenscale_logging_url,
headers=self.headers, headers=self.headers,

View file

@ -1,15 +1,28 @@
#### What this does #### #### What this does ####
# On success, logs events to Helicone # On success, logs events to Helicone
import dotenv, os import os
import requests # type: ignore
import litellm
import traceback import traceback
import dotenv
import requests # type: ignore
import litellm
from litellm._logging import verbose_logger from litellm._logging import verbose_logger
class HeliconeLogger: class HeliconeLogger:
# Class variables or attributes # Class variables or attributes
helicone_model_list = ["gpt", "claude", "command-r", "command-r-plus", "command-light", "command-medium", "command-medium-beta", "command-xlarge-nightly", "command-nightly"] helicone_model_list = [
"gpt",
"claude",
"command-r",
"command-r-plus",
"command-light",
"command-medium",
"command-medium-beta",
"command-xlarge-nightly",
"command-nightly",
]
def __init__(self): def __init__(self):
# Instance variables # Instance variables
@ -17,7 +30,7 @@ class HeliconeLogger:
self.key = os.getenv("HELICONE_API_KEY") self.key = os.getenv("HELICONE_API_KEY")
def claude_mapping(self, model, messages, response_obj): def claude_mapping(self, model, messages, response_obj):
from anthropic import HUMAN_PROMPT, AI_PROMPT from anthropic import AI_PROMPT, HUMAN_PROMPT
prompt = f"{HUMAN_PROMPT}" prompt = f"{HUMAN_PROMPT}"
for message in messages: for message in messages:
@ -29,7 +42,6 @@ class HeliconeLogger:
else: else:
prompt += f"{HUMAN_PROMPT}{message['content']}" prompt += f"{HUMAN_PROMPT}{message['content']}"
prompt += f"{AI_PROMPT}" prompt += f"{AI_PROMPT}"
claude_provider_request = {"model": model, "prompt": prompt}
choice = response_obj["choices"][0] choice = response_obj["choices"][0]
message = choice["message"] message = choice["message"]
@ -37,12 +49,14 @@ class HeliconeLogger:
content = [] content = []
if "tool_calls" in message and message["tool_calls"]: if "tool_calls" in message and message["tool_calls"]:
for tool_call in message["tool_calls"]: for tool_call in message["tool_calls"]:
content.append({ content.append(
"type": "tool_use", {
"id": tool_call["id"], "type": "tool_use",
"name": tool_call["function"]["name"], "id": tool_call["id"],
"input": tool_call["function"]["arguments"] "name": tool_call["function"]["name"],
}) "input": tool_call["function"]["arguments"],
}
)
elif "content" in message and message["content"]: elif "content" in message and message["content"]:
content = [{"type": "text", "text": message["content"]}] content = [{"type": "text", "text": message["content"]}]
@ -56,8 +70,8 @@ class HeliconeLogger:
"stop_sequence": None, "stop_sequence": None,
"usage": { "usage": {
"input_tokens": response_obj["usage"]["prompt_tokens"], "input_tokens": response_obj["usage"]["prompt_tokens"],
"output_tokens": response_obj["usage"]["completion_tokens"] "output_tokens": response_obj["usage"]["completion_tokens"],
} },
} }
return claude_response_obj return claude_response_obj
@ -99,10 +113,8 @@ class HeliconeLogger:
f"Helicone Logging - Enters logging function for model {model}" f"Helicone Logging - Enters logging function for model {model}"
) )
litellm_params = kwargs.get("litellm_params", {}) litellm_params = kwargs.get("litellm_params", {})
litellm_call_id = kwargs.get("litellm_call_id", None) kwargs.get("litellm_call_id", None)
metadata = ( metadata = litellm_params.get("metadata", {}) or {}
litellm_params.get("metadata", {}) or {}
)
metadata = self.add_metadata_from_header(litellm_params, metadata) metadata = self.add_metadata_from_header(litellm_params, metadata)
model = ( model = (
model model
@ -175,6 +187,6 @@ class HeliconeLogger:
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}" f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
) )
print_verbose(f"Helicone Logging - Error {response.text}") print_verbose(f"Helicone Logging - Error {response.text}")
except: except Exception:
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}") print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -11,7 +11,7 @@ import dotenv
import httpx import httpx
import litellm import litellm
from litellm import verbose_logger from litellm._logging import verbose_logger
from litellm.integrations.custom_logger import CustomLogger from litellm.integrations.custom_logger import CustomLogger
from litellm.llms.custom_httpx.http_handler import ( from litellm.llms.custom_httpx.http_handler import (
HTTPHandler, HTTPHandler,
@ -65,8 +65,8 @@ class LagoLogger(CustomLogger):
raise Exception("Missing keys={} in environment.".format(missing_keys)) raise Exception("Missing keys={} in environment.".format(missing_keys))
def _common_logic(self, kwargs: dict, response_obj) -> dict: def _common_logic(self, kwargs: dict, response_obj) -> dict:
call_id = response_obj.get("id", kwargs.get("litellm_call_id")) response_obj.get("id", kwargs.get("litellm_call_id"))
dt = get_utc_datetime().isoformat() get_utc_datetime().isoformat()
cost = kwargs.get("response_cost", None) cost = kwargs.get("response_cost", None)
model = kwargs.get("model") model = kwargs.get("model")
usage = {} usage = {}
@ -86,7 +86,7 @@ class LagoLogger(CustomLogger):
end_user_id = proxy_server_request.get("body", {}).get("user", None) end_user_id = proxy_server_request.get("body", {}).get("user", None)
user_id = litellm_params["metadata"].get("user_api_key_user_id", None) user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
team_id = litellm_params["metadata"].get("user_api_key_team_id", None) team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
org_id = litellm_params["metadata"].get("user_api_key_org_id", None) litellm_params["metadata"].get("user_api_key_org_id", None)
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id" charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
external_customer_id: Optional[str] = None external_customer_id: Optional[str] = None
@ -158,8 +158,9 @@ class LagoLogger(CustomLogger):
response.raise_for_status() response.raise_for_status()
except Exception as e: except Exception as e:
if hasattr(response, "text"): error_response = getattr(e, "response", None)
litellm.print_verbose(f"\nError Message: {response.text}") if error_response is not None and hasattr(error_response, "text"):
verbose_logger.debug(f"\nError Message: {error_response.text}")
raise e raise e
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
@ -199,5 +200,5 @@ class LagoLogger(CustomLogger):
verbose_logger.debug(f"Logged Lago Object: {response.text}") verbose_logger.debug(f"Logged Lago Object: {response.text}")
except Exception as e: except Exception as e:
if response is not None and hasattr(response, "text"): if response is not None and hasattr(response, "text"):
litellm.print_verbose(f"\nError Message: {response.text}") verbose_logger.debug(f"\nError Message: {response.text}")
raise e raise e

View file

@ -67,7 +67,7 @@ class LangFuseLogger:
try: try:
project_id = self.Langfuse.client.projects.get().data[0].id project_id = self.Langfuse.client.projects.get().data[0].id
os.environ["LANGFUSE_PROJECT_ID"] = project_id os.environ["LANGFUSE_PROJECT_ID"] = project_id
except: except Exception:
project_id = None project_id = None
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None: if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
@ -184,7 +184,7 @@ class LangFuseLogger:
if not isinstance(value, (str, int, bool, float)): if not isinstance(value, (str, int, bool, float)):
try: try:
optional_params[param] = str(value) optional_params[param] = str(value)
except: except Exception:
# if casting value to str fails don't block logging # if casting value to str fails don't block logging
pass pass
@ -275,7 +275,7 @@ class LangFuseLogger:
print_verbose( print_verbose(
f"Langfuse Layer Logging - final response object: {response_obj}" f"Langfuse Layer Logging - final response object: {response_obj}"
) )
verbose_logger.info(f"Langfuse Layer Logging - logging success") verbose_logger.info("Langfuse Layer Logging - logging success")
return {"trace_id": trace_id, "generation_id": generation_id} return {"trace_id": trace_id, "generation_id": generation_id}
except Exception as e: except Exception as e:
@ -492,7 +492,7 @@ class LangFuseLogger:
output if not mask_output else "redacted-by-litellm" output if not mask_output else "redacted-by-litellm"
) )
if debug == True or (isinstance(debug, str) and debug.lower() == "true"): if debug is True or (isinstance(debug, str) and debug.lower() == "true"):
if "metadata" in trace_params: if "metadata" in trace_params:
# log the raw_metadata in the trace # log the raw_metadata in the trace
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
@ -535,8 +535,8 @@ class LangFuseLogger:
proxy_server_request = litellm_params.get("proxy_server_request", None) proxy_server_request = litellm_params.get("proxy_server_request", None)
if proxy_server_request: if proxy_server_request:
method = proxy_server_request.get("method", None) proxy_server_request.get("method", None)
url = proxy_server_request.get("url", None) proxy_server_request.get("url", None)
headers = proxy_server_request.get("headers", None) headers = proxy_server_request.get("headers", None)
clean_headers = {} clean_headers = {}
if headers: if headers:
@ -625,7 +625,7 @@ class LangFuseLogger:
generation_client = trace.generation(**generation_params) generation_client = trace.generation(**generation_params)
return generation_client.trace_id, generation_id return generation_client.trace_id, generation_id
except Exception as e: except Exception:
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}") verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
return None, None return None, None

View file

@ -404,7 +404,7 @@ class LangsmithLogger(CustomBatchLogger):
verbose_logger.exception( verbose_logger.exception(
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}" f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
) )
except Exception as e: except Exception:
verbose_logger.exception( verbose_logger.exception(
f"Langsmith Layer Error - {traceback.format_exc()}" f"Langsmith Layer Error - {traceback.format_exc()}"
) )

View file

@ -1,6 +1,10 @@
import requests, traceback, json, os import json
import os
import traceback
import types import types
import requests
class LiteDebugger: class LiteDebugger:
user_email = None user_email = None
@ -17,23 +21,17 @@ class LiteDebugger:
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL") email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
) )
if ( if (
self.user_email == None self.user_email is None
): # if users are trying to use_client=True but token not set ): # if users are trying to use_client=True but token not set
raise ValueError( raise ValueError(
"litellm.use_client = True but no token or email passed. Please set it in litellm.token" "litellm.use_client = True but no token or email passed. Please set it in litellm.token"
) )
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
try: if self.user_email is None:
print(
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
)
except:
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
if self.user_email == None:
raise ValueError( raise ValueError(
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>" "[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
) )
except Exception as e: except Exception:
raise ValueError( raise ValueError(
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>" "[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
) )
@ -49,123 +47,18 @@ class LiteDebugger:
litellm_params, litellm_params,
optional_params, optional_params,
): ):
print_verbose( """
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}" This integration is not implemented yet.
) """
try: return
print_verbose(
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
)
def remove_key_value(dictionary, key):
new_dict = dictionary.copy() # Create a copy of the original dictionary
new_dict.pop(key) # Remove the specified key-value pair from the copy
return new_dict
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
if call_type == "embedding":
for (
message
) in (
messages
): # assuming the input is a list as required by the embedding function
litellm_data_obj = {
"model": model,
"messages": [{"role": "user", "content": message}],
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
"litellm_params": updated_litellm_params,
"optional_params": optional_params,
}
print_verbose(
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: embedding api response - {response.text}")
elif call_type == "completion":
litellm_data_obj = {
"model": model,
"messages": messages
if isinstance(messages, list)
else [{"role": "user", "content": messages}],
"end_user": end_user,
"status": "initiated",
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
"litellm_params": updated_litellm_params,
"optional_params": optional_params,
}
print_verbose(
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(
f"LiteDebugger: completion api response - {response.text}"
)
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass
def post_call_log_event( def post_call_log_event(
self, original_response, litellm_call_id, print_verbose, call_type, stream self, original_response, litellm_call_id, print_verbose, call_type, stream
): ):
print_verbose( """
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}" This integration is not implemented yet.
) """
try: return
if call_type == "embedding":
litellm_data_obj = {
"status": "received",
"additional_details": {
"original_response": str(
original_response["data"][0]["embedding"][:5]
)
}, # don't store the entire vector
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
elif call_type == "completion" and not stream:
litellm_data_obj = {
"status": "received",
"additional_details": {"original_response": original_response},
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
elif call_type == "completion" and stream:
litellm_data_obj = {
"status": "received",
"additional_details": {
"original_response": "Streamed response"
if isinstance(original_response, types.GeneratorType)
else original_response
},
"litellm_call_id": litellm_call_id,
"user_email": self.user_email,
}
print_verbose(f"litedebugger post-call data object - {litellm_data_obj}")
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: api response - {response.text}")
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
def log_event( def log_event(
self, self,
@ -178,85 +71,7 @@ class LiteDebugger:
call_type, call_type,
stream=False, stream=False,
): ):
print_verbose( """
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}" This integration is not implemented yet.
) """
try: return
print_verbose(
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
)
total_cost = 0 # [TODO] implement cost tracking
response_time = (end_time - start_time).total_seconds()
if call_type == "completion" and stream == False:
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,
"response": response_obj["choices"][0]["message"]["content"],
"litellm_call_id": litellm_call_id,
"status": "success",
}
print_verbose(
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif call_type == "embedding":
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,
"response": str(response_obj["data"][0]["embedding"][:5]),
"litellm_call_id": litellm_call_id,
"status": "success",
}
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif call_type == "completion" and stream == True:
if len(response_obj["content"]) > 0: # don't log the empty strings
litellm_data_obj = {
"response_time": response_time,
"total_cost": total_cost,
"response": response_obj["content"],
"litellm_call_id": litellm_call_id,
"status": "success",
}
print_verbose(
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
elif "error" in response_obj:
if "Unable to map your input to a model." in response_obj["error"]:
total_cost = 0
litellm_data_obj = {
"response_time": response_time,
"model": response_obj["model"],
"total_cost": total_cost,
"error": response_obj["error"],
"end_user": end_user,
"litellm_call_id": litellm_call_id,
"status": "failure",
"user_email": self.user_email,
}
print_verbose(
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
)
response = requests.post(
url=self.api_url,
headers={"content-type": "application/json"},
data=json.dumps(litellm_data_obj),
)
print_verbose(f"LiteDebugger: api response - {response.text}")
except:
print_verbose(
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
)
pass

View file

@ -27,7 +27,7 @@ class LogfireLogger:
# Class variables or attributes # Class variables or attributes
def __init__(self): def __init__(self):
try: try:
verbose_logger.debug(f"in init logfire logger") verbose_logger.debug("in init logfire logger")
import logfire import logfire
# only setting up logfire if we are sending to logfire # only setting up logfire if we are sending to logfire
@ -116,7 +116,7 @@ class LogfireLogger:
id = response_obj.get("id", str(uuid.uuid4())) id = response_obj.get("id", str(uuid.uuid4()))
try: try:
response_time = (end_time - start_time).total_seconds() response_time = (end_time - start_time).total_seconds()
except: except Exception:
response_time = None response_time = None
# Clean Metadata before logging - never log raw metadata # Clean Metadata before logging - never log raw metadata

View file

@ -1,8 +1,8 @@
#### What this does #### #### What this does ####
# On success + failure, log events to lunary.ai # On success + failure, log events to lunary.ai
from datetime import datetime, timezone
import traceback
import importlib import importlib
import traceback
from datetime import datetime, timezone
import packaging import packaging
@ -74,9 +74,9 @@ class LunaryLogger:
try: try:
import lunary import lunary
version = importlib.metadata.version("lunary") version = importlib.metadata.version("lunary") # type: ignore
# if version < 0.1.43 then raise ImportError # if version < 0.1.43 then raise ImportError
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
print( # noqa print( # noqa
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'" "Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
) )
@ -97,7 +97,7 @@ class LunaryLogger:
run_id, run_id,
model, model,
print_verbose, print_verbose,
extra=None, extra={},
input=None, input=None,
user_id=None, user_id=None,
response_obj=None, response_obj=None,
@ -128,7 +128,7 @@ class LunaryLogger:
if not isinstance(value, (str, int, bool, float)) and param != "tools": if not isinstance(value, (str, int, bool, float)) and param != "tools":
try: try:
extra[param] = str(value) extra[param] = str(value)
except: except Exception:
pass pass
if response_obj: if response_obj:
@ -175,6 +175,6 @@ class LunaryLogger:
token_usage=usage, token_usage=usage,
) )
except: except Exception:
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}") print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -98,7 +98,7 @@ class OpenTelemetry(CustomLogger):
import logging import logging
logging.basicConfig(level=logging.DEBUG) logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__) logging.getLogger(__name__)
# Enable OpenTelemetry logging # Enable OpenTelemetry logging
otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export") otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
@ -520,7 +520,7 @@ class OpenTelemetry(CustomLogger):
def set_raw_request_attributes(self, span: Span, kwargs, response_obj): def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
from litellm.proxy._types import SpanAttributes from litellm.proxy._types import SpanAttributes
optional_params = kwargs.get("optional_params", {}) kwargs.get("optional_params", {})
litellm_params = kwargs.get("litellm_params", {}) or {} litellm_params = kwargs.get("litellm_params", {}) or {}
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown") custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
@ -769,6 +769,6 @@ class OpenTelemetry(CustomLogger):
management_endpoint_span.set_attribute(f"request.{key}", value) management_endpoint_span.set_attribute(f"request.{key}", value)
_exception = logging_payload.exception _exception = logging_payload.exception
management_endpoint_span.set_attribute(f"exception", str(_exception)) management_endpoint_span.set_attribute("exception", str(_exception))
management_endpoint_span.set_status(Status(StatusCode.ERROR)) management_endpoint_span.set_status(Status(StatusCode.ERROR))
management_endpoint_span.end(end_time=_end_time_ns) management_endpoint_span.end(end_time=_end_time_ns)

View file

@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
user_api_team_alias = standard_logging_payload["metadata"][ user_api_team_alias = standard_logging_payload["metadata"][
"user_api_key_team_alias" "user_api_key_team_alias"
] ]
exception = kwargs.get("exception", None) kwargs.get("exception", None)
try: try:
self.litellm_llm_api_failed_requests_metric.labels( self.litellm_llm_api_failed_requests_metric.labels(
@ -679,7 +679,7 @@ class PrometheusLogger(CustomLogger):
).inc() ).inc()
pass pass
except: except Exception:
pass pass
def set_llm_deployment_success_metrics( def set_llm_deployment_success_metrics(
@ -800,7 +800,7 @@ class PrometheusLogger(CustomLogger):
if ( if (
request_kwargs.get("stream", None) is not None request_kwargs.get("stream", None) is not None
and request_kwargs["stream"] == True and request_kwargs["stream"] is True
): ):
# only log ttft for streaming request # only log ttft for streaming request
time_to_first_token_response_time = ( time_to_first_token_response_time = (

View file

@ -3,11 +3,17 @@
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers) # On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
import dotenv, os import datetime
import requests # type: ignore import os
import subprocess
import sys
import traceback import traceback
import datetime, subprocess, sys import uuid
import litellm, uuid
import dotenv
import requests # type: ignore
import litellm
from litellm._logging import print_verbose, verbose_logger from litellm._logging import print_verbose, verbose_logger
from litellm.types.services import ServiceLoggerPayload, ServiceTypes from litellm.types.services import ServiceLoggerPayload, ServiceTypes
@ -23,7 +29,7 @@ class PrometheusServicesLogger:
): ):
try: try:
try: try:
from prometheus_client import Counter, Histogram, REGISTRY from prometheus_client import REGISTRY, Counter, Histogram
except ImportError: except ImportError:
raise Exception( raise Exception(
"Missing prometheus_client. Run `pip install prometheus-client`" "Missing prometheus_client. Run `pip install prometheus-client`"
@ -33,7 +39,7 @@ class PrometheusServicesLogger:
self.Counter = Counter self.Counter = Counter
self.REGISTRY = REGISTRY self.REGISTRY = REGISTRY
verbose_logger.debug(f"in init prometheus services metrics") verbose_logger.debug("in init prometheus services metrics")
self.services = [item.value for item in ServiceTypes] self.services = [item.value for item in ServiceTypes]

View file

@ -1,9 +1,11 @@
#### What this does #### #### What this does ####
# On success, logs events to Promptlayer # On success, logs events to Promptlayer
import dotenv, os import os
import traceback
import dotenv
import requests # type: ignore import requests # type: ignore
from pydantic import BaseModel from pydantic import BaseModel
import traceback
class PromptLayerLogger: class PromptLayerLogger:
@ -84,6 +86,6 @@ class PromptLayerLogger:
f"Prompt Layer Logging: success - metadata post response object: {response.text}" f"Prompt Layer Logging: success - metadata post response object: {response.text}"
) )
except: except Exception:
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}") print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -1,10 +1,15 @@
#### What this does #### #### What this does ####
# On success + failure, log events to Supabase # On success + failure, log events to Supabase
import dotenv, os import datetime
import requests # type: ignore import os
import subprocess
import sys
import traceback import traceback
import datetime, subprocess, sys
import dotenv
import requests # type: ignore
import litellm import litellm
@ -21,7 +26,12 @@ class Supabase:
except ImportError: except ImportError:
subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"]) subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
import supabase import supabase
self.supabase_client = supabase.create_client(
if self.supabase_url is None or self.supabase_key is None:
raise ValueError(
"LiteLLM Error, trying to use Supabase but url or key not passed. Create a table and set `litellm.supabase_url=<your-url>` and `litellm.supabase_key=<your-key>`"
)
self.supabase_client = supabase.create_client( # type: ignore
self.supabase_url, self.supabase_key self.supabase_url, self.supabase_key
) )
@ -45,7 +55,7 @@ class Supabase:
.execute() .execute()
) )
print_verbose(f"data: {data}") print_verbose(f"data: {data}")
except: except Exception:
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}") print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass pass
@ -109,6 +119,6 @@ class Supabase:
.execute() .execute()
) )
except: except Exception:
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}") print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
pass pass

View file

@ -167,18 +167,17 @@ try:
trace = self.results_to_trace_tree(request, response, results, time_elapsed) trace = self.results_to_trace_tree(request, response, results, time_elapsed)
return trace return trace
except: except Exception:
imported_openAIResponse = False imported_openAIResponse = False
#### What this does #### #### What this does ####
# On success, logs events to Langfuse # On success, logs events to Langfuse
import os import os
import requests import traceback
import requests
from datetime import datetime from datetime import datetime
import traceback import requests
class WeightsBiasesLogger: class WeightsBiasesLogger:
@ -186,11 +185,11 @@ class WeightsBiasesLogger:
def __init__(self): def __init__(self):
try: try:
import wandb import wandb
except: except Exception:
raise Exception( raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m" "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
) )
if imported_openAIResponse == False: if imported_openAIResponse is False:
raise Exception( raise Exception(
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m" "\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
) )
@ -209,13 +208,14 @@ class WeightsBiasesLogger:
kwargs, response_obj, (end_time - start_time).total_seconds() kwargs, response_obj, (end_time - start_time).total_seconds()
) )
if trace is not None: if trace is not None and run is not None:
run.log({"trace": trace}) run.log({"trace": trace})
run.finish() if run is not None:
print_verbose( run.finish()
f"W&B Logging Logging - final response object: {response_obj}" print_verbose(
) f"W&B Logging Logging - final response object: {response_obj}"
except: )
except Exception:
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}") print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
pass pass

View file

@ -62,7 +62,7 @@ def get_error_message(error_obj) -> Optional[str]:
# If all else fails, return None # If all else fails, return None
return None return None
except Exception as e: except Exception:
return None return None
@ -910,7 +910,7 @@ def exception_type( # type: ignore
): ):
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints", message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
model=model, model=model,
llm_provider="sagemaker", llm_provider="sagemaker",
response=original_exception.response, response=original_exception.response,
@ -1122,7 +1122,7 @@ def exception_type( # type: ignore
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate. # 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
exception_mapping_worked = True exception_mapping_worked = True
raise BadRequestError( raise BadRequestError(
message=f"GeminiException - Invalid api key", message="GeminiException - Invalid api key",
model=model, model=model,
llm_provider="palm", llm_provider="palm",
response=original_exception.response, response=original_exception.response,
@ -2067,12 +2067,34 @@ def exception_logging(
logger_fn( logger_fn(
model_call_details model_call_details
) # Expectation: any logger function passed in by the user should accept a dict object ) # Expectation: any logger function passed in by the user should accept a dict object
except Exception as e: except Exception:
verbose_logger.debug( verbose_logger.debug(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
) )
except Exception as e: except Exception:
verbose_logger.debug( verbose_logger.debug(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
) )
pass pass
def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
"""
Internal helper function for litellm proxy
Add the Key Name + Team Name to the error
Only gets added if the metadata contains the user_api_key_alias and user_api_key_team_alias
[Non-Blocking helper function]
"""
try:
_api_key_name = metadata.get("user_api_key_alias", None)
_user_api_key_team_alias = metadata.get("user_api_key_team_alias", None)
if _api_key_name is not None:
request_info = (
f"\n\nKey Name: `{_api_key_name}`\nTeam: `{_user_api_key_team_alias}`"
+ request_info
)
return request_info
except Exception:
return request_info

View file

@ -476,7 +476,7 @@ def get_llm_provider(
elif model == "*": elif model == "*":
custom_llm_provider = "openai" custom_llm_provider = "openai"
if custom_llm_provider is None or custom_llm_provider == "": if custom_llm_provider is None or custom_llm_provider == "":
if litellm.suppress_debug_info == False: if litellm.suppress_debug_info is False:
print() # noqa print() # noqa
print( # noqa print( # noqa
"\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa "\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa

View file

@ -52,18 +52,8 @@ from litellm.types.utils import (
) )
from litellm.utils import ( from litellm.utils import (
_get_base_model_from_metadata, _get_base_model_from_metadata,
add_breadcrumb,
capture_exception,
customLogger,
liteDebuggerClient,
logfireLogger,
lunaryLogger,
print_verbose, print_verbose,
prometheusLogger,
prompt_token_calculator, prompt_token_calculator,
promptLayerLogger,
supabaseClient,
weightsBiasesLogger,
) )
from ..integrations.aispend import AISpendLogger from ..integrations.aispend import AISpendLogger
@ -71,7 +61,6 @@ from ..integrations.athina import AthinaLogger
from ..integrations.berrispend import BerriSpendLogger from ..integrations.berrispend import BerriSpendLogger
from ..integrations.braintrust_logging import BraintrustLogger from ..integrations.braintrust_logging import BraintrustLogger
from ..integrations.clickhouse import ClickhouseLogger from ..integrations.clickhouse import ClickhouseLogger
from ..integrations.custom_logger import CustomLogger
from ..integrations.datadog.datadog import DataDogLogger from ..integrations.datadog.datadog import DataDogLogger
from ..integrations.dynamodb import DyanmoDBLogger from ..integrations.dynamodb import DyanmoDBLogger
from ..integrations.galileo import GalileoObserve from ..integrations.galileo import GalileoObserve
@ -423,7 +412,7 @@ class Logging:
elif callback == "sentry" and add_breadcrumb: elif callback == "sentry" and add_breadcrumb:
try: try:
details_to_log = copy.deepcopy(self.model_call_details) details_to_log = copy.deepcopy(self.model_call_details)
except: except Exception:
details_to_log = self.model_call_details details_to_log = self.model_call_details
if litellm.turn_off_message_logging: if litellm.turn_off_message_logging:
# make a copy of the _model_Call_details and log it # make a copy of the _model_Call_details and log it
@ -528,7 +517,7 @@ class Logging:
verbose_logger.debug("reaches sentry breadcrumbing") verbose_logger.debug("reaches sentry breadcrumbing")
try: try:
details_to_log = copy.deepcopy(self.model_call_details) details_to_log = copy.deepcopy(self.model_call_details)
except: except Exception:
details_to_log = self.model_call_details details_to_log = self.model_call_details
if litellm.turn_off_message_logging: if litellm.turn_off_message_logging:
# make a copy of the _model_Call_details and log it # make a copy of the _model_Call_details and log it
@ -1326,7 +1315,7 @@ class Logging:
and customLogger is not None and customLogger is not None
): # custom logger functions ): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" "success callbacks: Running Custom Callback Function"
) )
customLogger.log_event( customLogger.log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
@ -1400,7 +1389,7 @@ class Logging:
self.model_call_details["response_cost"] = 0.0 self.model_call_details["response_cost"] = 0.0
else: else:
# check if base_model set on azure # check if base_model set on azure
base_model = _get_base_model_from_metadata( _get_base_model_from_metadata(
model_call_details=self.model_call_details model_call_details=self.model_call_details
) )
# base_model defaults to None if not set on model_info # base_model defaults to None if not set on model_info
@ -1483,7 +1472,7 @@ class Logging:
for callback in callbacks: for callback in callbacks:
# check if callback can run for this request # check if callback can run for this request
litellm_params = self.model_call_details.get("litellm_params", {}) litellm_params = self.model_call_details.get("litellm_params", {})
if litellm_params.get("no-log", False) == True: if litellm_params.get("no-log", False) is True:
# proxy cost tracking cal backs should run # proxy cost tracking cal backs should run
if not ( if not (
isinstance(callback, CustomLogger) isinstance(callback, CustomLogger)
@ -1492,7 +1481,7 @@ class Logging:
print_verbose("no-log request, skipping logging") print_verbose("no-log request, skipping logging")
continue continue
try: try:
if kwargs.get("no-log", False) == True: if kwargs.get("no-log", False) is True:
print_verbose("no-log request, skipping logging") print_verbose("no-log request, skipping logging")
continue continue
if ( if (
@ -1641,7 +1630,7 @@ class Logging:
end_time=end_time, end_time=end_time,
print_verbose=print_verbose, print_verbose=print_verbose,
) )
except Exception as e: except Exception:
verbose_logger.error( verbose_logger.error(
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}" f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
) )
@ -2433,7 +2422,7 @@ def get_standard_logging_object_payload(
call_type = kwargs.get("call_type") call_type = kwargs.get("call_type")
cache_hit = kwargs.get("cache_hit", False) cache_hit = kwargs.get("cache_hit", False)
usage = response_obj.get("usage", None) or {} usage = response_obj.get("usage", None) or {}
if type(usage) == litellm.Usage: if type(usage) is litellm.Usage:
usage = dict(usage) usage = dict(usage)
id = response_obj.get("id", kwargs.get("litellm_call_id")) id = response_obj.get("id", kwargs.get("litellm_call_id"))
@ -2656,3 +2645,11 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
litellm_params["metadata"] = metadata litellm_params["metadata"] = metadata
return litellm_params return litellm_params
# integration helper function
def modify_integration(integration_name, integration_params):
global supabaseClient
if integration_name == "supabase":
if "table_name" in integration_params:
Supabase.supabase_table_name = integration_params["table_name"]

View file

@ -45,7 +45,7 @@ def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
model_info = litellm.get_model_info( model_info = litellm.get_model_info(
model=model, custom_llm_provider=custom_llm_provider model=model, custom_llm_provider=custom_llm_provider
) )
except: except Exception:
continue continue
if model_info.get("mode") != "chat": if model_info.get("mode") != "chat":
continue continue

View file

@ -123,7 +123,7 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -165,7 +165,7 @@ def completion(
) )
if response.status_code != 200: if response.status_code != 200:
raise AI21Error(status_code=response.status_code, message=response.text) raise AI21Error(status_code=response.status_code, message=response.text)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -191,7 +191,7 @@ def completion(
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore model_response.choices = choices_list # type: ignore
except Exception as e: except Exception:
raise AI21Error( raise AI21Error(
message=traceback.format_exc(), status_code=response.status_code message=traceback.format_exc(), status_code=response.status_code
) )

View file

@ -151,7 +151,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None, client=None,
aget_assistants=None, aget_assistants=None,
): ):
if aget_assistants is not None and aget_assistants == True: if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants( return self.async_get_assistants(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -260,7 +260,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None, client=None,
a_add_message: Optional[bool] = None, a_add_message: Optional[bool] = None,
): ):
if a_add_message is not None and a_add_message == True: if a_add_message is not None and a_add_message is True:
return self.a_add_message( return self.a_add_message(
thread_id=thread_id, thread_id=thread_id,
message_data=message_data, message_data=message_data,
@ -365,7 +365,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None, client=None,
aget_messages=None, aget_messages=None,
): ):
if aget_messages is not None and aget_messages == True: if aget_messages is not None and aget_messages is True:
return self.async_get_messages( return self.async_get_messages(
thread_id=thread_id, thread_id=thread_id,
api_key=api_key, api_key=api_key,
@ -483,7 +483,7 @@ class AzureAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message]) openai_api.create_thread(messages=[message])
``` ```
""" """
if acreate_thread is not None and acreate_thread == True: if acreate_thread is not None and acreate_thread is True:
return self.async_create_thread( return self.async_create_thread(
metadata=metadata, metadata=metadata,
api_key=api_key, api_key=api_key,
@ -586,7 +586,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None, client=None,
aget_thread=None, aget_thread=None,
): ):
if aget_thread is not None and aget_thread == True: if aget_thread is not None and aget_thread is True:
return self.async_get_thread( return self.async_get_thread(
thread_id=thread_id, thread_id=thread_id,
api_key=api_key, api_key=api_key,
@ -774,8 +774,8 @@ class AzureAssistantsAPI(BaseLLM):
arun_thread=None, arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None, event_handler: Optional[AssistantEventHandler] = None,
): ):
if arun_thread is not None and arun_thread == True: if arun_thread is not None and arun_thread is True:
if stream is not None and stream == True: if stream is not None and stream is True:
azure_client = self.async_get_azure_client( azure_client = self.async_get_azure_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -823,7 +823,7 @@ class AzureAssistantsAPI(BaseLLM):
client=client, client=client,
) )
if stream is not None and stream == True: if stream is not None and stream is True:
return self.run_thread_stream( return self.run_thread_stream(
client=openai_client, client=openai_client,
thread_id=thread_id, thread_id=thread_id,
@ -887,7 +887,7 @@ class AzureAssistantsAPI(BaseLLM):
client=None, client=None,
async_create_assistants=None, async_create_assistants=None,
): ):
if async_create_assistants is not None and async_create_assistants == True: if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants( return self.async_create_assistants(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -950,7 +950,7 @@ class AzureAssistantsAPI(BaseLLM):
async_delete_assistants: Optional[bool] = None, async_delete_assistants: Optional[bool] = None,
client=None, client=None,
): ):
if async_delete_assistants is not None and async_delete_assistants == True: if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant( return self.async_delete_assistant(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,

View file

@ -317,7 +317,7 @@ class AzureOpenAIAssistantsAPIConfig:
if "file_id" in item: if "file_id" in item:
file_ids.append(item["file_id"]) file_ids.append(item["file_id"])
else: else:
if litellm.drop_params == True: if litellm.drop_params is True:
pass pass
else: else:
raise litellm.utils.UnsupportedParamsError( raise litellm.utils.UnsupportedParamsError(
@ -580,7 +580,7 @@ class AzureChatCompletion(BaseLLM):
try: try:
if model is None or messages is None: if model is None or messages is None:
raise AzureOpenAIError( raise AzureOpenAIError(
status_code=422, message=f"Missing model or messages" status_code=422, message="Missing model or messages"
) )
max_retries = optional_params.pop("max_retries", 2) max_retries = optional_params.pop("max_retries", 2)
@ -1240,12 +1240,6 @@ class AzureChatCompletion(BaseLLM):
) )
while response.json()["status"] not in ["succeeded", "failed"]: while response.json()["status"] not in ["succeeded", "failed"]:
if time.time() - start_time > timeout_secs: if time.time() - start_time > timeout_secs:
timeout_msg = {
"error": {
"code": "Timeout",
"message": "Operation polling timed out.",
}
}
raise AzureOpenAIError( raise AzureOpenAIError(
status_code=408, message="Operation polling timed out." status_code=408, message="Operation polling timed out."
@ -1493,7 +1487,6 @@ class AzureChatCompletion(BaseLLM):
client=None, client=None,
aimg_generation=None, aimg_generation=None,
): ):
exception_mapping_worked = False
try: try:
if model and len(model) > 0: if model and len(model) > 0:
model = model model = model
@ -1534,7 +1527,7 @@ class AzureChatCompletion(BaseLLM):
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token) azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
azure_client_params["azure_ad_token"] = azure_ad_token azure_client_params["azure_ad_token"] = azure_ad_token
if aimg_generation == True: if aimg_generation is True:
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
return response return response

View file

@ -1263,7 +1263,6 @@ class OpenAIChatCompletion(BaseLLM):
client=None, client=None,
aimg_generation=None, aimg_generation=None,
): ):
exception_mapping_worked = False
data = {} data = {}
try: try:
model = model model = model
@ -1272,7 +1271,7 @@ class OpenAIChatCompletion(BaseLLM):
if not isinstance(max_retries, int): if not isinstance(max_retries, int):
raise OpenAIError(status_code=422, message="max retries must be an int") raise OpenAIError(status_code=422, message="max retries must be an int")
if aimg_generation == True: if aimg_generation is True:
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
return response return response
@ -1311,7 +1310,6 @@ class OpenAIChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(
input=prompt, input=prompt,
@ -1543,7 +1541,7 @@ class OpenAITextCompletion(BaseLLM):
if ( if (
len(messages) > 0 len(messages) > 0
and "content" in messages[0] and "content" in messages[0]
and type(messages[0]["content"]) == list and isinstance(messages[0]["content"], list)
): ):
prompt = messages[0]["content"] prompt = messages[0]["content"]
else: else:
@ -2413,7 +2411,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None, client=None,
aget_assistants=None, aget_assistants=None,
): ):
if aget_assistants is not None and aget_assistants == True: if aget_assistants is not None and aget_assistants is True:
return self.async_get_assistants( return self.async_get_assistants(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -2470,7 +2468,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None, client=None,
async_create_assistants=None, async_create_assistants=None,
): ):
if async_create_assistants is not None and async_create_assistants == True: if async_create_assistants is not None and async_create_assistants is True:
return self.async_create_assistants( return self.async_create_assistants(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -2527,7 +2525,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None, client=None,
async_delete_assistants=None, async_delete_assistants=None,
): ):
if async_delete_assistants is not None and async_delete_assistants == True: if async_delete_assistants is not None and async_delete_assistants is True:
return self.async_delete_assistant( return self.async_delete_assistant(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -2629,7 +2627,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None, client=None,
a_add_message: Optional[bool] = None, a_add_message: Optional[bool] = None,
): ):
if a_add_message is not None and a_add_message == True: if a_add_message is not None and a_add_message is True:
return self.a_add_message( return self.a_add_message(
thread_id=thread_id, thread_id=thread_id,
message_data=message_data, message_data=message_data,
@ -2727,7 +2725,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None, client=None,
aget_messages=None, aget_messages=None,
): ):
if aget_messages is not None and aget_messages == True: if aget_messages is not None and aget_messages is True:
return self.async_get_messages( return self.async_get_messages(
thread_id=thread_id, thread_id=thread_id,
api_key=api_key, api_key=api_key,
@ -2838,7 +2836,7 @@ class OpenAIAssistantsAPI(BaseLLM):
openai_api.create_thread(messages=[message]) openai_api.create_thread(messages=[message])
``` ```
""" """
if acreate_thread is not None and acreate_thread == True: if acreate_thread is not None and acreate_thread is True:
return self.async_create_thread( return self.async_create_thread(
metadata=metadata, metadata=metadata,
api_key=api_key, api_key=api_key,
@ -2934,7 +2932,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=None, client=None,
aget_thread=None, aget_thread=None,
): ):
if aget_thread is not None and aget_thread == True: if aget_thread is not None and aget_thread is True:
return self.async_get_thread( return self.async_get_thread(
thread_id=thread_id, thread_id=thread_id,
api_key=api_key, api_key=api_key,
@ -3117,8 +3115,8 @@ class OpenAIAssistantsAPI(BaseLLM):
arun_thread=None, arun_thread=None,
event_handler: Optional[AssistantEventHandler] = None, event_handler: Optional[AssistantEventHandler] = None,
): ):
if arun_thread is not None and arun_thread == True: if arun_thread is not None and arun_thread is True:
if stream is not None and stream == True: if stream is not None and stream is True:
_client = self.async_get_openai_client( _client = self.async_get_openai_client(
api_key=api_key, api_key=api_key,
api_base=api_base, api_base=api_base,
@ -3163,7 +3161,7 @@ class OpenAIAssistantsAPI(BaseLLM):
client=client, client=client,
) )
if stream is not None and stream == True: if stream is not None and stream is True:
return self.run_thread_stream( return self.run_thread_stream(
client=openai_client, client=openai_client,
thread_id=thread_id, thread_id=thread_id,

View file

@ -191,7 +191,7 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
default_max_tokens_to_sample=None, default_max_tokens_to_sample=None,
@ -246,7 +246,7 @@ def completion(
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False, stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -279,7 +279,7 @@ def completion(
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore model_response.choices = choices_list # type: ignore
except: except Exception:
raise AlephAlphaError( raise AlephAlphaError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
status_code=response.status_code, status_code=response.status_code,

View file

@ -607,7 +607,6 @@ class ModelResponseIterator:
def _handle_usage( def _handle_usage(
self, anthropic_usage_chunk: Union[dict, UsageDelta] self, anthropic_usage_chunk: Union[dict, UsageDelta]
) -> AnthropicChatCompletionUsageBlock: ) -> AnthropicChatCompletionUsageBlock:
special_fields = ["input_tokens", "output_tokens"]
usage_block = AnthropicChatCompletionUsageBlock( usage_block = AnthropicChatCompletionUsageBlock(
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0), prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
@ -683,7 +682,7 @@ class ModelResponseIterator:
"index": self.tool_index, "index": self.tool_index,
} }
elif type_chunk == "content_block_stop": elif type_chunk == "content_block_stop":
content_block_stop = ContentBlockStop(**chunk) # type: ignore ContentBlockStop(**chunk) # type: ignore
# check if tool call content block # check if tool call content block
is_empty = self.check_empty_tool_call_args() is_empty = self.check_empty_tool_call_args()
if is_empty: if is_empty:

View file

@ -114,7 +114,7 @@ class AnthropicTextCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
except: except Exception:
raise AnthropicError( raise AnthropicError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
) )
@ -229,7 +229,7 @@ class AnthropicTextCompletion(BaseLLM):
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
headers={}, headers={},
@ -276,8 +276,8 @@ class AnthropicTextCompletion(BaseLLM):
) )
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
if acompletion == True: if acompletion is True:
return self.async_streaming( return self.async_streaming(
model=model, model=model,
api_base=api_base, api_base=api_base,
@ -309,7 +309,7 @@ class AnthropicTextCompletion(BaseLLM):
logging_obj=logging_obj, logging_obj=logging_obj,
) )
return stream_response return stream_response
elif acompletion == True: elif acompletion is True:
return self.async_completion( return self.async_completion(
model=model, model=model,
model_response=model_response, model_response=model_response,

View file

@ -233,7 +233,7 @@ class AzureTextCompletion(BaseLLM):
client=client, client=client,
logging_obj=logging_obj, logging_obj=logging_obj,
) )
elif "stream" in optional_params and optional_params["stream"] == True: elif "stream" in optional_params and optional_params["stream"] is True:
return self.streaming( return self.streaming(
logging_obj=logging_obj, logging_obj=logging_obj,
api_base=api_base, api_base=api_base,

View file

@ -36,7 +36,7 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -59,7 +59,7 @@ def completion(
"parameters": optional_params, "parameters": optional_params,
"stream": ( "stream": (
True True
if "stream" in optional_params and optional_params["stream"] == True if "stream" in optional_params and optional_params["stream"] is True
else False else False
), ),
} }
@ -77,12 +77,12 @@ def completion(
data=json.dumps(data), data=json.dumps(data),
stream=( stream=(
True True
if "stream" in optional_params and optional_params["stream"] == True if "stream" in optional_params and optional_params["stream"] is True
else False else False
), ),
) )
if "text/event-stream" in response.headers["Content-Type"] or ( if "text/event-stream" in response.headers["Content-Type"] or (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] is True
): ):
return response.iter_lines() return response.iter_lines()
else: else:

View file

@ -183,7 +183,7 @@ class BedrockConverseLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text) raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.") raise BedrockError(status_code=408, message="Timeout error occurred.")
return litellm.AmazonConverseConfig()._transform_response( return litellm.AmazonConverseConfig()._transform_response(

View file

@ -251,9 +251,7 @@ class AmazonConverseConfig:
supported_converse_params = AmazonConverseConfig.__annotations__.keys() supported_converse_params = AmazonConverseConfig.__annotations__.keys()
supported_tool_call_params = ["tools", "tool_choice"] supported_tool_call_params = ["tools", "tool_choice"]
supported_guardrail_params = ["guardrailConfig"] supported_guardrail_params = ["guardrailConfig"]
json_mode: Optional[bool] = inference_params.pop( inference_params.pop("json_mode", None) # used for handling json_schema
"json_mode", None
) # used for handling json_schema
## TRANSFORMATION ## ## TRANSFORMATION ##
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt( bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(

View file

@ -234,7 +234,7 @@ async def make_call(
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text) raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.") raise BedrockError(status_code=408, message="Timeout error occurred.")
except Exception as e: except Exception as e:
raise BedrockError(status_code=500, message=str(e)) raise BedrockError(status_code=500, message=str(e))
@ -335,7 +335,7 @@ class BedrockLLM(BaseAWSLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
except: except Exception:
raise BedrockError(message=response.text, status_code=422) raise BedrockError(message=response.text, status_code=422)
outputText: Optional[str] = None outputText: Optional[str] = None
@ -394,12 +394,12 @@ class BedrockLLM(BaseAWSLLM):
outputText # allow user to access raw anthropic tool calling response outputText # allow user to access raw anthropic tool calling response
) )
if ( if (
_is_function_call == True _is_function_call is True
and stream is not None and stream is not None
and stream == True and stream is True
): ):
print_verbose( print_verbose(
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK" "INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
) )
# return an iterator # return an iterator
streaming_model_response = ModelResponse(stream=True) streaming_model_response = ModelResponse(stream=True)
@ -440,7 +440,7 @@ class BedrockLLM(BaseAWSLLM):
model_response=streaming_model_response model_response=streaming_model_response
) )
print_verbose( print_verbose(
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object" "Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
) )
return litellm.CustomStreamWrapper( return litellm.CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,
@ -597,7 +597,7 @@ class BedrockLLM(BaseAWSLLM):
from botocore.auth import SigV4Auth from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError as e: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## SETUP ## ## SETUP ##
@ -700,7 +700,7 @@ class BedrockLLM(BaseAWSLLM):
k not in inference_params k not in inference_params
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in ): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
inference_params[k] = v inference_params[k] = v
if stream == True: if stream is True:
inference_params["stream"] = ( inference_params["stream"] = (
True # cohere requires stream = True in inference params True # cohere requires stream = True in inference params
) )
@ -845,7 +845,7 @@ class BedrockLLM(BaseAWSLLM):
if acompletion: if acompletion:
if isinstance(client, HTTPHandler): if isinstance(client, HTTPHandler):
client = None client = None
if stream == True and provider != "ai21": if stream is True and provider != "ai21":
return self.async_streaming( return self.async_streaming(
model=model, model=model,
messages=messages, messages=messages,
@ -891,7 +891,7 @@ class BedrockLLM(BaseAWSLLM):
self.client = _get_httpx_client(_params) # type: ignore self.client = _get_httpx_client(_params) # type: ignore
else: else:
self.client = client self.client = client
if (stream is not None and stream == True) and provider != "ai21": if (stream is not None and stream is True) and provider != "ai21":
response = self.client.post( response = self.client.post(
url=proxy_endpoint_url, url=proxy_endpoint_url,
headers=prepped.headers, # type: ignore headers=prepped.headers, # type: ignore
@ -929,7 +929,7 @@ class BedrockLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text) raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.") raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response( return self.process_response(
@ -980,7 +980,7 @@ class BedrockLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise BedrockError(status_code=error_code, message=err.response.text) raise BedrockError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise BedrockError(status_code=408, message="Timeout error occurred.") raise BedrockError(status_code=408, message="Timeout error occurred.")
return self.process_response( return self.process_response(

View file

@ -260,7 +260,7 @@ class AmazonAnthropicConfig:
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "stop": if param == "stop":
optional_params["stop_sequences"] = value optional_params["stop_sequences"] = value
if param == "stream" and value == True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
return optional_params return optional_params

View file

@ -39,7 +39,7 @@ class BedrockEmbedding(BaseAWSLLM):
) -> Tuple[Any, str]: ) -> Tuple[Any, str]:
try: try:
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError as e: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ## ## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them

View file

@ -130,7 +130,7 @@ def process_response(
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore model_response.choices = choices_list # type: ignore
except Exception as e: except Exception:
raise ClarifaiError( raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model message=traceback.format_exc(), status_code=response.status_code, url=model
) )
@ -219,7 +219,7 @@ async def async_completion(
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore model_response.choices = choices_list # type: ignore
except Exception as e: except Exception:
raise ClarifaiError( raise ClarifaiError(
message=traceback.format_exc(), status_code=response.status_code, url=model message=traceback.format_exc(), status_code=response.status_code, url=model
) )
@ -251,9 +251,9 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
acompletion=False, acompletion=False,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -268,20 +268,12 @@ def completion(
optional_params[k] = v optional_params[k] = v
custom_llm_provider, orig_model_name = get_prompt_model_name(model) custom_llm_provider, orig_model_name = get_prompt_model_name(model)
if custom_llm_provider == "anthropic": prompt: str = prompt_factory( # type: ignore
prompt = prompt_factory( model=orig_model_name,
model=orig_model_name, messages=messages,
messages=messages, api_key=api_key,
api_key=api_key, custom_llm_provider="clarifai",
custom_llm_provider="clarifai", )
)
else:
prompt = prompt_factory(
model=orig_model_name,
messages=messages,
api_key=api_key,
custom_llm_provider=custom_llm_provider,
)
# print(prompt); exit(0) # print(prompt); exit(0)
data = { data = {
@ -300,7 +292,7 @@ def completion(
"api_base": model, "api_base": model,
}, },
) )
if acompletion == True: if acompletion is True:
return async_completion( return async_completion(
model=model, model=model,
prompt=prompt, prompt=prompt,
@ -331,7 +323,7 @@ def completion(
status_code=response.status_code, message=response.text, url=model status_code=response.status_code, message=response.text, url=model
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
completion_stream = response.iter_lines() completion_stream = response.iter_lines()
stream_response = CustomStreamWrapper( stream_response = CustomStreamWrapper(
completion_stream=completion_stream, completion_stream=completion_stream,

View file

@ -80,8 +80,8 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -97,7 +97,7 @@ def completion(
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( custom_prompt(
role_dict=model_prompt_details.get("roles", {}), role_dict=model_prompt_details.get("roles", {}),
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""), initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
final_prompt_value=model_prompt_details.get("final_prompt_value", ""), final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
@ -126,7 +126,7 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
response = requests.post( response = requests.post(
api_base, api_base,
headers=headers, headers=headers,

View file

@ -268,7 +268,7 @@ def completion(
if response.status_code != 200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -283,12 +283,12 @@ def completion(
completion_response = response.json() completion_response = response.json()
try: try:
model_response.choices[0].message.content = completion_response["text"] # type: ignore model_response.choices[0].message.content = completion_response["text"] # type: ignore
except Exception as e: except Exception:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
## Tool calling response ## Tool calling response
cohere_tools_response = completion_response.get("tool_calls", None) cohere_tools_response = completion_response.get("tool_calls", None)
if cohere_tools_response is not None and cohere_tools_response is not []: if cohere_tools_response is not None and cohere_tools_response != []:
# convert cohere_tools_response to OpenAI response format # convert cohere_tools_response to OpenAI response format
tool_calls = [] tool_calls = []
for tool in cohere_tools_response: for tool in cohere_tools_response:

View file

@ -146,7 +146,7 @@ def completion(
api_key, api_key,
logging_obj, logging_obj,
headers: dict, headers: dict,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -198,7 +198,7 @@ def completion(
if response.status_code != 200: if response.status_code != 200:
raise CohereError(message=response.text, status_code=response.status_code) raise CohereError(message=response.text, status_code=response.status_code)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -231,7 +231,7 @@ def completion(
) )
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore model_response.choices = choices_list # type: ignore
except Exception as e: except Exception:
raise CohereError( raise CohereError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
) )

View file

@ -17,7 +17,7 @@ else:
try: try:
from litellm._version import version from litellm._version import version
except: except Exception:
version = "0.0.0" version = "0.0.0"
headers = { headers = {

View file

@ -1,15 +1,17 @@
from typing import Optional from typing import Optional
import httpx import httpx
try: try:
from litellm._version import version from litellm._version import version
except: except Exception:
version = "0.0.0" version = "0.0.0"
headers = { headers = {
"User-Agent": f"litellm/{version}", "User-Agent": f"litellm/{version}",
} }
class HTTPHandler: class HTTPHandler:
def __init__(self, concurrent_limit=1000): def __init__(self, concurrent_limit=1000):
# Create a client with a connection pool # Create a client with a connection pool

View file

@ -113,7 +113,7 @@ class DatabricksConfig:
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "n": if param == "n":
optional_params["n"] = value optional_params["n"] = value
if param == "stream" and value == True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "temperature": if param == "temperature":
optional_params["temperature"] = value optional_params["temperature"] = value
@ -564,7 +564,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code, status_code=e.response.status_code,
message=e.response.text, message=e.response.text,
) )
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise DatabricksError( raise DatabricksError(
status_code=408, message="Timeout error occurred." status_code=408, message="Timeout error occurred."
) )
@ -614,7 +614,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code, status_code=e.response.status_code,
message=response.text if response else str(e), message=response.text if response else str(e),
) )
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise DatabricksError( raise DatabricksError(
status_code=408, message="Timeout error occurred." status_code=408, message="Timeout error occurred."
) )
@ -669,7 +669,7 @@ class DatabricksChatCompletion(BaseLLM):
additional_args={"complete_input_dict": data, "api_base": api_base}, additional_args={"complete_input_dict": data, "api_base": api_base},
) )
if aembedding == True: if aembedding is True:
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
if client is None or isinstance(client, AsyncHTTPHandler): if client is None or isinstance(client, AsyncHTTPHandler):
self.client = HTTPHandler(timeout=timeout) # type: ignore self.client = HTTPHandler(timeout=timeout) # type: ignore
@ -692,7 +692,7 @@ class DatabricksChatCompletion(BaseLLM):
status_code=e.response.status_code, status_code=e.response.status_code,
message=e.response.text, message=e.response.text,
) )
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise DatabricksError(status_code=408, message="Timeout error occurred.") raise DatabricksError(status_code=408, message="Timeout error occurred.")
except Exception as e: except Exception as e:
raise DatabricksError(status_code=500, message=str(e)) raise DatabricksError(status_code=500, message=str(e))

View file

@ -71,7 +71,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self, self,
_is_async: bool, _is_async: bool,
create_file_data: CreateFileRequest, create_file_data: CreateFileRequest,
api_base: str, api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
api_version: Optional[str], api_version: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
@ -117,7 +117,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self, self,
_is_async: bool, _is_async: bool,
file_content_request: FileContentRequest, file_content_request: FileContentRequest,
api_base: str, api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
@ -168,7 +168,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self, self,
_is_async: bool, _is_async: bool,
file_id: str, file_id: str,
api_base: str, api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
@ -220,7 +220,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
self, self,
_is_async: bool, _is_async: bool,
file_id: str, file_id: str,
api_base: str, api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],
@ -275,7 +275,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
def list_files( def list_files(
self, self,
_is_async: bool, _is_async: bool,
api_base: str, api_base: Optional[str],
api_key: Optional[str], api_key: Optional[str],
timeout: Union[float, httpx.Timeout], timeout: Union[float, httpx.Timeout],
max_retries: Optional[int], max_retries: Optional[int],

View file

@ -41,7 +41,7 @@ class VertexFineTuningAPI(VertexLLM):
created_at = int(create_time_datetime.timestamp()) created_at = int(create_time_datetime.timestamp())
return created_at return created_at
except Exception as e: except Exception:
return 0 return 0
def convert_vertex_response_to_open_ai_response( def convert_vertex_response_to_open_ai_response(

View file

@ -136,7 +136,7 @@ class GeminiConfig:
# ): # ):
# try: # try:
# import google.generativeai as genai # type: ignore # import google.generativeai as genai # type: ignore
# except: # except Exception:
# raise Exception( # raise Exception(
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai" # "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
# ) # )
@ -282,7 +282,7 @@ class GeminiConfig:
# completion_response = model_response["choices"][0]["message"].get("content") # completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None: # if completion_response is None:
# raise Exception # raise Exception
# except: # except Exception:
# original_response = f"response: {response}" # original_response = f"response: {response}"
# if hasattr(response, "candidates"): # if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}" # original_response = f"response: {response.candidates}"
@ -374,7 +374,7 @@ class GeminiConfig:
# completion_response = model_response["choices"][0]["message"].get("content") # completion_response = model_response["choices"][0]["message"].get("content")
# if completion_response is None: # if completion_response is None:
# raise Exception # raise Exception
# except: # except Exception:
# original_response = f"response: {response}" # original_response = f"response: {response}"
# if hasattr(response, "candidates"): # if hasattr(response, "candidates"):
# original_response = f"response: {response.candidates}" # original_response = f"response: {response.candidates}"

View file

@ -13,6 +13,7 @@ import requests
import litellm import litellm
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
from litellm.secret_managers.main import get_secret_str
from litellm.types.completion import ChatCompletionMessageToolCallParam from litellm.types.completion import ChatCompletionMessageToolCallParam
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
@ -181,7 +182,7 @@ class HuggingfaceConfig:
return optional_params return optional_params
def get_hf_api_key(self) -> Optional[str]: def get_hf_api_key(self) -> Optional[str]:
return litellm.utils.get_secret("HUGGINGFACE_API_KEY") return get_secret_str("HUGGINGFACE_API_KEY")
def output_parser(generated_text: str): def output_parser(generated_text: str):
@ -240,7 +241,7 @@ def read_tgi_conv_models():
# Cache the set for future use # Cache the set for future use
conv_models_cache = conv_models conv_models_cache = conv_models
return tgi_models, conv_models return tgi_models, conv_models
except: except Exception:
return set(), set() return set(), set()
@ -372,7 +373,7 @@ class Huggingface(BaseLLM):
]["finish_reason"] ]["finish_reason"]
sum_logprob = 0 sum_logprob = 0
for token in completion_response[0]["details"]["tokens"]: for token in completion_response[0]["details"]["tokens"]:
if token["logprob"] != None: if token["logprob"] is not None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
if "best_of" in optional_params and optional_params["best_of"] > 1: if "best_of" in optional_params and optional_params["best_of"] > 1:
@ -386,7 +387,7 @@ class Huggingface(BaseLLM):
): ):
sum_logprob = 0 sum_logprob = 0
for token in item["tokens"]: for token in item["tokens"]:
if token["logprob"] != None: if token["logprob"] is not None:
sum_logprob += token["logprob"] sum_logprob += token["logprob"]
if len(item["generated_text"]) > 0: if len(item["generated_text"]) > 0:
message_obj = Message( message_obj = Message(
@ -417,7 +418,7 @@ class Huggingface(BaseLLM):
prompt_tokens = len( prompt_tokens = len(
encoding.encode(input_text) encoding.encode(input_text)
) ##[TODO] use the llama2 tokenizer here ) ##[TODO] use the llama2 tokenizer here
except: except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails # this should remain non blocking we should not block a response returning if calculating usage fails
pass pass
output_text = model_response["choices"][0]["message"].get("content", "") output_text = model_response["choices"][0]["message"].get("content", "")
@ -429,7 +430,7 @@ class Huggingface(BaseLLM):
model_response["choices"][0]["message"].get("content", "") model_response["choices"][0]["message"].get("content", "")
) )
) ##[TODO] use the llama2 tokenizer here ) ##[TODO] use the llama2 tokenizer here
except: except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails # this should remain non blocking we should not block a response returning if calculating usage fails
pass pass
else: else:
@ -559,7 +560,7 @@ class Huggingface(BaseLLM):
True True
if "stream" in optional_params if "stream" in optional_params
and isinstance(optional_params["stream"], bool) and isinstance(optional_params["stream"], bool)
and optional_params["stream"] == True # type: ignore and optional_params["stream"] is True # type: ignore
else False else False
), ),
} }
@ -595,7 +596,7 @@ class Huggingface(BaseLLM):
data["stream"] = ( # type: ignore data["stream"] = ( # type: ignore
True # type: ignore True # type: ignore
if "stream" in optional_params if "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] is True
else False else False
) )
input_text = prompt input_text = prompt
@ -631,7 +632,7 @@ class Huggingface(BaseLLM):
### ASYNC COMPLETION ### ASYNC COMPLETION
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
### SYNC STREAMING ### SYNC STREAMING
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
response = requests.post( response = requests.post(
completion_url, completion_url,
headers=headers, headers=headers,
@ -691,7 +692,7 @@ class Huggingface(BaseLLM):
completion_response = response.json() completion_response = response.json()
if isinstance(completion_response, dict): if isinstance(completion_response, dict):
completion_response = [completion_response] completion_response = [completion_response]
except: except Exception:
import traceback import traceback
raise HuggingfaceError( raise HuggingfaceError(

View file

@ -101,7 +101,7 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -135,7 +135,7 @@ def completion(
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False, stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -159,7 +159,7 @@ def completion(
model_response.choices[0].message.content = completion_response[ # type: ignore model_response.choices[0].message.content = completion_response[ # type: ignore
"answer" "answer"
] ]
except Exception as e: except Exception:
raise MaritalkError( raise MaritalkError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
) )

View file

@ -120,7 +120,7 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
default_max_tokens_to_sample=None, default_max_tokens_to_sample=None,
@ -164,7 +164,7 @@ def completion(
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False, stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return clean_and_iterate_chunks(response) return clean_and_iterate_chunks(response)
else: else:
## LOGGING ## LOGGING
@ -178,7 +178,7 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
except: except Exception:
raise NLPCloudError(message=response.text, status_code=response.status_code) raise NLPCloudError(message=response.text, status_code=response.status_code)
if "error" in completion_response: if "error" in completion_response:
raise NLPCloudError( raise NLPCloudError(
@ -191,7 +191,7 @@ def completion(
model_response.choices[0].message.content = ( # type: ignore model_response.choices[0].message.content = ( # type: ignore
completion_response["generated_text"] completion_response["generated_text"]
) )
except: except Exception:
raise NLPCloudError( raise NLPCloudError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
status_code=response.status_code, status_code=response.status_code,

View file

@ -14,7 +14,7 @@ import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.utils import ProviderField from litellm.types.utils import ProviderField, StreamingChoices
from .prompt_templates.factory import custom_prompt, prompt_factory from .prompt_templates.factory import custom_prompt, prompt_factory
@ -172,7 +172,7 @@ def _convert_image(image):
try: try:
from PIL import Image from PIL import Image
except: except Exception:
raise Exception( raise Exception(
"ollama image conversion failed please run `pip install Pillow`" "ollama image conversion failed please run `pip install Pillow`"
) )
@ -184,7 +184,7 @@ def _convert_image(image):
image_data = Image.open(io.BytesIO(base64.b64decode(image))) image_data = Image.open(io.BytesIO(base64.b64decode(image)))
if image_data.format in ["JPEG", "PNG"]: if image_data.format in ["JPEG", "PNG"]:
return image return image
except: except Exception:
return orig return orig
jpeg_image = io.BytesIO() jpeg_image = io.BytesIO()
image_data.convert("RGB").save(jpeg_image, "JPEG") image_data.convert("RGB").save(jpeg_image, "JPEG")
@ -195,13 +195,13 @@ def _convert_image(image):
# ollama implementation # ollama implementation
def get_ollama_response( def get_ollama_response(
model_response: litellm.ModelResponse, model_response: litellm.ModelResponse,
api_base="http://localhost:11434", model: str,
model="llama2", prompt: str,
prompt="Why is the sky blue?", optional_params: dict,
optional_params=None, logging_obj: Any,
logging_obj=None, encoding: Any,
acompletion: bool = False, acompletion: bool = False,
encoding=None, api_base="http://localhost:11434",
): ):
if api_base.endswith("/api/generate"): if api_base.endswith("/api/generate"):
url = api_base url = api_base
@ -242,7 +242,7 @@ def get_ollama_response(
}, },
) )
if acompletion is True: if acompletion is True:
if stream == True: if stream is True:
response = ollama_async_streaming( response = ollama_async_streaming(
url=url, url=url,
data=data, data=data,
@ -340,11 +340,16 @@ def ollama_completion_stream(url, data, logging_obj):
# Gather all chunks and return the function call as one delta to simplify parsing # Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json": if data.get("format", "") == "json":
first_chunk = next(streamwrapper) first_chunk = next(streamwrapper)
response_content = "".join( content_chunks = []
chunk.choices[0].delta.content for chunk in chain([first_chunk], streamwrapper):
for chunk in chain([first_chunk], streamwrapper) content_chunk = chunk.choices[0]
if chunk.choices[0].delta.content if (
) isinstance(content_chunk, StreamingChoices)
and hasattr(content_chunk, "delta")
and hasattr(content_chunk.delta, "content")
):
content_chunks.append(content_chunk.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content) function_call = json.loads(response_content)
delta = litellm.utils.Delta( delta = litellm.utils.Delta(
@ -392,15 +397,27 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
# If format is JSON, this was a function call # If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing # Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json": if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper) first_chunk = await anext(streamwrapper) # noqa F821
first_chunk_content = first_chunk.choices[0].delta.content or "" chunk_choice = first_chunk.choices[0]
response_content = first_chunk_content + "".join( if (
[ isinstance(chunk_choice, StreamingChoices)
chunk.choices[0].delta.content and hasattr(chunk_choice, "delta")
async for chunk in streamwrapper and hasattr(chunk_choice.delta, "content")
if chunk.choices[0].delta.content ):
] first_chunk_content = chunk_choice.delta.content or ""
) else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content) function_call = json.loads(response_content)
delta = litellm.utils.Delta( delta = litellm.utils.Delta(
content=None, content=None,
@ -501,8 +518,8 @@ async def ollama_aembeddings(
prompts: List[str], prompts: List[str],
model_response: litellm.EmbeddingResponse, model_response: litellm.EmbeddingResponse,
optional_params: dict, optional_params: dict,
logging_obj=None, logging_obj: Any,
encoding=None, encoding: Any,
): ):
if api_base.endswith("/api/embed"): if api_base.endswith("/api/embed"):
url = api_base url = api_base
@ -581,9 +598,9 @@ def ollama_embeddings(
api_base: str, api_base: str,
model: str, model: str,
prompts: list, prompts: list,
optional_params=None, optional_params: dict,
logging_obj=None, model_response: litellm.EmbeddingResponse,
model_response=None, logging_obj: Any,
encoding=None, encoding=None,
): ):
return asyncio.run( return asyncio.run(

View file

@ -4,7 +4,7 @@ import traceback
import types import types
import uuid import uuid
from itertools import chain from itertools import chain
from typing import List, Optional from typing import Any, List, Optional
import aiohttp import aiohttp
import httpx import httpx
@ -15,6 +15,7 @@ import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
from litellm.types.llms.openai import ChatCompletionAssistantToolCall from litellm.types.llms.openai import ChatCompletionAssistantToolCall
from litellm.types.utils import StreamingChoices
class OllamaError(Exception): class OllamaError(Exception):
@ -216,10 +217,10 @@ def get_ollama_response(
model_response: litellm.ModelResponse, model_response: litellm.ModelResponse,
messages: list, messages: list,
optional_params: dict, optional_params: dict,
model: str,
logging_obj: Any,
api_base="http://localhost:11434", api_base="http://localhost:11434",
api_key: Optional[str] = None, api_key: Optional[str] = None,
model="llama2",
logging_obj=None,
acompletion: bool = False, acompletion: bool = False,
encoding=None, encoding=None,
): ):
@ -252,10 +253,13 @@ def get_ollama_response(
for tool in m["tool_calls"]: for tool in m["tool_calls"]:
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
if typed_tool["type"] == "function": if typed_tool["type"] == "function":
arguments = {}
if "arguments" in typed_tool["function"]:
arguments = json.loads(typed_tool["function"]["arguments"])
ollama_tool_call = OllamaToolCall( ollama_tool_call = OllamaToolCall(
function=OllamaToolCallFunction( function=OllamaToolCallFunction(
name=typed_tool["function"]["name"], name=typed_tool["function"].get("name") or "",
arguments=json.loads(typed_tool["function"]["arguments"]), arguments=arguments,
) )
) )
new_tools.append(ollama_tool_call) new_tools.append(ollama_tool_call)
@ -401,12 +405,16 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
# If format is JSON, this was a function call # If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing # Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json": if data.get("format", "") == "json":
first_chunk = next(streamwrapper) content_chunks = []
response_content = "".join( for chunk in streamwrapper:
chunk.choices[0].delta.content chunk_choice = chunk.choices[0]
for chunk in chain([first_chunk], streamwrapper) if (
if chunk.choices[0].delta.content isinstance(chunk_choice, StreamingChoices)
) and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = "".join(content_chunks)
function_call = json.loads(response_content) function_call = json.loads(response_content)
delta = litellm.utils.Delta( delta = litellm.utils.Delta(
@ -422,7 +430,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
} }
], ],
) )
model_response = first_chunk model_response = content_chunks[0]
model_response.choices[0].delta = delta # type: ignore model_response.choices[0].delta = delta # type: ignore
model_response.choices[0].finish_reason = "tool_calls" model_response.choices[0].finish_reason = "tool_calls"
yield model_response yield model_response
@ -462,15 +470,28 @@ async def ollama_async_streaming(
# If format is JSON, this was a function call # If format is JSON, this was a function call
# Gather all chunks and return the function call as one delta to simplify parsing # Gather all chunks and return the function call as one delta to simplify parsing
if data.get("format", "") == "json": if data.get("format", "") == "json":
first_chunk = await anext(streamwrapper) first_chunk = await anext(streamwrapper) # noqa F821
first_chunk_content = first_chunk.choices[0].delta.content or "" chunk_choice = first_chunk.choices[0]
response_content = first_chunk_content + "".join( if (
[ isinstance(chunk_choice, StreamingChoices)
chunk.choices[0].delta.content and hasattr(chunk_choice, "delta")
async for chunk in streamwrapper and hasattr(chunk_choice.delta, "content")
if chunk.choices[0].delta.content ):
] first_chunk_content = chunk_choice.delta.content or ""
) else:
first_chunk_content = ""
content_chunks = []
async for chunk in streamwrapper:
chunk_choice = chunk.choices[0]
if (
isinstance(chunk_choice, StreamingChoices)
and hasattr(chunk_choice, "delta")
and hasattr(chunk_choice.delta, "content")
):
content_chunks.append(chunk_choice.delta.content)
response_content = first_chunk_content + "".join(content_chunks)
function_call = json.loads(response_content) function_call = json.loads(response_content)
delta = litellm.utils.Delta( delta = litellm.utils.Delta(
content=None, content=None,

View file

@ -39,8 +39,8 @@ def completion(
encoding, encoding,
api_key, api_key,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
default_max_tokens_to_sample=None, default_max_tokens_to_sample=None,
@ -77,7 +77,7 @@ def completion(
data=json.dumps(data), data=json.dumps(data),
stream=optional_params["stream"] if "stream" in optional_params else False, stream=optional_params["stream"] if "stream" in optional_params else False,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return response.iter_lines() return response.iter_lines()
else: else:
## LOGGING ## LOGGING
@ -91,7 +91,7 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
except: except Exception:
raise OobaboogaError( raise OobaboogaError(
message=response.text, status_code=response.status_code message=response.text, status_code=response.status_code
) )
@ -103,7 +103,7 @@ def completion(
else: else:
try: try:
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
except: except Exception:
raise OobaboogaError( raise OobaboogaError(
message=json.dumps(completion_response), message=json.dumps(completion_response),
status_code=response.status_code, status_code=response.status_code,

View file

@ -96,13 +96,13 @@ def completion(
api_key, api_key,
encoding, encoding,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
try: try:
import google.generativeai as palm # type: ignore import google.generativeai as palm # type: ignore
except: except Exception:
raise Exception( raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai" "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
) )
@ -167,14 +167,14 @@ def completion(
choice_obj = Choices(index=idx + 1, message=message_obj) choice_obj = Choices(index=idx + 1, message=message_obj)
choices_list.append(choice_obj) choices_list.append(choice_obj)
model_response.choices = choices_list # type: ignore model_response.choices = choices_list # type: ignore
except Exception as e: except Exception:
raise PalmError( raise PalmError(
message=traceback.format_exc(), status_code=response.status_code message=traceback.format_exc(), status_code=response.status_code
) )
try: try:
completion_response = model_response["choices"][0]["message"].get("content") completion_response = model_response["choices"][0]["message"].get("content")
except: except Exception:
raise PalmError( raise PalmError(
status_code=400, status_code=400,
message=f"No response received. Original response - {response}", message=f"No response received. Original response - {response}",

View file

@ -98,7 +98,7 @@ def completion(
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params=None, optional_params: dict,
stream=False, stream=False,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
@ -123,6 +123,7 @@ def completion(
else: else:
prompt = prompt_factory(model=model, messages=messages) prompt = prompt_factory(model=model, messages=messages)
output_text: Optional[str] = None
if api_base: if api_base:
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
@ -157,7 +158,7 @@ def completion(
import torch import torch
from petals import AutoDistributedModelForCausalLM # type: ignore from petals import AutoDistributedModelForCausalLM # type: ignore
from transformers import AutoTokenizer from transformers import AutoTokenizer
except: except Exception:
raise Exception( raise Exception(
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals" "Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
) )
@ -192,7 +193,7 @@ def completion(
## RESPONSE OBJECT ## RESPONSE OBJECT
output_text = tokenizer.decode(outputs[0]) output_text = tokenizer.decode(outputs[0])
if len(output_text) > 0: if output_text is not None and len(output_text) > 0:
model_response.choices[0].message.content = output_text # type: ignore model_response.choices[0].message.content = output_text # type: ignore
prompt_tokens = len(encoding.encode(prompt)) prompt_tokens = len(encoding.encode(prompt))

View file

@ -265,7 +265,7 @@ class PredibaseChatCompletion(BaseLLM):
## RESPONSE OBJECT ## RESPONSE OBJECT
try: try:
completion_response = response.json() completion_response = response.json()
except: except Exception:
raise PredibaseError(message=response.text, status_code=422) raise PredibaseError(message=response.text, status_code=422)
if "error" in completion_response: if "error" in completion_response:
raise PredibaseError( raise PredibaseError(
@ -348,7 +348,7 @@ class PredibaseChatCompletion(BaseLLM):
model_response["choices"][0]["message"].get("content", "") model_response["choices"][0]["message"].get("content", "")
) )
) ##[TODO] use a model-specific tokenizer ) ##[TODO] use a model-specific tokenizer
except: except Exception:
# this should remain non blocking we should not block a response returning if calculating usage fails # this should remain non blocking we should not block a response returning if calculating usage fails
pass pass
else: else:

View file

@ -5,7 +5,7 @@ import traceback
import uuid import uuid
import xml.etree.ElementTree as ET import xml.etree.ElementTree as ET
from enum import Enum from enum import Enum
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, cast
from jinja2 import BaseLoader, Template, exceptions, meta from jinja2 import BaseLoader, Template, exceptions, meta
from jinja2.sandbox import ImmutableSandboxedEnvironment from jinja2.sandbox import ImmutableSandboxedEnvironment
@ -26,11 +26,14 @@ from litellm.types.completion import (
) )
from litellm.types.llms.anthropic import * from litellm.types.llms.anthropic import *
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
from litellm.types.llms.ollama import OllamaVisionModelObject
from litellm.types.llms.openai import ( from litellm.types.llms.openai import (
AllMessageValues, AllMessageValues,
ChatCompletionAssistantMessage, ChatCompletionAssistantMessage,
ChatCompletionAssistantToolCall, ChatCompletionAssistantToolCall,
ChatCompletionFunctionMessage, ChatCompletionFunctionMessage,
ChatCompletionImageObject,
ChatCompletionTextObject,
ChatCompletionToolCallFunctionChunk, ChatCompletionToolCallFunctionChunk,
ChatCompletionToolMessage, ChatCompletionToolMessage,
ChatCompletionUserMessage, ChatCompletionUserMessage,
@ -164,7 +167,9 @@ def convert_to_ollama_image(openai_image_url: str):
def ollama_pt( def ollama_pt(
model, messages model, messages
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template ) -> Union[
str, OllamaVisionModelObject
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
if "instruct" in model: if "instruct" in model:
prompt = custom_prompt( prompt = custom_prompt(
role_dict={ role_dict={
@ -438,7 +443,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
def _is_system_in_template(): def _is_system_in_template():
try: try:
# Try rendering the template with a system message # Try rendering the template with a system message
response = template.render( template.render(
messages=[{"role": "system", "content": "test"}], messages=[{"role": "system", "content": "test"}],
eos_token="<eos>", eos_token="<eos>",
bos_token="<bos>", bos_token="<bos>",
@ -446,10 +451,11 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
return True return True
# This will be raised if Jinja attempts to render the system message and it can't # This will be raised if Jinja attempts to render the system message and it can't
except: except Exception:
return False return False
try: try:
rendered_text = ""
# Render the template with the provided values # Render the template with the provided values
if _is_system_in_template(): if _is_system_in_template():
rendered_text = template.render( rendered_text = template.render(
@ -460,8 +466,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
) )
else: else:
# treat a system message as a user message, if system not in template # treat a system message as a user message, if system not in template
reformatted_messages = []
try: try:
reformatted_messages = []
for message in messages: for message in messages:
if message["role"] == "system": if message["role"] == "system":
reformatted_messages.append( reformatted_messages.append(
@ -556,30 +562,31 @@ def get_model_info(token, model):
return None, None return None, None
else: else:
return None, None return None, None
except Exception as e: # safely fail a prompt template request except Exception: # safely fail a prompt template request
return None, None return None, None
def format_prompt_togetherai(messages, prompt_format, chat_template): ## OLD TOGETHER AI FLOW
if prompt_format is None: # def format_prompt_togetherai(messages, prompt_format, chat_template):
return default_pt(messages) # if prompt_format is None:
# return default_pt(messages)
human_prompt, assistant_prompt = prompt_format.split("{prompt}") # human_prompt, assistant_prompt = prompt_format.split("{prompt}")
if chat_template is not None: # if chat_template is not None:
prompt = hf_chat_template( # prompt = hf_chat_template(
model=None, messages=messages, chat_template=chat_template # model=None, messages=messages, chat_template=chat_template
) # )
elif prompt_format is not None: # elif prompt_format is not None:
prompt = custom_prompt( # prompt = custom_prompt(
role_dict={}, # role_dict={},
messages=messages, # messages=messages,
initial_prompt_value=human_prompt, # initial_prompt_value=human_prompt,
final_prompt_value=assistant_prompt, # final_prompt_value=assistant_prompt,
) # )
else: # else:
prompt = default_pt(messages) # prompt = default_pt(messages)
return prompt # return prompt
### IBM Granite ### IBM Granite
@ -1063,7 +1070,7 @@ def convert_to_gemini_tool_call_invoke(
else: # don't silently drop params. Make it clear to user what's happening. else: # don't silently drop params. Make it clear to user what's happening.
raise Exception( raise Exception(
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format( "function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
tool message
) )
) )
return _parts_list return _parts_list
@ -1216,12 +1223,14 @@ def convert_function_to_anthropic_tool_invoke(
function_call: Union[dict, ChatCompletionToolCallFunctionChunk], function_call: Union[dict, ChatCompletionToolCallFunctionChunk],
) -> List[AnthropicMessagesToolUseParam]: ) -> List[AnthropicMessagesToolUseParam]:
try: try:
_name = get_attribute_or_key(function_call, "name") or ""
_arguments = get_attribute_or_key(function_call, "arguments")
anthropic_tool_invoke = [ anthropic_tool_invoke = [
AnthropicMessagesToolUseParam( AnthropicMessagesToolUseParam(
type="tool_use", type="tool_use",
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
name=get_attribute_or_key(function_call, "name"), name=_name,
input=json.loads(get_attribute_or_key(function_call, "arguments")), input=json.loads(_arguments) if _arguments else {},
) )
] ]
return anthropic_tool_invoke return anthropic_tool_invoke
@ -1349,8 +1358,9 @@ def anthropic_messages_pt(
): ):
for m in user_message_types_block["content"]: for m in user_message_types_block["content"]:
if m.get("type", "") == "image_url": if m.get("type", "") == "image_url":
m = cast(ChatCompletionImageObject, m)
image_chunk = convert_to_anthropic_image_obj( image_chunk = convert_to_anthropic_image_obj(
m["image_url"]["url"] openai_image_url=m["image_url"]["url"] # type: ignore
) )
_anthropic_content_element = AnthropicMessagesImageParam( _anthropic_content_element = AnthropicMessagesImageParam(
@ -1362,21 +1372,31 @@ def anthropic_messages_pt(
), ),
) )
anthropic_content_element = add_cache_control_to_content( _content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_content_element, anthropic_content_element=_anthropic_content_element,
orignal_content_element=m, orignal_content_element=dict(m),
) )
user_content.append(anthropic_content_element)
if "cache_control" in _content_element:
_anthropic_content_element["cache_control"] = (
_content_element["cache_control"]
)
user_content.append(_anthropic_content_element)
elif m.get("type", "") == "text": elif m.get("type", "") == "text":
_anthropic_text_content_element = { m = cast(ChatCompletionTextObject, m)
"type": "text", _anthropic_text_content_element = AnthropicMessagesTextParam(
"text": m["text"], type="text",
} text=m["text"],
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=m,
) )
user_content.append(anthropic_content_element) _content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=dict(m),
)
_content_element = cast(
AnthropicMessagesTextParam, _content_element
)
user_content.append(_content_element)
elif ( elif (
user_message_types_block["role"] == "tool" user_message_types_block["role"] == "tool"
or user_message_types_block["role"] == "function" or user_message_types_block["role"] == "function"
@ -1390,12 +1410,17 @@ def anthropic_messages_pt(
"type": "text", "type": "text",
"text": user_message_types_block["content"], "text": user_message_types_block["content"],
} }
anthropic_content_element = add_cache_control_to_content( _content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_content_text_element, anthropic_content_element=_anthropic_content_text_element,
orignal_content_element=user_message_types_block, orignal_content_element=dict(user_message_types_block),
) )
user_content.append(anthropic_content_element) if "cache_control" in _content_element:
_anthropic_content_text_element["cache_control"] = _content_element[
"cache_control"
]
user_content.append(_anthropic_content_text_element)
msg_i += 1 msg_i += 1
@ -1417,11 +1442,14 @@ def anthropic_messages_pt(
anthropic_message = AnthropicMessagesTextParam( anthropic_message = AnthropicMessagesTextParam(
type="text", text=m.get("text") type="text", text=m.get("text")
) )
anthropic_message = add_cache_control_to_content( _cached_message = add_cache_control_to_content(
anthropic_content_element=anthropic_message, anthropic_content_element=anthropic_message,
orignal_content_element=m, orignal_content_element=dict(m),
)
assistant_content.append(
cast(AnthropicMessagesTextParam, _cached_message)
) )
assistant_content.append(anthropic_message)
elif ( elif (
"content" in assistant_content_block "content" in assistant_content_block
and isinstance(assistant_content_block["content"], str) and isinstance(assistant_content_block["content"], str)
@ -1430,16 +1458,22 @@ def anthropic_messages_pt(
] # don't pass empty text blocks. anthropic api raises errors. ] # don't pass empty text blocks. anthropic api raises errors.
): ):
_anthropic_text_content_element = { _anthropic_text_content_element = AnthropicMessagesTextParam(
"type": "text", type="text",
"text": assistant_content_block["content"], text=assistant_content_block["content"],
}
anthropic_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=assistant_content_block,
) )
assistant_content.append(anthropic_content_element)
_content_element = add_cache_control_to_content(
anthropic_content_element=_anthropic_text_content_element,
orignal_content_element=dict(assistant_content_block),
)
if "cache_control" in _content_element:
_anthropic_text_content_element["cache_control"] = _content_element[
"cache_control"
]
assistant_content.append(_anthropic_text_content_element)
assistant_tool_calls = assistant_content_block.get("tool_calls") assistant_tool_calls = assistant_content_block.get("tool_calls")
if ( if (
@ -1566,30 +1600,6 @@ def get_system_prompt(messages):
return system_prompt, messages return system_prompt, messages
def convert_to_documents(
observations: Any,
) -> List[MutableMapping]:
"""Converts observations into a 'document' dict"""
documents: List[MutableMapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
return documents
from litellm.types.llms.cohere import ( from litellm.types.llms.cohere import (
CallObject, CallObject,
ChatHistory, ChatHistory,
@ -1943,7 +1953,7 @@ def amazon_titan_pt(
def _load_image_from_url(image_url): def _load_image_from_url(image_url):
try: try:
from PIL import Image from PIL import Image
except: except Exception:
raise Exception("image conversion failed please run `pip install Pillow`") raise Exception("image conversion failed please run `pip install Pillow`")
from io import BytesIO from io import BytesIO
@ -2008,7 +2018,7 @@ def _gemini_vision_convert_messages(messages: list):
else: else:
try: try:
from PIL import Image from PIL import Image
except: except Exception:
raise Exception( raise Exception(
"gemini image conversion failed please run `pip install Pillow`" "gemini image conversion failed please run `pip install Pillow`"
) )
@ -2056,7 +2066,7 @@ def gemini_text_image_pt(messages: list):
""" """
try: try:
import google.generativeai as genai # type: ignore import google.generativeai as genai # type: ignore
except: except Exception:
raise Exception( raise Exception(
"Importing google.generativeai failed, please run 'pip install -q google-generativeai" "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
) )
@ -2331,7 +2341,7 @@ def _convert_to_bedrock_tool_call_result(
for content in content_list: for content in content_list:
if content["type"] == "text": if content["type"] == "text":
content_str += content["text"] content_str += content["text"]
name = message.get("name", "") message.get("name", "")
id = str(message.get("tool_call_id", str(uuid.uuid4()))) id = str(message.get("tool_call_id", str(uuid.uuid4())))
tool_result_content_block = BedrockToolResultContentBlock(text=content_str) tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
@ -2575,7 +2585,7 @@ def function_call_prompt(messages: list, functions: list):
message["content"] += f""" {function_prompt}""" message["content"] += f""" {function_prompt}"""
function_added_to_prompt = True function_added_to_prompt = True
if function_added_to_prompt == False: if function_added_to_prompt is False:
messages.append({"role": "system", "content": f"""{function_prompt}"""}) messages.append({"role": "system", "content": f"""{function_prompt}"""})
return messages return messages
@ -2692,11 +2702,6 @@ def prompt_factory(
) )
elif custom_llm_provider == "anthropic_xml": elif custom_llm_provider == "anthropic_xml":
return anthropic_messages_pt_xml(messages=messages) return anthropic_messages_pt_xml(messages=messages)
elif custom_llm_provider == "together_ai":
prompt_format, chat_template = get_model_info(token=api_key, model=model)
return format_prompt_togetherai(
messages=messages, prompt_format=prompt_format, chat_template=chat_template
)
elif custom_llm_provider == "gemini": elif custom_llm_provider == "gemini":
if ( if (
model == "gemini-pro-vision" model == "gemini-pro-vision"
@ -2810,7 +2815,7 @@ def prompt_factory(
) )
else: else:
return hf_chat_template(original_model_name, messages) return hf_chat_template(original_model_name, messages)
except Exception as e: except Exception:
return default_pt( return default_pt(
messages=messages messages=messages
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2) ) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)

View file

@ -61,7 +61,7 @@ async def async_convert_url_to_base64(url: str) -> str:
try: try:
response = await client.get(url, follow_redirects=True) response = await client.get(url, follow_redirects=True)
return _process_image_response(response, url) return _process_image_response(response, url)
except: except Exception:
pass pass
raise Exception( raise Exception(
f"Error: Unable to fetch image from URL after 3 attempts. url={url}" f"Error: Unable to fetch image from URL after 3 attempts. url={url}"

View file

@ -297,7 +297,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
if "output" in response_data: if "output" in response_data:
try: try:
output_string = "".join(response_data["output"]) output_string = "".join(response_data["output"])
except Exception as e: except Exception:
raise ReplicateError( raise ReplicateError(
status_code=422, status_code=422,
message="Unable to parse response. Got={}".format( message="Unable to parse response. Got={}".format(
@ -344,7 +344,7 @@ async def async_handle_prediction_response_streaming(
if "output" in response_data: if "output" in response_data:
try: try:
output_string = "".join(response_data["output"]) output_string = "".join(response_data["output"])
except Exception as e: except Exception:
raise ReplicateError( raise ReplicateError(
status_code=422, status_code=422,
message="Unable to parse response. Got={}".format( message="Unable to parse response. Got={}".format(
@ -479,7 +479,7 @@ def completion(
else: else:
input_data = {"prompt": prompt, **optional_params} input_data = {"prompt": prompt, **optional_params}
if acompletion is not None and acompletion == True: if acompletion is not None and acompletion is True:
return async_completion( return async_completion(
model_response=model_response, model_response=model_response,
model=model, model=model,
@ -513,7 +513,7 @@ def completion(
print_verbose(prediction_url) print_verbose(prediction_url)
# Handle the prediction response (streaming or non-streaming) # Handle the prediction response (streaming or non-streaming)
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
print_verbose("streaming request") print_verbose("streaming request")
_response = handle_prediction_response_streaming( _response = handle_prediction_response_streaming(
prediction_url, api_key, print_verbose prediction_url, api_key, print_verbose
@ -571,7 +571,7 @@ async def async_completion(
http_handler=http_handler, http_handler=http_handler,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
_response = async_handle_prediction_response_streaming( _response = async_handle_prediction_response_streaming(
prediction_url, api_key, print_verbose prediction_url, api_key, print_verbose
) )

View file

@ -8,7 +8,7 @@ import types
from copy import deepcopy from copy import deepcopy
from enum import Enum from enum import Enum
from functools import partial from functools import partial
from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -112,7 +112,7 @@ class SagemakerLLM(BaseAWSLLM):
): ):
try: try:
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError as e: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
## CREDENTIALS ## ## CREDENTIALS ##
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them # pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
@ -123,7 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
aws_role_name = optional_params.pop("aws_role_name", None) aws_role_name = optional_params.pop("aws_role_name", None)
aws_session_name = optional_params.pop("aws_session_name", None) aws_session_name = optional_params.pop("aws_session_name", None)
aws_profile_name = optional_params.pop("aws_profile_name", None) aws_profile_name = optional_params.pop("aws_profile_name", None)
aws_bedrock_runtime_endpoint = optional_params.pop( optional_params.pop(
"aws_bedrock_runtime_endpoint", None "aws_bedrock_runtime_endpoint", None
) # https://bedrock-runtime.{region_name}.amazonaws.com ) # https://bedrock-runtime.{region_name}.amazonaws.com
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None) aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
@ -175,7 +175,7 @@ class SagemakerLLM(BaseAWSLLM):
from botocore.auth import SigV4Auth from botocore.auth import SigV4Auth
from botocore.awsrequest import AWSRequest from botocore.awsrequest import AWSRequest
from botocore.credentials import Credentials from botocore.credentials import Credentials
except ImportError as e: except ImportError:
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.") raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name) sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
@ -244,7 +244,7 @@ class SagemakerLLM(BaseAWSLLM):
hf_model_name = ( hf_model_name = (
hf_model_name or model hf_model_name or model
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt) ) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
prompt = prompt_factory(model=hf_model_name, messages=messages) prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
return prompt return prompt
@ -256,10 +256,10 @@ class SagemakerLLM(BaseAWSLLM):
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict,
timeout: Optional[Union[float, httpx.Timeout]] = None, timeout: Optional[Union[float, httpx.Timeout]] = None,
custom_prompt_dict={}, custom_prompt_dict={},
hf_model_name=None, hf_model_name=None,
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
acompletion: bool = False, acompletion: bool = False,
@ -277,7 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
openai_like_chat_completions = DatabricksChatCompletion() openai_like_chat_completions = DatabricksChatCompletion()
inference_params["stream"] = True if stream is True else False inference_params["stream"] = True if stream is True else False
_data = { _data: Dict[str, Any] = {
"model": model, "model": model,
"messages": messages, "messages": messages,
**inference_params, **inference_params,
@ -310,7 +310,7 @@ class SagemakerLLM(BaseAWSLLM):
logger_fn=logger_fn, logger_fn=logger_fn,
timeout=timeout, timeout=timeout,
encoding=encoding, encoding=encoding,
headers=prepared_request.headers, headers=prepared_request.headers, # type: ignore
custom_endpoint=True, custom_endpoint=True,
custom_llm_provider="sagemaker_chat", custom_llm_provider="sagemaker_chat",
streaming_decoder=custom_stream_decoder, # type: ignore streaming_decoder=custom_stream_decoder, # type: ignore
@ -474,7 +474,7 @@ class SagemakerLLM(BaseAWSLLM):
try: try:
sync_response = sync_handler.post( sync_response = sync_handler.post(
url=prepared_request.url, url=prepared_request.url,
headers=prepared_request.headers, headers=prepared_request.headers, # type: ignore
json=_data, json=_data,
timeout=timeout, timeout=timeout,
) )
@ -559,7 +559,7 @@ class SagemakerLLM(BaseAWSLLM):
self, self,
api_base: str, api_base: str,
headers: dict, headers: dict,
data: str, data: dict,
logging_obj, logging_obj,
client=None, client=None,
): ):
@ -598,7 +598,7 @@ class SagemakerLLM(BaseAWSLLM):
except httpx.HTTPStatusError as err: except httpx.HTTPStatusError as err:
error_code = err.response.status_code error_code = err.response.status_code
raise SagemakerError(status_code=error_code, message=err.response.text) raise SagemakerError(status_code=error_code, message=err.response.text)
except httpx.TimeoutException as e: except httpx.TimeoutException:
raise SagemakerError(status_code=408, message="Timeout error occurred.") raise SagemakerError(status_code=408, message="Timeout error occurred.")
except Exception as e: except Exception as e:
raise SagemakerError(status_code=500, message=str(e)) raise SagemakerError(status_code=500, message=str(e))
@ -638,7 +638,7 @@ class SagemakerLLM(BaseAWSLLM):
make_call=partial( make_call=partial(
self.make_async_call, self.make_async_call,
api_base=prepared_request.url, api_base=prepared_request.url,
headers=prepared_request.headers, headers=prepared_request.headers, # type: ignore
data=data, data=data,
logging_obj=logging_obj, logging_obj=logging_obj,
), ),
@ -716,7 +716,7 @@ class SagemakerLLM(BaseAWSLLM):
try: try:
response = await async_handler.post( response = await async_handler.post(
url=prepared_request.url, url=prepared_request.url,
headers=prepared_request.headers, headers=prepared_request.headers, # type: ignore
json=data, json=data,
timeout=timeout, timeout=timeout,
) )
@ -794,8 +794,8 @@ class SagemakerLLM(BaseAWSLLM):
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -1032,7 +1032,7 @@ class AWSEventStreamDecoder:
yield self._chunk_parser_messages_api(chunk_data=_data) yield self._chunk_parser_messages_api(chunk_data=_data)
else: else:
yield self._chunk_parser(chunk_data=_data) yield self._chunk_parser(chunk_data=_data)
except json.JSONDecodeError as e: except json.JSONDecodeError:
# Handle or log any unparseable data at the end # Handle or log any unparseable data at the end
verbose_logger.error( verbose_logger.error(
f"Warning: Unparseable JSON data remained: {accumulated_json}" f"Warning: Unparseable JSON data remained: {accumulated_json}"

View file

@ -17,6 +17,7 @@ import requests # type: ignore
import litellm import litellm
from litellm import verbose_logger from litellm import verbose_logger
from litellm.litellm_core_utils.core_helpers import map_finish_reason from litellm.litellm_core_utils.core_helpers import map_finish_reason
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
from litellm.types.llms.databricks import GenericStreamingChunk from litellm.types.llms.databricks import GenericStreamingChunk
from litellm.utils import ( from litellm.utils import (
@ -157,7 +158,7 @@ class MistralTextCompletionConfig:
optional_params["top_p"] = value optional_params["top_p"] = value
if param == "max_tokens" or param == "max_completion_tokens": if param == "max_tokens" or param == "max_completion_tokens":
optional_params["max_tokens"] = value optional_params["max_tokens"] = value
if param == "stream" and value == True: if param == "stream" and value is True:
optional_params["stream"] = value optional_params["stream"] = value
if param == "stop": if param == "stop":
optional_params["stop"] = value optional_params["stop"] = value
@ -249,7 +250,7 @@ class CodestralTextCompletion(BaseLLM):
response: Union[requests.Response, httpx.Response], response: Union[requests.Response, httpx.Response],
model_response: TextCompletionResponse, model_response: TextCompletionResponse,
stream: bool, stream: bool,
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging, logging_obj: LiteLLMLogging,
optional_params: dict, optional_params: dict,
api_key: str, api_key: str,
data: Union[dict, str], data: Union[dict, str],
@ -273,7 +274,7 @@ class CodestralTextCompletion(BaseLLM):
) )
try: try:
completion_response = response.json() completion_response = response.json()
except: except Exception:
raise TextCompletionCodestralError(message=response.text, status_code=422) raise TextCompletionCodestralError(message=response.text, status_code=422)
_original_choices = completion_response.get("choices", []) _original_choices = completion_response.get("choices", [])

View file

@ -176,7 +176,7 @@ class VertexAIConfig:
if param == "top_p": if param == "top_p":
optional_params["top_p"] = value optional_params["top_p"] = value
if ( if (
param == "stream" and value == True param == "stream" and value is True
): # sending stream = False, can cause it to get passed unchecked and raise issues ): # sending stream = False, can cause it to get passed unchecked and raise issues
optional_params["stream"] = value optional_params["stream"] = value
if param == "n": if param == "n":
@ -1313,7 +1313,6 @@ class ModelResponseIterator:
text = "" text = ""
tool_use: Optional[ChatCompletionToolCallChunk] = None tool_use: Optional[ChatCompletionToolCallChunk] = None
is_finished = False
finish_reason = "" finish_reason = ""
usage: Optional[ChatCompletionUsageBlock] = None usage: Optional[ChatCompletionUsageBlock] = None
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates") _candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")

View file

@ -268,7 +268,7 @@ def completion(
): ):
try: try:
import vertexai import vertexai
except: except Exception:
raise VertexAIError( raise VertexAIError(
status_code=400, status_code=400,
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""", message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",

View file

@ -5,7 +5,7 @@ import time
import types import types
import uuid import uuid
from enum import Enum from enum import Enum
from typing import Any, Callable, List, Literal, Optional, Union from typing import Any, Callable, List, Literal, Optional, Union, cast
import httpx # type: ignore import httpx # type: ignore
import requests # type: ignore import requests # type: ignore
@ -25,7 +25,12 @@ from litellm.types.files import (
is_gemini_1_5_accepted_file_type, is_gemini_1_5_accepted_file_type,
is_video_file_type, is_video_file_type,
) )
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage from litellm.types.llms.openai import (
AllMessageValues,
ChatCompletionAssistantMessage,
ChatCompletionImageObject,
ChatCompletionTextObject,
)
from litellm.types.llms.vertex_ai import * from litellm.types.llms.vertex_ai import *
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
@ -150,30 +155,34 @@ def _gemini_convert_messages_with_history(
while ( while (
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
): ):
if messages[msg_i]["content"] is not None and isinstance( _message_content = messages[msg_i].get("content")
messages[msg_i]["content"], list if _message_content is not None and isinstance(_message_content, list):
):
_parts: List[PartType] = [] _parts: List[PartType] = []
for element in messages[msg_i]["content"]: # type: ignore for element in _message_content:
if isinstance(element, dict): if (
if element["type"] == "text" and len(element["text"]) > 0: # type: ignore element["type"] == "text"
_part = PartType(text=element["text"]) # type: ignore and "text" in element
_parts.append(_part) and len(element["text"]) > 0
elif element["type"] == "image_url": ):
img_element: ChatCompletionImageObject = element # type: ignore element = cast(ChatCompletionTextObject, element)
if isinstance(img_element["image_url"], dict): _part = PartType(text=element["text"])
image_url = img_element["image_url"]["url"] _parts.append(_part)
else: elif element["type"] == "image_url":
image_url = img_element["image_url"] element = cast(ChatCompletionImageObject, element)
_part = _process_gemini_image(image_url=image_url) img_element = element
_parts.append(_part) # type: ignore if isinstance(img_element["image_url"], dict):
image_url = img_element["image_url"]["url"]
else:
image_url = img_element["image_url"]
_part = _process_gemini_image(image_url=image_url)
_parts.append(_part)
user_content.extend(_parts) user_content.extend(_parts)
elif ( elif (
messages[msg_i]["content"] is not None _message_content is not None
and isinstance(messages[msg_i]["content"], str) and isinstance(_message_content, str)
and len(messages[msg_i]["content"]) > 0 # type: ignore and len(_message_content) > 0
): ):
_part = PartType(text=messages[msg_i]["content"]) # type: ignore _part = PartType(text=_message_content)
user_content.append(_part) user_content.append(_part)
msg_i += 1 msg_i += 1
@ -201,22 +210,21 @@ def _gemini_convert_messages_with_history(
else: else:
msg_dict = messages[msg_i] # type: ignore msg_dict = messages[msg_i] # type: ignore
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
if assistant_msg.get("content", None) is not None and isinstance( _message_content = assistant_msg.get("content", None)
assistant_msg["content"], list if _message_content is not None and isinstance(_message_content, list):
):
_parts = [] _parts = []
for element in assistant_msg["content"]: for element in _message_content:
if isinstance(element, dict): if isinstance(element, dict):
if element["type"] == "text": if element["type"] == "text":
_part = PartType(text=element["text"]) # type: ignore _part = PartType(text=element["text"])
_parts.append(_part) _parts.append(_part)
assistant_content.extend(_parts) assistant_content.extend(_parts)
elif ( elif (
assistant_msg.get("content", None) is not None _message_content is not None
and isinstance(assistant_msg["content"], str) and isinstance(_message_content, str)
and assistant_msg["content"] and _message_content
): ):
assistant_text = assistant_msg["content"] # either string or none assistant_text = _message_content # either string or none
assistant_content.append(PartType(text=assistant_text)) # type: ignore assistant_content.append(PartType(text=assistant_text)) # type: ignore
## HANDLE ASSISTANT FUNCTION CALL ## HANDLE ASSISTANT FUNCTION CALL
@ -256,7 +264,9 @@ def _gemini_convert_messages_with_history(
raise e raise e
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str): def _get_client_cache_key(
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
):
_cache_key = f"{model}-{vertex_project}-{vertex_location}" _cache_key = f"{model}-{vertex_project}-{vertex_location}"
return _cache_key return _cache_key
@ -294,7 +304,7 @@ def completion(
""" """
try: try:
import vertexai import vertexai
except: except Exception:
raise VertexAIError( raise VertexAIError(
status_code=400, status_code=400,
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM", message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
@ -339,6 +349,8 @@ def completion(
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key) _vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
if _vertex_llm_model_object is None: if _vertex_llm_model_object is None:
from google.auth.credentials import Credentials
if vertex_credentials is not None and isinstance(vertex_credentials, str): if vertex_credentials is not None and isinstance(vertex_credentials, str):
import google.oauth2.service_account import google.oauth2.service_account
@ -356,7 +368,9 @@ def completion(
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}" f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
) )
vertexai.init( vertexai.init(
project=vertex_project, location=vertex_location, credentials=creds project=vertex_project,
location=vertex_location,
credentials=cast(Credentials, creds),
) )
## Load Config ## Load Config
@ -391,7 +405,6 @@ def completion(
request_str = "" request_str = ""
response_obj = None response_obj = None
async_client = None
instances = None instances = None
client_options = { client_options = {
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com" "api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
@ -400,7 +413,7 @@ def completion(
model in litellm.vertex_language_models model in litellm.vertex_language_models
or model in litellm.vertex_vision_models or model in litellm.vertex_vision_models
): ):
llm_model = _vertex_llm_model_object or GenerativeModel(model) llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
mode = "vision" mode = "vision"
request_str += f"llm_model = GenerativeModel({model})\n" request_str += f"llm_model = GenerativeModel({model})\n"
elif model in litellm.vertex_chat_models: elif model in litellm.vertex_chat_models:
@ -459,7 +472,6 @@ def completion(
"model_response": model_response, "model_response": model_response,
"encoding": encoding, "encoding": encoding,
"messages": messages, "messages": messages,
"request_str": request_str,
"print_verbose": print_verbose, "print_verbose": print_verbose,
"client_options": client_options, "client_options": client_options,
"instances": instances, "instances": instances,
@ -474,6 +486,7 @@ def completion(
return async_completion(**data) return async_completion(**data)
completion_response = None
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call") print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
@ -529,7 +542,7 @@ def completion(
# Check if it's a RepeatedComposite instance # Check if it's a RepeatedComposite instance
for key, val in function_call.args.items(): for key, val in function_call.args.items():
if isinstance( if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
): ):
# If so, convert to list # If so, convert to list
args_dict[key] = [v for v in val] args_dict[key] = [v for v in val]
@ -560,9 +573,9 @@ def completion(
optional_params["tools"] = tools optional_params["tools"] = tools
elif mode == "chat": elif mode == "chat":
chat = llm_model.start_chat() chat = llm_model.start_chat()
request_str += f"chat = llm_model.start_chat()\n" request_str += "chat = llm_model.start_chat()\n"
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# NOTE: VertexAI does not accept stream=True as a param and raises an error, # NOTE: VertexAI does not accept stream=True as a param and raises an error,
# we handle this by removing 'stream' from optional params and sending the request # we handle this by removing 'stream' from optional params and sending the request
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format # after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
@ -597,7 +610,7 @@ def completion(
) )
completion_response = chat.send_message(prompt, **optional_params).text completion_response = chat.send_message(prompt, **optional_params).text
elif mode == "text": elif mode == "text":
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
optional_params.pop( optional_params.pop(
"stream", None "stream", None
) # See note above on handling streaming for vertex ai ) # See note above on handling streaming for vertex ai
@ -632,6 +645,12 @@ def completion(
""" """
Vertex AI Model Garden Vertex AI Model Garden
""" """
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -661,13 +680,17 @@ def completion(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
return response return response
elif mode == "private": elif mode == "private":
""" """
Vertex AI Model Garden deployed on private endpoint Vertex AI Model Garden deployed on private endpoint
""" """
if instances is None:
raise ValueError("instances are required for private endpoint")
if llm_model is None:
raise ValueError("Unable to pick client for private endpoint")
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -686,7 +709,7 @@ def completion(
and "\nOutput:\n" in completion_response and "\nOutput:\n" in completion_response
): ):
completion_response = completion_response.split("\nOutput:\n", 1)[1] completion_response = completion_response.split("\nOutput:\n", 1)[1]
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
return response return response
@ -715,7 +738,7 @@ def completion(
else: else:
# init prompt tokens # init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None: if response_obj is not None:
if hasattr(response_obj, "usage_metadata") and hasattr( if hasattr(response_obj, "usage_metadata") and hasattr(
response_obj.usage_metadata, "prompt_token_count" response_obj.usage_metadata, "prompt_token_count"
@ -771,11 +794,13 @@ async def async_completion(
try: try:
import proto # type: ignore import proto # type: ignore
response_obj = None
completion_response = None
if mode == "vision": if mode == "vision":
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call") print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
print_verbose(f"\nProcessing input messages = {messages}") print_verbose(f"\nProcessing input messages = {messages}")
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
stream = optional_params.pop("stream", False) optional_params.pop("stream", False)
content = _gemini_convert_messages_with_history(messages=messages) content = _gemini_convert_messages_with_history(messages=messages)
@ -817,7 +842,7 @@ async def async_completion(
# Check if it's a RepeatedComposite instance # Check if it's a RepeatedComposite instance
for key, val in function_call.args.items(): for key, val in function_call.args.items():
if isinstance( if isinstance(
val, proto.marshal.collections.repeated.RepeatedComposite val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
): ):
# If so, convert to list # If so, convert to list
args_dict[key] = [v for v in val] args_dict[key] = [v for v in val]
@ -880,6 +905,11 @@ async def async_completion(
""" """
from google.cloud import aiplatform # type: ignore from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
## LOGGING ## LOGGING
logging_obj.pre_call( logging_obj.pre_call(
input=prompt, input=prompt,
@ -953,7 +983,7 @@ async def async_completion(
else: else:
# init prompt tokens # init prompt tokens
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter # this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0 prompt_tokens, completion_tokens, _ = 0, 0, 0
if response_obj is not None and ( if response_obj is not None and (
hasattr(response_obj, "usage_metadata") hasattr(response_obj, "usage_metadata")
and hasattr(response_obj.usage_metadata, "prompt_token_count") and hasattr(response_obj.usage_metadata, "prompt_token_count")
@ -1001,6 +1031,7 @@ async def async_streaming(
""" """
Add support for async streaming calls for gemini-pro Add support for async streaming calls for gemini-pro
""" """
response: Any = None
if mode == "vision": if mode == "vision":
stream = optional_params.pop("stream") stream = optional_params.pop("stream")
tools = optional_params.pop("tools", None) tools = optional_params.pop("tools", None)
@ -1065,6 +1096,11 @@ async def async_streaming(
elif mode == "custom": elif mode == "custom":
from google.cloud import aiplatform # type: ignore from google.cloud import aiplatform # type: ignore
if vertex_project is None or vertex_location is None:
raise ValueError(
"Vertex project and location are required for custom endpoint"
)
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
## LOGGING ## LOGGING
@ -1102,6 +1138,8 @@ async def async_streaming(
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
elif mode == "private": elif mode == "private":
if instances is None:
raise ValueError("Instances are required for private endpoint")
stream = optional_params.pop("stream", None) stream = optional_params.pop("stream", None)
_ = instances[0].pop("stream", None) _ = instances[0].pop("stream", None)
request_str += f"llm_model.predict_async(instances={instances})\n" request_str += f"llm_model.predict_async(instances={instances})\n"
@ -1118,6 +1156,9 @@ async def async_streaming(
if stream: if stream:
response = TextStreamer(completion_response) response = TextStreamer(completion_response)
if response is None:
raise ValueError("Unable to generate response")
logging_obj.post_call(input=prompt, api_key=None, original_response=response) logging_obj.post_call(input=prompt, api_key=None, original_response=response)
streamwrapper = CustomStreamWrapper( streamwrapper = CustomStreamWrapper(

View file

@ -1,7 +1,7 @@
import json import json
import os import os
import types import types
from typing import Literal, Optional, Union from typing import Any, Literal, Optional, Union, cast
import httpx import httpx
from pydantic import BaseModel from pydantic import BaseModel
@ -53,7 +53,7 @@ class VertexEmbedding(VertexBase):
gemini_api_key: Optional[str] = None, gemini_api_key: Optional[str] = None,
extra_headers: Optional[dict] = None, extra_headers: Optional[dict] = None,
): ):
if aembedding == True: if aembedding is True:
return self.async_embedding( return self.async_embedding(
model=model, model=model,
input=input, input=input,

View file

@ -45,8 +45,8 @@ def completion(
print_verbose: Callable, print_verbose: Callable,
encoding, encoding,
logging_obj, logging_obj,
optional_params: dict,
custom_prompt_dict={}, custom_prompt_dict={},
optional_params=None,
litellm_params=None, litellm_params=None,
logger_fn=None, logger_fn=None,
): ):
@ -83,7 +83,7 @@ def completion(
) )
## COMPLETION CALL ## COMPLETION CALL
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
return iter(outputs) return iter(outputs)
else: else:
## LOGGING ## LOGGING
@ -144,10 +144,7 @@ def batch_completions(
llm, SamplingParams = validate_environment(model=model) llm, SamplingParams = validate_environment(model=model)
except Exception as e: except Exception as e:
error_str = str(e) error_str = str(e)
if "data parallel group is already initialized" in error_str: raise VLLMError(status_code=0, message=error_str)
pass
else:
raise VLLMError(status_code=0, message=error_str)
sampling_params = SamplingParams(**optional_params) sampling_params = SamplingParams(**optional_params)
prompts = [] prompts = []
if model in custom_prompt_dict: if model in custom_prompt_dict:

View file

@ -106,6 +106,7 @@ from .llms.prompt_templates.factory import (
custom_prompt, custom_prompt,
function_call_prompt, function_call_prompt,
map_system_message_pt, map_system_message_pt,
ollama_pt,
prompt_factory, prompt_factory,
stringify_json_tool_call_content, stringify_json_tool_call_content,
) )
@ -150,7 +151,6 @@ from .types.utils import (
encoding = tiktoken.get_encoding("cl100k_base") encoding = tiktoken.get_encoding("cl100k_base")
from litellm.utils import ( from litellm.utils import (
Choices, Choices,
CustomStreamWrapper,
EmbeddingResponse, EmbeddingResponse,
ImageResponse, ImageResponse,
Message, Message,
@ -159,8 +159,6 @@ from litellm.utils import (
TextCompletionResponse, TextCompletionResponse,
TextCompletionStreamWrapper, TextCompletionStreamWrapper,
TranscriptionResponse, TranscriptionResponse,
get_secret,
read_config_args,
) )
####### ENVIRONMENT VARIABLES ################### ####### ENVIRONMENT VARIABLES ###################
@ -214,7 +212,7 @@ class LiteLLM:
class Chat: class Chat:
def __init__(self, params, router_obj: Optional[Any]): def __init__(self, params, router_obj: Optional[Any]):
self.params = params self.params = params
if self.params.get("acompletion", False) == True: if self.params.get("acompletion", False) is True:
self.params.pop("acompletion") self.params.pop("acompletion")
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions( self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
self.params, router_obj=router_obj self.params, router_obj=router_obj
@ -837,10 +835,10 @@ def completion(
model_response = ModelResponse() model_response = ModelResponse()
setattr(model_response, "usage", litellm.Usage()) setattr(model_response, "usage", litellm.Usage())
if ( if (
kwargs.get("azure", False) == True kwargs.get("azure", False) is True
): # don't remove flag check, to remain backwards compatible for repos like Codium ): # don't remove flag check, to remain backwards compatible for repos like Codium
custom_llm_provider = "azure" custom_llm_provider = "azure"
if deployment_id != None: # azure llms if deployment_id is not None: # azure llms
model = deployment_id model = deployment_id
custom_llm_provider = "azure" custom_llm_provider = "azure"
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider( model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
@ -1156,7 +1154,7 @@ def completion(
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -1278,7 +1276,7 @@ def completion(
if ( if (
len(messages) > 0 len(messages) > 0
and "content" in messages[0] and "content" in messages[0]
and type(messages[0]["content"]) == list and isinstance(messages[0]["content"], list)
): ):
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content'] # text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
# https://platform.openai.com/docs/api-reference/completions/create # https://platform.openai.com/docs/api-reference/completions/create
@ -1304,16 +1302,16 @@ def completion(
) )
if ( if (
optional_params.get("stream", False) == False optional_params.get("stream", False) is False
and acompletion == False and acompletion is False
and text_completion == False and text_completion is False
): ):
# convert to chat completion response # convert to chat completion response
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object( _response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
response_object=_response, model_response_object=model_response response_object=_response, model_response_object=model_response
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -1519,7 +1517,7 @@ def completion(
acompletion=acompletion, acompletion=acompletion,
) )
if optional_params.get("stream", False) == True: if optional_params.get("stream", False) is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -1566,7 +1564,7 @@ def completion(
custom_prompt_dict=custom_prompt_dict, custom_prompt_dict=custom_prompt_dict,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
@ -1575,7 +1573,7 @@ def completion(
original_response=model_response, original_response=model_response,
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -1654,7 +1652,7 @@ def completion(
timeout=timeout, timeout=timeout,
client=client, client=client,
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -1691,7 +1689,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
response, response,
@ -1700,7 +1698,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -1740,7 +1738,7 @@ def completion(
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model_response,
@ -1788,7 +1786,7 @@ def completion(
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model_response,
@ -1836,7 +1834,7 @@ def completion(
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model_response,
@ -1875,7 +1873,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model_response,
@ -1916,7 +1914,7 @@ def completion(
) )
if ( if (
"stream" in optional_params "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] is True
and acompletion is False and acompletion is False
): ):
# don't try to access stream object, # don't try to access stream object,
@ -1943,7 +1941,7 @@ def completion(
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model_response,
@ -2095,7 +2093,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
# fake palm streaming # fake palm streaming
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# fake streaming for palm # fake streaming for palm
resp_string = model_response["choices"][0]["message"]["content"] resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper( response = CustomStreamWrapper(
@ -2390,7 +2388,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
model_response, model_response,
@ -2527,7 +2525,7 @@ def completion(
) )
if ( if (
"stream" in optional_params "stream" in optional_params
and optional_params["stream"] == True and optional_params["stream"] is True
and not isinstance(response, CustomStreamWrapper) and not isinstance(response, CustomStreamWrapper)
): ):
# don't try to access stream object, # don't try to access stream object,
@ -2563,7 +2561,7 @@ def completion(
) )
if ( if (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] is True
): ## [BETA] ): ## [BETA]
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
@ -2587,38 +2585,38 @@ def completion(
if model in custom_prompt_dict: if model in custom_prompt_dict:
# check if the model has a registered custom prompt # check if the model has a registered custom prompt
model_prompt_details = custom_prompt_dict[model] model_prompt_details = custom_prompt_dict[model]
prompt = custom_prompt( ollama_prompt = custom_prompt(
role_dict=model_prompt_details["roles"], role_dict=model_prompt_details["roles"],
initial_prompt_value=model_prompt_details["initial_prompt_value"], initial_prompt_value=model_prompt_details["initial_prompt_value"],
final_prompt_value=model_prompt_details["final_prompt_value"], final_prompt_value=model_prompt_details["final_prompt_value"],
messages=messages, messages=messages,
) )
else: else:
prompt = prompt_factory( modified_prompt = ollama_pt(model=model, messages=messages)
model=model, if isinstance(modified_prompt, dict):
messages=messages,
custom_llm_provider=custom_llm_provider,
)
if isinstance(prompt, dict):
# for multimode models - ollama/llava prompt_factory returns a dict { # for multimode models - ollama/llava prompt_factory returns a dict {
# "prompt": prompt, # "prompt": prompt,
# "images": images # "images": images
# } # }
prompt, images = prompt["prompt"], prompt["images"] ollama_prompt, images = (
modified_prompt["prompt"],
modified_prompt["images"],
)
optional_params["images"] = images optional_params["images"] = images
else:
ollama_prompt = modified_prompt
## LOGGING ## LOGGING
generator = ollama.get_ollama_response( generator = ollama.get_ollama_response(
api_base=api_base, api_base=api_base,
model=model, model=model,
prompt=prompt, prompt=ollama_prompt,
optional_params=optional_params, optional_params=optional_params,
logging_obj=logging, logging_obj=logging,
acompletion=acompletion, acompletion=acompletion,
model_response=model_response, model_response=model_response,
encoding=encoding, encoding=encoding,
) )
if acompletion is True or optional_params.get("stream", False) == True: if acompletion is True or optional_params.get("stream", False) is True:
return generator return generator
response = generator response = generator
@ -2701,7 +2699,7 @@ def completion(
api_key=api_key, api_key=api_key,
logging_obj=logging, logging_obj=logging,
) )
if "stream" in optional_params and optional_params["stream"] == True: if "stream" in optional_params and optional_params["stream"] is True:
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
response, response,
@ -2710,7 +2708,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
if optional_params.get("stream", False) or acompletion == True: if optional_params.get("stream", False) or acompletion is True:
## LOGGING ## LOGGING
logging.post_call( logging.post_call(
input=messages, input=messages,
@ -2743,7 +2741,7 @@ def completion(
logging_obj=logging, logging_obj=logging,
) )
if inspect.isgenerator(model_response) or ( if inspect.isgenerator(model_response) or (
"stream" in optional_params and optional_params["stream"] == True "stream" in optional_params and optional_params["stream"] is True
): ):
# don't try to access stream object, # don't try to access stream object,
response = CustomStreamWrapper( response = CustomStreamWrapper(
@ -2771,7 +2769,7 @@ def completion(
encoding=encoding, encoding=encoding,
logging_obj=logging, logging_obj=logging,
) )
if stream == True: ## [BETA] if stream is True: ## [BETA]
# Fake streaming for petals # Fake streaming for petals
resp_string = model_response["choices"][0]["message"]["content"] resp_string = model_response["choices"][0]["message"]["content"]
response = CustomStreamWrapper( response = CustomStreamWrapper(
@ -2786,7 +2784,7 @@ def completion(
import requests import requests
url = litellm.api_base or api_base or "" url = litellm.api_base or api_base or ""
if url == None or url == "": if url is None or url == "":
raise ValueError( raise ValueError(
"api_base not set. Set api_base or litellm.api_base for custom endpoints" "api_base not set. Set api_base or litellm.api_base for custom endpoints"
) )
@ -3145,10 +3143,10 @@ def batch_completion_models(*args, **kwargs):
try: try:
result = future.result() result = future.result()
return result return result
except Exception as e: except Exception:
# if model 1 fails, continue with response from model 2, model3 # if model 1 fails, continue with response from model 2, model3
print_verbose( print_verbose(
f"\n\ngot an exception, ignoring, removing from futures" "\n\ngot an exception, ignoring, removing from futures"
) )
print_verbose(futures) print_verbose(futures)
new_futures = {} new_futures = {}
@ -3189,9 +3187,6 @@ def batch_completion_models_all_responses(*args, **kwargs):
import concurrent.futures import concurrent.futures
# ANSI escape codes for colored output # ANSI escape codes for colored output
GREEN = "\033[92m"
RED = "\033[91m"
RESET = "\033[0m"
if "model" in kwargs: if "model" in kwargs:
kwargs.pop("model") kwargs.pop("model")
@ -3520,7 +3515,7 @@ def embedding(
if api_base is None: if api_base is None:
raise ValueError( raise ValueError(
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env" "No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
) )
## EMBEDDING CALL ## EMBEDDING CALL
@ -4106,7 +4101,6 @@ def text_completion(
*args, *args,
**kwargs, **kwargs,
): ):
global print_verbose
import copy import copy
""" """
@ -4136,7 +4130,7 @@ def text_completion(
Your example of how to use this function goes here. Your example of how to use this function goes here.
""" """
if "engine" in kwargs: if "engine" in kwargs:
if model == None: if model is None:
# only use engine when model not passed # only use engine when model not passed
model = kwargs["engine"] model = kwargs["engine"]
kwargs.pop("engine") kwargs.pop("engine")
@ -4189,18 +4183,18 @@ def text_completion(
if custom_llm_provider == "huggingface": if custom_llm_provider == "huggingface":
# if echo == True, for TGI llms we need to set top_n_tokens to 3 # if echo == True, for TGI llms we need to set top_n_tokens to 3
if echo == True: if echo is True:
# for tgi llms # for tgi llms
if "top_n_tokens" not in kwargs: if "top_n_tokens" not in kwargs:
kwargs["top_n_tokens"] = 3 kwargs["top_n_tokens"] = 3
# processing prompt - users can pass raw tokens to OpenAI Completion() # processing prompt - users can pass raw tokens to OpenAI Completion()
if type(prompt) == list: if isinstance(prompt, list):
import concurrent.futures import concurrent.futures
tokenizer = tiktoken.encoding_for_model("text-davinci-003") tokenizer = tiktoken.encoding_for_model("text-davinci-003")
## if it's a 2d list - each element in the list is a text_completion() request ## if it's a 2d list - each element in the list is a text_completion() request
if len(prompt) > 0 and type(prompt[0]) == list: if len(prompt) > 0 and isinstance(prompt[0], list):
responses = [None for x in prompt] # init responses responses = [None for x in prompt] # init responses
def process_prompt(i, individual_prompt): def process_prompt(i, individual_prompt):
@ -4299,7 +4293,7 @@ def text_completion(
raw_response = response._hidden_params.get("original_response", None) raw_response = response._hidden_params.get("original_response", None)
transformed_logprobs = litellm.utils.transform_logprobs(raw_response) transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
except Exception as e: except Exception as e:
print_verbose(f"LiteLLM non blocking exception: {e}") verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
if isinstance(response, TextCompletionResponse): if isinstance(response, TextCompletionResponse):
return response return response
@ -4813,12 +4807,12 @@ def transcription(
Allows router to load balance between them Allows router to load balance between them
""" """
atranscription = kwargs.get("atranscription", False) atranscription = kwargs.get("atranscription", False)
litellm_call_id = kwargs.get("litellm_call_id", None) kwargs.get("litellm_call_id", None)
logger_fn = kwargs.get("logger_fn", None) kwargs.get("logger_fn", None)
proxy_server_request = kwargs.get("proxy_server_request", None) kwargs.get("proxy_server_request", None)
model_info = kwargs.get("model_info", None) kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {}) kwargs.get("metadata", {})
tags = kwargs.pop("tags", []) kwargs.pop("tags", [])
drop_params = kwargs.get("drop_params", None) drop_params = kwargs.get("drop_params", None)
client: Optional[ client: Optional[
@ -4996,7 +4990,7 @@ def speech(
model_info = kwargs.get("model_info", None) model_info = kwargs.get("model_info", None)
metadata = kwargs.get("metadata", {}) metadata = kwargs.get("metadata", {})
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
tags = kwargs.pop("tags", []) kwargs.pop("tags", [])
optional_params = {} optional_params = {}
if response_format is not None: if response_format is not None:
@ -5345,12 +5339,12 @@ def print_verbose(print_statement):
verbose_logger.debug(print_statement) verbose_logger.debug(print_statement)
if litellm.set_verbose: if litellm.set_verbose:
print(print_statement) # noqa print(print_statement) # noqa
except: except Exception:
pass pass
def config_completion(**kwargs): def config_completion(**kwargs):
if litellm.config_path != None: if litellm.config_path is not None:
config_args = read_config_args(litellm.config_path) config_args = read_config_args(litellm.config_path)
# overwrite any args passed in with config args # overwrite any args passed in with config args
return completion(**kwargs, **config_args) return completion(**kwargs, **config_args)
@ -5408,16 +5402,18 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
response["choices"][0]["text"] = combined_content response["choices"][0]["text"] = combined_content
if len(combined_content) > 0: if len(combined_content) > 0:
completion_output = combined_content pass
else: else:
completion_output = "" pass
# # Update usage information if needed # # Update usage information if needed
try: try:
response["usage"]["prompt_tokens"] = token_counter( response["usage"]["prompt_tokens"] = token_counter(
model=model, messages=messages model=model, messages=messages
) )
except: # don't allow this failing to block a complete streaming response from being returned except (
print_verbose(f"token_counter failed, assuming prompt tokens is 0") Exception
): # don't allow this failing to block a complete streaming response from being returned
print_verbose("token_counter failed, assuming prompt tokens is 0")
response["usage"]["prompt_tokens"] = 0 response["usage"]["prompt_tokens"] = 0
response["usage"]["completion_tokens"] = token_counter( response["usage"]["completion_tokens"] = token_counter(
model=model, model=model,

View file

@ -128,14 +128,14 @@ class LiteLLMBase(BaseModel):
def json(self, **kwargs): # type: ignore def json(self, **kwargs): # type: ignore
try: try:
return self.model_dump(**kwargs) # noqa return self.model_dump(**kwargs) # noqa
except Exception as e: except Exception:
# if using pydantic v1 # if using pydantic v1
return self.dict(**kwargs) return self.dict(**kwargs)
def fields_set(self): def fields_set(self):
try: try:
return self.model_fields_set # noqa return self.model_fields_set # noqa
except: except Exception:
# if using pydantic v1 # if using pydantic v1
return self.__fields_set__ return self.__fields_set__

View file

@ -1,202 +0,0 @@
def add_new_model():
import streamlit as st
import json, requests, uuid
model_name = st.text_input(
"Model Name - user-facing model name", placeholder="gpt-3.5-turbo"
)
st.subheader("LiteLLM Params")
litellm_model_name = st.text_input(
"Model", placeholder="azure/gpt-35-turbo-us-east"
)
litellm_api_key = st.text_input("API Key")
litellm_api_base = st.text_input(
"API Base",
placeholder="https://my-endpoint.openai.azure.com",
)
litellm_api_version = st.text_input("API Version", placeholder="2023-07-01-preview")
litellm_params = json.loads(
st.text_area(
"Additional Litellm Params (JSON dictionary). [See all possible inputs](https://github.com/BerriAI/litellm/blob/3f15d7230fe8e7492c95a752963e7fbdcaf7bf98/litellm/main.py#L293)",
value={},
)
)
st.subheader("Model Info")
mode_options = ("completion", "embedding", "image generation")
mode_selected = st.selectbox("Mode", mode_options)
model_info = json.loads(
st.text_area(
"Additional Model Info (JSON dictionary)",
value={},
)
)
if st.button("Submit"):
try:
model_info = {
"model_name": model_name,
"litellm_params": {
"model": litellm_model_name,
"api_key": litellm_api_key,
"api_base": litellm_api_base,
"api_version": litellm_api_version,
},
"model_info": {
"id": str(uuid.uuid4()),
"mode": mode_selected,
},
}
# Make the POST request to the specified URL
complete_url = ""
if st.session_state["api_url"].endswith("/"):
complete_url = f"{st.session_state['api_url']}model/new"
else:
complete_url = f"{st.session_state['api_url']}/model/new"
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
response = requests.post(complete_url, json=model_info, headers=headers)
if response.status_code == 200:
st.success("Model added successfully!")
else:
st.error(f"Failed to add model. Status code: {response.status_code}")
st.success("Form submitted successfully!")
except Exception as e:
raise e
def list_models():
import streamlit as st
import requests
# Check if the necessary configuration is available
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
# Make the GET request
try:
complete_url = ""
if isinstance(st.session_state["api_url"], str) and st.session_state[
"api_url"
].endswith("/"):
complete_url = f"{st.session_state['api_url']}models"
else:
complete_url = f"{st.session_state['api_url']}/models"
response = requests.get(
complete_url,
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
)
# Check if the request was successful
if response.status_code == 200:
models = response.json()
st.write(models) # or st.json(models) to pretty print the JSON
else:
st.error(f"Failed to get models. Status code: {response.status_code}")
except Exception as e:
st.error(f"An error occurred while requesting models: {e}")
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def create_key():
import streamlit as st
import json, requests, uuid
if (
st.session_state.get("api_url", None) is not None
and st.session_state.get("proxy_key", None) is not None
):
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
models = st.text_input("Models it can access (separated by comma)", value="")
models = models.split(",") if models else []
additional_params = json.loads(
st.text_area(
"Additional Key Params (JSON dictionary). [See all possible inputs](https://litellm-api.up.railway.app/#/key%20management/generate_key_fn_key_generate_post)",
value={},
)
)
if st.button("Submit"):
try:
key_post_body = {
"duration": duration,
"models": models,
**additional_params,
}
# Make the POST request to the specified URL
complete_url = ""
if st.session_state["api_url"].endswith("/"):
complete_url = f"{st.session_state['api_url']}key/generate"
else:
complete_url = f"{st.session_state['api_url']}/key/generate"
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
response = requests.post(
complete_url, json=key_post_body, headers=headers
)
if response.status_code == 200:
st.success(f"Key added successfully! - {response.json()}")
else:
st.error(f"Failed to add Key. Status code: {response.status_code}")
st.success("Form submitted successfully!")
except Exception as e:
raise e
else:
st.warning(
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
)
def streamlit_ui():
import streamlit as st
st.header("Admin Configuration")
# Add a navigation sidebar
st.sidebar.title("Navigation")
page = st.sidebar.radio(
"Go to", ("Proxy Setup", "Add Models", "List Models", "Create Key")
)
# Initialize session state variables if not already present
if "api_url" not in st.session_state:
st.session_state["api_url"] = None
if "proxy_key" not in st.session_state:
st.session_state["proxy_key"] = None
# Display different pages based on navigation selection
if page == "Proxy Setup":
# Use text inputs with intermediary variables
input_api_url = st.text_input(
"Proxy Endpoint",
value=st.session_state.get("api_url", ""),
placeholder="http://0.0.0.0:8000",
)
input_proxy_key = st.text_input(
"Proxy Key",
value=st.session_state.get("proxy_key", ""),
placeholder="sk-...",
)
# When the "Save" button is clicked, update the session state
if st.button("Save"):
st.session_state["api_url"] = input_api_url
st.session_state["proxy_key"] = input_proxy_key
st.success("Configuration saved!")
elif page == "Add Models":
add_new_model()
elif page == "List Models":
list_models()
elif page == "Create Key":
create_key()
if __name__ == "__main__":
streamlit_ui()

View file

@ -69,7 +69,7 @@ async def get_global_activity(
try: try:
if prisma_client is None: if prisma_client is None:
raise ValueError( raise ValueError(
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
) )
sql_query = """ sql_query = """

View file

@ -132,7 +132,7 @@ def common_checks(
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints # 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
if ( if (
general_settings.get("enforce_user_param", None) is not None general_settings.get("enforce_user_param", None) is not None
and general_settings["enforce_user_param"] == True and general_settings["enforce_user_param"] is True
): ):
if is_llm_api_route(route=route) and "user" not in request_body: if is_llm_api_route(route=route) and "user" not in request_body:
raise Exception( raise Exception(
@ -557,7 +557,7 @@ async def get_team_object(
) )
return _response return _response
except Exception as e: except Exception:
raise Exception( raise Exception(
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call." f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
) )
@ -664,7 +664,7 @@ async def get_org_object(
raise Exception raise Exception
return response return response
except Exception as e: except Exception:
raise Exception( raise Exception(
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call." f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
) )

View file

@ -98,7 +98,7 @@ class LicenseCheck:
elif self._verify(license_str=self.license_str) is True: elif self._verify(license_str=self.license_str) is True:
return True return True
return False return False
except Exception as e: except Exception:
return False return False
def verify_license_without_api_request(self, public_key, license_key): def verify_license_without_api_request(self, public_key, license_key):

View file

@ -112,7 +112,6 @@ async def user_api_key_auth(
), ),
) -> UserAPIKeyAuth: ) -> UserAPIKeyAuth:
from litellm.proxy.proxy_server import ( from litellm.proxy.proxy_server import (
custom_db_client,
general_settings, general_settings,
jwt_handler, jwt_handler,
litellm_proxy_admin_name, litellm_proxy_admin_name,
@ -476,7 +475,7 @@ async def user_api_key_auth(
) )
if route == "/user/auth": if route == "/user/auth":
if general_settings.get("allow_user_auth", False) == True: if general_settings.get("allow_user_auth", False) is True:
return UserAPIKeyAuth() return UserAPIKeyAuth()
else: else:
raise HTTPException( raise HTTPException(
@ -597,7 +596,7 @@ async def user_api_key_auth(
## VALIDATE MASTER KEY ## ## VALIDATE MASTER KEY ##
try: try:
assert isinstance(master_key, str) assert isinstance(master_key, str)
except Exception as e: except Exception:
raise HTTPException( raise HTTPException(
status_code=500, status_code=500,
detail={ detail={
@ -648,7 +647,7 @@ async def user_api_key_auth(
) )
if ( if (
prisma_client is None and custom_db_client is None prisma_client is None
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error ): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
raise Exception("No connected db.") raise Exception("No connected db.")
@ -722,9 +721,9 @@ async def user_api_key_auth(
if config != {}: if config != {}:
model_list = config.get("model_list", []) model_list = config.get("model_list", [])
llm_model_list = model_list new_model_list = model_list
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"\n new llm router model list {llm_model_list}" f"\n new llm router model list {new_model_list}"
) )
if ( if (
len(valid_token.models) == 0 len(valid_token.models) == 0

View file

@ -2,6 +2,7 @@ import sys
from typing import Any, Dict, List, Optional, get_args from typing import Any, Dict, List, Optional, get_args
import litellm import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
from litellm.proxy.utils import get_instance_fn from litellm.proxy.utils import get_instance_fn
@ -59,9 +60,15 @@ def initialize_callbacks_on_proxy(
presidio_logging_only presidio_logging_only
) # validate boolean given ) # validate boolean given
params = { _presidio_params = {}
if "presidio" in callback_specific_params and isinstance(
callback_specific_params["presidio"], dict
):
_presidio_params = callback_specific_params["presidio"]
params: Dict[str, Any] = {
"logging_only": presidio_logging_only, "logging_only": presidio_logging_only,
**callback_specific_params.get("presidio", {}), **_presidio_params,
} }
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params) pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
imported_list.append(pii_masking_object) imported_list.append(pii_masking_object)
@ -70,7 +77,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_LlamaGuard, _ENTERPRISE_LlamaGuard,
) )
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use Llama Guard" "Trying to use Llama Guard"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -83,7 +90,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_SecretDetection, _ENTERPRISE_SecretDetection,
) )
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use secret hiding" "Trying to use secret hiding"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -96,7 +103,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_OpenAI_Moderation, _ENTERPRISE_OpenAI_Moderation,
) )
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use OpenAI Moderations Check" "Trying to use OpenAI Moderations Check"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -126,7 +133,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_GoogleTextModeration, _ENTERPRISE_GoogleTextModeration,
) )
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use Google Text Moderation" "Trying to use Google Text Moderation"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -137,7 +144,7 @@ def initialize_callbacks_on_proxy(
elif isinstance(callback, str) and callback == "llmguard_moderations": elif isinstance(callback, str) and callback == "llmguard_moderations":
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use Llm Guard" "Trying to use Llm Guard"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -150,7 +157,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_BlockedUserList, _ENTERPRISE_BlockedUserList,
) )
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use ENTERPRISE BlockedUser" "Trying to use ENTERPRISE BlockedUser"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -165,7 +172,7 @@ def initialize_callbacks_on_proxy(
_ENTERPRISE_BannedKeywords, _ENTERPRISE_BannedKeywords,
) )
if premium_user != True: if premium_user is not True:
raise Exception( raise Exception(
"Trying to use ENTERPRISE BannedKeyword" "Trying to use ENTERPRISE BannedKeyword"
+ CommonProxyErrors.not_premium_user.value + CommonProxyErrors.not_premium_user.value
@ -212,7 +219,7 @@ def initialize_callbacks_on_proxy(
and isinstance(v, str) and isinstance(v, str)
and v.startswith("os.environ/") and v.startswith("os.environ/")
): ):
azure_content_safety_params[k] = litellm.get_secret(v) azure_content_safety_params[k] = get_secret(v)
azure_content_safety_obj = _PROXY_AzureContentSafety( azure_content_safety_obj = _PROXY_AzureContentSafety(
**azure_content_safety_params, **azure_content_safety_params,

View file

@ -6,13 +6,14 @@ import tracemalloc
from fastapi import APIRouter from fastapi import APIRouter
import litellm import litellm
from litellm import get_secret, get_secret_str
from litellm._logging import verbose_proxy_logger from litellm._logging import verbose_proxy_logger
router = APIRouter() router = APIRouter()
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true": if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
try: try:
import objgraph import objgraph # type: ignore
print("growth of objects") # noqa print("growth of objects") # noqa
objgraph.show_growth() objgraph.show_growth()
@ -21,8 +22,10 @@ if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
roots = objgraph.get_leaking_objects() roots = objgraph.get_leaking_objects()
print("\n\nLeaking objects") # noqa print("\n\nLeaking objects") # noqa
objgraph.show_most_common_types(objects=roots) objgraph.show_most_common_types(objects=roots)
except: except ImportError:
pass raise ImportError(
"objgraph not found. Please install objgraph to use this feature."
)
tracemalloc.start(10) tracemalloc.start(10)
@ -57,15 +60,20 @@ async def memory_usage_in_mem_cache():
user_api_key_cache, user_api_key_cache,
) )
if llm_router is None:
num_items_in_llm_router_cache = 0
else:
num_items_in_llm_router_cache = len(
llm_router.cache.in_memory_cache.cache_dict
) + len(llm_router.cache.in_memory_cache.ttl_dict)
num_items_in_user_api_key_cache = len( num_items_in_user_api_key_cache = len(
user_api_key_cache.in_memory_cache.cache_dict user_api_key_cache.in_memory_cache.cache_dict
) + len(user_api_key_cache.in_memory_cache.ttl_dict) ) + len(user_api_key_cache.in_memory_cache.ttl_dict)
num_items_in_llm_router_cache = len(
llm_router.cache.in_memory_cache.cache_dict
) + len(llm_router.cache.in_memory_cache.ttl_dict)
num_items_in_proxy_logging_obj_cache = len( num_items_in_proxy_logging_obj_cache = len(
proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
) + len(proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict) ) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
return { return {
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache, "num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
@ -89,13 +97,20 @@ async def memory_usage_in_mem_cache_items():
user_api_key_cache, user_api_key_cache,
) )
if llm_router is None:
llm_router_in_memory_cache_dict = {}
llm_router_in_memory_ttl_dict = {}
else:
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
return { return {
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict, "user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict, "user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
"llm_router_cache": llm_router.cache.in_memory_cache.cache_dict, "llm_router_cache": llm_router_in_memory_cache_dict,
"llm_router_ttl": llm_router.cache.in_memory_cache.ttl_dict, "llm_router_ttl": llm_router_in_memory_ttl_dict,
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict, "proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict, "proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
} }
@ -104,9 +119,18 @@ async def get_otel_spans():
from litellm.integrations.opentelemetry import OpenTelemetry from litellm.integrations.opentelemetry import OpenTelemetry
from litellm.proxy.proxy_server import open_telemetry_logger from litellm.proxy.proxy_server import open_telemetry_logger
open_telemetry_logger: OpenTelemetry = open_telemetry_logger if open_telemetry_logger is None:
return {
"otel_spans": [],
"spans_grouped_by_parent": {},
"most_recent_parent": None,
}
otel_exporter = open_telemetry_logger.OTEL_EXPORTER otel_exporter = open_telemetry_logger.OTEL_EXPORTER
recorded_spans = otel_exporter.get_finished_spans() if hasattr(otel_exporter, "get_finished_spans"):
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
else:
recorded_spans = []
print("Spans: ", recorded_spans) # noqa print("Spans: ", recorded_spans) # noqa
@ -137,11 +161,13 @@ async def get_otel_spans():
# Helper functions for debugging # Helper functions for debugging
def init_verbose_loggers(): def init_verbose_loggers():
try: try:
worker_config = litellm.get_secret("WORKER_CONFIG") worker_config = get_secret_str("WORKER_CONFIG")
# if not, assume it's a json string
if worker_config is None:
return
if os.path.isfile(worker_config): if os.path.isfile(worker_config):
return return
# if not, assume it's a json string _settings = json.loads(worker_config)
_settings = json.loads(os.getenv("WORKER_CONFIG"))
if not isinstance(_settings, dict): if not isinstance(_settings, dict):
return return
@ -162,7 +188,7 @@ def init_verbose_loggers():
level=logging.INFO level=logging.INFO
) # set router logs to info ) # set router logs to info
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
if detailed_debug == True: if detailed_debug is True:
import logging import logging
from litellm._logging import ( from litellm._logging import (
@ -178,10 +204,10 @@ def init_verbose_loggers():
verbose_proxy_logger.setLevel( verbose_proxy_logger.setLevel(
level=logging.DEBUG level=logging.DEBUG
) # set proxy logs to debug ) # set proxy logs to debug
elif debug == False and detailed_debug == False: elif debug is False and detailed_debug is False:
# users can control proxy debugging using env variable = 'LITELLM_LOG' # users can control proxy debugging using env variable = 'LITELLM_LOG'
litellm_log_setting = os.environ.get("LITELLM_LOG", "") litellm_log_setting = os.environ.get("LITELLM_LOG", "")
if litellm_log_setting != None: if litellm_log_setting is not None:
if litellm_log_setting.upper() == "INFO": if litellm_log_setting.upper() == "INFO":
import logging import logging
@ -213,4 +239,6 @@ def init_verbose_loggers():
level=logging.DEBUG level=logging.DEBUG
) # set proxy logs to debug ) # set proxy logs to debug
except Exception as e: except Exception as e:
verbose_logger.warning(f"Failed to init verbose loggers: {str(e)}") import logging
logging.warning(f"Failed to init verbose loggers: {str(e)}")

Some files were not shown because too many files have changed in this diff Show more