forked from phoenix/litellm-mirror
Litellm ruff linting enforcement (#5992)
* ci(config.yml): add a 'check_code_quality' step Addresses https://github.com/BerriAI/litellm/issues/5991 * ci(config.yml): check why circle ci doesn't pick up this test * ci(config.yml): fix to run 'check_code_quality' tests * fix(__init__.py): fix unprotected import * fix(__init__.py): don't remove unused imports * build(ruff.toml): update ruff.toml to ignore unused imports * fix: fix: ruff + pyright - fix linting + type-checking errors * fix: fix linting errors * fix(lago.py): fix module init error * fix: fix linting errors * ci(config.yml): cd into correct dir for checks * fix(proxy_server.py): fix linting error * fix(utils.py): fix bare except causes ruff linting errors * fix: ruff - fix remaining linting errors * fix(clickhouse.py): use standard logging object * fix(__init__.py): fix unprotected import * fix: ruff - fix linting errors * fix: fix linting errors * ci(config.yml): cleanup code qa step (formatting handled in local_testing) * fix(_health_endpoints.py): fix ruff linting errors * ci(config.yml): just use ruff in check_code_quality pipeline for now * build(custom_guardrail.py): include missing file * style(embedding_handler.py): fix ruff check
This commit is contained in:
parent
3fc4ae0d65
commit
d57be47b0f
263 changed files with 1687 additions and 3320 deletions
|
@ -299,6 +299,27 @@ jobs:
|
|||
ls
|
||||
python -m pytest -vv tests/local_testing/test_python_38.py
|
||||
|
||||
check_code_quality:
|
||||
docker:
|
||||
- image: cimg/python:3.11
|
||||
auth:
|
||||
username: ${DOCKERHUB_USERNAME}
|
||||
password: ${DOCKERHUB_PASSWORD}
|
||||
working_directory: ~/project/litellm
|
||||
|
||||
steps:
|
||||
- checkout
|
||||
- run:
|
||||
name: Install Dependencies
|
||||
command: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install ruff
|
||||
pip install pylint
|
||||
pip install .
|
||||
- run: python -c "from litellm import *" || (echo '🚨 import failed, this means you introduced unprotected imports! 🚨'; exit 1)
|
||||
- run: ruff check ./litellm
|
||||
|
||||
|
||||
build_and_test:
|
||||
machine:
|
||||
image: ubuntu-2204:2023.10.1
|
||||
|
@ -806,6 +827,12 @@ workflows:
|
|||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- check_code_quality:
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
- main
|
||||
- /litellm_.*/
|
||||
- ui_endpoint_testing:
|
||||
filters:
|
||||
branches:
|
||||
|
@ -867,6 +894,7 @@ workflows:
|
|||
- installing_litellm_on_python
|
||||
- proxy_logging_guardrails_model_info_tests
|
||||
- proxy_pass_through_endpoint_tests
|
||||
- check_code_quality
|
||||
filters:
|
||||
branches:
|
||||
only:
|
||||
|
|
|
@ -630,18 +630,6 @@ general_settings:
|
|||
"database_url": "string",
|
||||
"database_connection_pool_limit": 0, # default 100
|
||||
"database_connection_timeout": 0, # default 60s
|
||||
"database_type": "dynamo_db",
|
||||
"database_args": {
|
||||
"billing_mode": "PROVISIONED_THROUGHPUT",
|
||||
"read_capacity_units": 0,
|
||||
"write_capacity_units": 0,
|
||||
"ssl_verify": true,
|
||||
"region_name": "string",
|
||||
"user_table_name": "LiteLLM_UserTable",
|
||||
"key_table_name": "LiteLLM_VerificationToken",
|
||||
"config_table_name": "LiteLLM_Config",
|
||||
"spend_table_name": "LiteLLM_SpendLogs"
|
||||
},
|
||||
"otel": true,
|
||||
"custom_auth": "string",
|
||||
"max_parallel_requests": 0, # the max parallel requests allowed per deployment
|
||||
|
|
|
@ -97,7 +97,7 @@ class GenericAPILogger:
|
|||
for key, value in payload.items():
|
||||
try:
|
||||
payload[key] = str(value)
|
||||
except:
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
def __init__(self):
|
||||
try:
|
||||
from google.cloud import language_v1
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Missing google.cloud package. Run `pip install --upgrade google-cloud-language`"
|
||||
)
|
||||
|
@ -90,7 +90,7 @@ class _ENTERPRISE_GoogleTextModeration(CustomLogger):
|
|||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def async_moderation_hook(
|
||||
|
|
|
@ -58,7 +58,7 @@ class _ENTERPRISE_LlamaGuard(CustomLogger):
|
|||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_custom_prompt_template(self, messages: list):
|
||||
|
|
|
@ -49,7 +49,7 @@ class _ENTERPRISE_LLMGuard(CustomLogger):
|
|||
verbose_proxy_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def moderation_check(self, text: str):
|
||||
|
|
|
@ -3,7 +3,8 @@ import warnings
|
|||
|
||||
warnings.filterwarnings("ignore", message=".*conflict with protected namespace.*")
|
||||
### INIT VARIABLES ###
|
||||
import threading, requests, os
|
||||
import threading
|
||||
import os
|
||||
from typing import Callable, List, Optional, Dict, Union, Any, Literal, get_args
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.caching import Cache
|
||||
|
@ -308,13 +309,13 @@ def get_model_cost_map(url: str):
|
|||
return content
|
||||
|
||||
try:
|
||||
with requests.get(
|
||||
response = httpx.get(
|
||||
url, timeout=5
|
||||
) as response: # set a 5 second timeout for the get request
|
||||
) # set a 5 second timeout for the get request
|
||||
response.raise_for_status() # Raise an exception if the request is unsuccessful
|
||||
content = response.json()
|
||||
return content
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
import importlib.resources
|
||||
import json
|
||||
|
||||
|
@ -839,7 +840,7 @@ openai_image_generation_models = ["dall-e-2", "dall-e-3"]
|
|||
|
||||
from .timeout import timeout
|
||||
from .cost_calculator import completion_cost
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging, modify_integration
|
||||
from litellm.litellm_core_utils.get_llm_provider_logic import get_llm_provider
|
||||
from litellm.litellm_core_utils.core_helpers import remove_index_from_tool_calls
|
||||
from litellm.litellm_core_utils.token_counter import get_modified_max_tokens
|
||||
|
@ -848,7 +849,6 @@ from .utils import (
|
|||
exception_type,
|
||||
get_optional_params,
|
||||
get_response_string,
|
||||
modify_integration,
|
||||
token_counter,
|
||||
create_pretrained_tokenizer,
|
||||
create_tokenizer,
|
||||
|
|
|
@ -98,5 +98,5 @@ def print_verbose(print_statement):
|
|||
try:
|
||||
if set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -2,5 +2,5 @@ import importlib_metadata
|
|||
|
||||
try:
|
||||
version = importlib_metadata.version("litellm")
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
|
|
@ -12,7 +12,6 @@ from openai.types.beta.assistant import Assistant
|
|||
from openai.types.beta.assistant_deleted import AssistantDeleted
|
||||
|
||||
import litellm
|
||||
from litellm import client
|
||||
from litellm.llms.AzureOpenAI import assistants
|
||||
from litellm.types.router import GenericLiteLLMParams
|
||||
from litellm.utils import (
|
||||
|
@ -96,7 +95,7 @@ def get_assistants(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -280,7 +279,7 @@ def create_assistants(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -464,7 +463,7 @@ def delete_assistant(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -649,7 +648,7 @@ def create_thread(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -805,7 +804,7 @@ def get_thread(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -991,7 +990,7 @@ def add_message(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -1149,7 +1148,7 @@ def get_messages(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -1347,7 +1346,7 @@ def run_thread(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
|
|
@ -22,7 +22,7 @@ import litellm
|
|||
from litellm import client
|
||||
from litellm.llms.AzureOpenAI.azure import AzureBatchesAPI
|
||||
from litellm.llms.OpenAI.openai import OpenAIBatchesAPI
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.secret_managers.main import get_secret, get_secret_str
|
||||
from litellm.types.llms.openai import (
|
||||
Batch,
|
||||
CancelBatchRequest,
|
||||
|
@ -131,7 +131,7 @@ def create_batch(
|
|||
extra_headers=extra_headers,
|
||||
extra_body=extra_body,
|
||||
)
|
||||
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -165,27 +165,30 @@ def create_batch(
|
|||
_is_async=_is_async,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.create_batch(
|
||||
_is_async=_is_async,
|
||||
|
@ -293,7 +296,7 @@ def retrieve_batch(
|
|||
)
|
||||
|
||||
_is_async = kwargs.pop("aretrieve_batch", False) is True
|
||||
|
||||
api_base: Optional[str] = None
|
||||
if custom_llm_provider == "openai":
|
||||
|
||||
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
|
||||
|
@ -327,27 +330,30 @@ def retrieve_batch(
|
|||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = (
|
||||
optional_params.api_base
|
||||
or litellm.api_base
|
||||
or get_secret_str("AZURE_API_BASE")
|
||||
)
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.retrieve_batch(
|
||||
_is_async=_is_async,
|
||||
|
@ -384,7 +390,7 @@ async def alist_batches(
|
|||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Batch:
|
||||
):
|
||||
"""
|
||||
Async: List your organization's batches.
|
||||
"""
|
||||
|
@ -482,27 +488,26 @@ def list_batches(
|
|||
max_retries=optional_params.max_retries,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
)
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
)
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_batches_instance.list_batches(
|
||||
_is_async=_is_async,
|
||||
|
|
|
@ -7,11 +7,16 @@
|
|||
#
|
||||
# Thank you users! We ❤️ you! - Krrish & Ishaan
|
||||
|
||||
import os, json, time
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm.utils import ModelResponse
|
||||
import requests, threading # type: ignore
|
||||
from typing import Optional, Union, Literal
|
||||
|
||||
|
||||
class BudgetManager:
|
||||
|
@ -35,7 +40,7 @@ class BudgetManager:
|
|||
import logging
|
||||
|
||||
logging.info(print_statement)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def load_data(self):
|
||||
|
@ -52,7 +57,6 @@ class BudgetManager:
|
|||
elif self.client_type == "hosted":
|
||||
# Load the user_dict from hosted db
|
||||
url = self.api_base + "/get_budget"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"project_name": self.project_name}
|
||||
response = requests.post(url, headers=self.headers, json=data)
|
||||
response = response.json()
|
||||
|
@ -210,7 +214,6 @@ class BudgetManager:
|
|||
return {"status": "success"}
|
||||
elif self.client_type == "hosted":
|
||||
url = self.api_base + "/set_budget"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
data = {"project_name": self.project_name, "user_dict": self.user_dict}
|
||||
response = requests.post(url, headers=self.headers, json=data)
|
||||
response = response.json()
|
||||
|
|
|
@ -33,7 +33,7 @@ def print_verbose(print_statement):
|
|||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -96,15 +96,13 @@ class InMemoryCache(BaseCache):
|
|||
"""
|
||||
for key in list(self.ttl_dict.keys()):
|
||||
if time.time() > self.ttl_dict[key]:
|
||||
removed_item = self.cache_dict.pop(key, None)
|
||||
removed_ttl_item = self.ttl_dict.pop(key, None)
|
||||
self.cache_dict.pop(key, None)
|
||||
self.ttl_dict.pop(key, None)
|
||||
|
||||
# de-reference the removed item
|
||||
# https://www.geeksforgeeks.org/diagnosing-and-fixing-memory-leaks-in-python/
|
||||
# One of the most common causes of memory leaks in Python is the retention of objects that are no longer being used.
|
||||
# This can occur when an object is referenced by another object, but the reference is never removed.
|
||||
removed_item = None
|
||||
removed_ttl_item = None
|
||||
|
||||
def set_cache(self, key, value, **kwargs):
|
||||
print_verbose(
|
||||
|
@ -150,7 +148,7 @@ class InMemoryCache(BaseCache):
|
|||
original_cached_response = self.cache_dict[key]
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response)
|
||||
except:
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
@ -251,7 +249,7 @@ class RedisCache(BaseCache):
|
|||
self.redis_version = "Unknown"
|
||||
try:
|
||||
self.redis_version = self.redis_client.info()["redis_version"]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
### ASYNC HEALTH PING ###
|
||||
|
@ -688,7 +686,7 @@ class RedisCache(BaseCache):
|
|||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
|
@ -844,7 +842,7 @@ class RedisCache(BaseCache):
|
|||
"""
|
||||
Tests if the sync redis client is correctly setup.
|
||||
"""
|
||||
print_verbose(f"Pinging Sync Redis Cache")
|
||||
print_verbose("Pinging Sync Redis Cache")
|
||||
start_time = time.time()
|
||||
try:
|
||||
response = self.redis_client.ping()
|
||||
|
@ -878,7 +876,7 @@ class RedisCache(BaseCache):
|
|||
_redis_client = self.init_async_client()
|
||||
start_time = time.time()
|
||||
async with _redis_client as redis_client:
|
||||
print_verbose(f"Pinging Async Redis Cache")
|
||||
print_verbose("Pinging Async Redis Cache")
|
||||
try:
|
||||
response = await redis_client.ping()
|
||||
## LOGGING ##
|
||||
|
@ -973,7 +971,6 @@ class RedisSemanticCache(BaseCache):
|
|||
},
|
||||
"fields": {
|
||||
"text": [{"name": "response"}],
|
||||
"text": [{"name": "prompt"}],
|
||||
"vector": [
|
||||
{
|
||||
"name": "litellm_embedding",
|
||||
|
@ -999,14 +996,14 @@ class RedisSemanticCache(BaseCache):
|
|||
|
||||
redis_url = "redis://:" + password + "@" + host + ":" + port
|
||||
print_verbose(f"redis semantic-cache redis_url: {redis_url}")
|
||||
if use_async == False:
|
||||
if use_async is False:
|
||||
self.index = SearchIndex.from_dict(schema)
|
||||
self.index.connect(redis_url=redis_url)
|
||||
try:
|
||||
self.index.create(overwrite=False) # don't overwrite existing index
|
||||
except Exception as e:
|
||||
print_verbose(f"Got exception creating semantic cache index: {str(e)}")
|
||||
elif use_async == True:
|
||||
elif use_async is True:
|
||||
schema["index"]["name"] = "litellm_semantic_cache_index_async"
|
||||
self.index = SearchIndex.from_dict(schema)
|
||||
self.index.connect(redis_url=redis_url, use_async=True)
|
||||
|
@ -1027,7 +1024,7 @@ class RedisSemanticCache(BaseCache):
|
|||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
|
@ -1060,7 +1057,7 @@ class RedisSemanticCache(BaseCache):
|
|||
]
|
||||
|
||||
# Add more data
|
||||
keys = self.index.load(new_data)
|
||||
self.index.load(new_data)
|
||||
|
||||
return
|
||||
|
||||
|
@ -1092,7 +1089,7 @@ class RedisSemanticCache(BaseCache):
|
|||
)
|
||||
|
||||
results = self.index.query(query)
|
||||
if results == None:
|
||||
if results is None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
|
@ -1173,7 +1170,7 @@ class RedisSemanticCache(BaseCache):
|
|||
]
|
||||
|
||||
# Add more data
|
||||
keys = await self.index.aload(new_data)
|
||||
await self.index.aload(new_data)
|
||||
return
|
||||
|
||||
async def async_get_cache(self, key, **kwargs):
|
||||
|
@ -1222,7 +1219,7 @@ class RedisSemanticCache(BaseCache):
|
|||
return_fields=["response", "prompt", "vector_distance"],
|
||||
)
|
||||
results = await self.index.aquery(query)
|
||||
if results == None:
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
|
@ -1396,7 +1393,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
return cached_response
|
||||
|
||||
|
@ -1435,7 +1432,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
},
|
||||
]
|
||||
}
|
||||
keys = self.sync_client.put(
|
||||
self.sync_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
|
@ -1481,7 +1478,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
)
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results == None:
|
||||
if results is None:
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
if len(results) == 0:
|
||||
|
@ -1563,7 +1560,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
]
|
||||
}
|
||||
|
||||
keys = await self.async_client.put(
|
||||
await self.async_client.put(
|
||||
url=f"{self.qdrant_api_base}/collections/{self.collection_name}/points",
|
||||
headers=self.headers,
|
||||
json=data,
|
||||
|
@ -1629,7 +1626,7 @@ class QdrantSemanticCache(BaseCache):
|
|||
|
||||
results = search_response.json()["result"]
|
||||
|
||||
if results == None:
|
||||
if results is None:
|
||||
kwargs.setdefault("metadata", {})["semantic-similarity"] = 0.0
|
||||
return None
|
||||
if isinstance(results, list):
|
||||
|
@ -1767,7 +1764,7 @@ class S3Cache(BaseCache):
|
|||
cached_response = json.loads(
|
||||
cached_response
|
||||
) # Convert string to dictionary
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response)
|
||||
if type(cached_response) is not dict:
|
||||
cached_response = dict(cached_response)
|
||||
|
@ -1845,7 +1842,7 @@ class DualCache(BaseCache):
|
|||
|
||||
self.in_memory_cache.set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
self.redis_cache.set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
print_verbose(e)
|
||||
|
@ -1865,7 +1862,7 @@ class DualCache(BaseCache):
|
|||
if self.in_memory_cache is not None:
|
||||
result = self.in_memory_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
result = self.redis_cache.increment_cache(key, value, **kwargs)
|
||||
|
||||
return result
|
||||
|
@ -1887,7 +1884,7 @@ class DualCache(BaseCache):
|
|||
if (
|
||||
(self.always_read_redis is True)
|
||||
and self.redis_cache is not None
|
||||
and local_only == False
|
||||
and local_only is False
|
||||
):
|
||||
# If not found in in-memory cache or always_read_redis is True, try fetching from Redis
|
||||
redis_result = self.redis_cache.get_cache(key, **kwargs)
|
||||
|
@ -1900,7 +1897,7 @@ class DualCache(BaseCache):
|
|||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
def batch_get_cache(self, keys: list, local_only: bool = False, **kwargs):
|
||||
|
@ -1913,7 +1910,7 @@ class DualCache(BaseCache):
|
|||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if None in result and self.redis_cache is not None and local_only == False:
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
|
@ -1933,7 +1930,7 @@ class DualCache(BaseCache):
|
|||
|
||||
print_verbose(f"async batch get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_get_cache(self, key, local_only: bool = False, **kwargs):
|
||||
|
@ -1952,7 +1949,7 @@ class DualCache(BaseCache):
|
|||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
|
||||
if result is None and self.redis_cache is not None and local_only == False:
|
||||
if result is None and self.redis_cache is not None and local_only is False:
|
||||
# If not found in in-memory cache, try fetching from Redis
|
||||
redis_result = await self.redis_cache.async_get_cache(key, **kwargs)
|
||||
|
||||
|
@ -1966,7 +1963,7 @@ class DualCache(BaseCache):
|
|||
|
||||
print_verbose(f"get cache: cache result: {result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_batch_get_cache(
|
||||
|
@ -1981,7 +1978,7 @@ class DualCache(BaseCache):
|
|||
|
||||
if in_memory_result is not None:
|
||||
result = in_memory_result
|
||||
if None in result and self.redis_cache is not None and local_only == False:
|
||||
if None in result and self.redis_cache is not None and local_only is False:
|
||||
"""
|
||||
- for the none values in the result
|
||||
- check the redis cache
|
||||
|
@ -2006,7 +2003,7 @@ class DualCache(BaseCache):
|
|||
result[index] = value
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.error(traceback.format_exc())
|
||||
|
||||
async def async_set_cache(self, key, value, local_only: bool = False, **kwargs):
|
||||
|
@ -2017,7 +2014,7 @@ class DualCache(BaseCache):
|
|||
if self.in_memory_cache is not None:
|
||||
await self.in_memory_cache.async_set_cache(key, value, **kwargs)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache(key, value, **kwargs)
|
||||
except Exception as e:
|
||||
verbose_logger.exception(
|
||||
|
@ -2039,7 +2036,7 @@ class DualCache(BaseCache):
|
|||
cache_list=cache_list, **kwargs
|
||||
)
|
||||
|
||||
if self.redis_cache is not None and local_only == False:
|
||||
if self.redis_cache is not None and local_only is False:
|
||||
await self.redis_cache.async_set_cache_pipeline(
|
||||
cache_list=cache_list, ttl=kwargs.pop("ttl", None), **kwargs
|
||||
)
|
||||
|
@ -2459,7 +2456,7 @@ class Cache:
|
|||
cached_response = json.loads(
|
||||
cached_response # type: ignore
|
||||
) # Convert string to dictionary
|
||||
except:
|
||||
except Exception:
|
||||
cached_response = ast.literal_eval(cached_response) # type: ignore
|
||||
return cached_response
|
||||
return cached_result
|
||||
|
@ -2492,7 +2489,7 @@ class Cache:
|
|||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
@ -2506,7 +2503,7 @@ class Cache:
|
|||
if self.should_use_cache(*args, **kwargs) is not True:
|
||||
return
|
||||
|
||||
messages = kwargs.get("messages", [])
|
||||
kwargs.get("messages", [])
|
||||
if "cache_key" in kwargs:
|
||||
cache_key = kwargs["cache_key"]
|
||||
else:
|
||||
|
@ -2522,7 +2519,7 @@ class Cache:
|
|||
return self._get_cache_logic(
|
||||
cached_result=cached_result, max_age=max_age
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
print_verbose(f"An exception occurred: {traceback.format_exc()}")
|
||||
return None
|
||||
|
||||
|
@ -2701,7 +2698,7 @@ class DiskCache(BaseCache):
|
|||
if original_cached_response:
|
||||
try:
|
||||
cached_response = json.loads(original_cached_response) # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
cached_response = original_cached_response
|
||||
return cached_response
|
||||
return None
|
||||
|
@ -2803,7 +2800,7 @@ def enable_cache(
|
|||
if "cache" not in litellm._async_success_callback:
|
||||
litellm._async_success_callback.append("cache")
|
||||
|
||||
if litellm.cache == None:
|
||||
if litellm.cache is None:
|
||||
litellm.cache = Cache(
|
||||
type=type,
|
||||
host=host,
|
||||
|
|
|
@ -57,7 +57,7 @@
|
|||
# config = yaml.safe_load(file)
|
||||
# else:
|
||||
# pass
|
||||
# except:
|
||||
# except Exception:
|
||||
# pass
|
||||
|
||||
# ## SERVER SETTINGS (e.g. default completion model = 'ollama/mistral')
|
||||
|
|
|
@ -9,12 +9,12 @@ import asyncio
|
|||
import contextvars
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union
|
||||
from typing import Any, Coroutine, Dict, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import client, get_secret
|
||||
from litellm import client, get_secret_str
|
||||
from litellm.llms.files_apis.azure import AzureOpenAIFilesAPI
|
||||
from litellm.llms.OpenAI.openai import FileDeleted, FileObject, OpenAIFilesAPI
|
||||
from litellm.types.llms.openai import (
|
||||
|
@ -39,7 +39,7 @@ async def afile_retrieve(
|
|||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> Coroutine[Any, Any, FileObject]:
|
||||
):
|
||||
"""
|
||||
Async: Get file contents
|
||||
|
||||
|
@ -66,7 +66,7 @@ async def afile_retrieve(
|
|||
if asyncio.iscoroutine(init_response):
|
||||
response = await init_response
|
||||
else:
|
||||
response = init_response # type: ignore
|
||||
response = init_response
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
|
@ -137,27 +137,26 @@ def file_retrieve(
|
|||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.retrieve_file(
|
||||
_is_async=_is_async,
|
||||
|
@ -181,7 +180,7 @@ def file_retrieve(
|
|||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
return cast(FileObject, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -222,7 +221,7 @@ async def afile_delete(
|
|||
else:
|
||||
response = init_response # type: ignore
|
||||
|
||||
return response
|
||||
return cast(FileDeleted, response) # type: ignore
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -248,7 +247,7 @@ def file_delete(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -288,27 +287,26 @@ def file_delete(
|
|||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.delete_file(
|
||||
_is_async=_is_async,
|
||||
|
@ -332,7 +330,7 @@ def file_delete(
|
|||
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
|
||||
),
|
||||
)
|
||||
return response
|
||||
return cast(FileDeleted, response)
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
@ -399,7 +397,7 @@ def file_list(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -441,27 +439,26 @@ def file_list(
|
|||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.list_files(
|
||||
_is_async=_is_async,
|
||||
|
@ -556,7 +553,7 @@ def create_file(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -603,27 +600,26 @@ def create_file(
|
|||
create_file_data=_create_file_request,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.create_file(
|
||||
_is_async=_is_async,
|
||||
|
@ -713,7 +709,7 @@ def file_content(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -761,27 +757,26 @@ def file_content(
|
|||
organization=organization,
|
||||
)
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_files_instance.file_content(
|
||||
_is_async=_is_async,
|
||||
|
|
|
@ -25,7 +25,7 @@ from litellm.llms.fine_tuning_apis.openai import (
|
|||
OpenAIFineTuningAPI,
|
||||
)
|
||||
from litellm.llms.fine_tuning_apis.vertex_ai import VertexFineTuningAPI
|
||||
from litellm.secret_managers.main import get_secret
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.llms.openai import Hyperparameters
|
||||
from litellm.types.router import *
|
||||
from litellm.utils import supports_httpx_timeout
|
||||
|
@ -119,7 +119,7 @@ def create_fine_tuning_job(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -177,28 +177,27 @@ def create_fine_tuning_job(
|
|||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
model=model,
|
||||
|
@ -228,14 +227,14 @@ def create_fine_tuning_job(
|
|||
vertex_ai_project = (
|
||||
optional_params.vertex_project
|
||||
or litellm.vertex_project
|
||||
or get_secret("VERTEXAI_PROJECT")
|
||||
or get_secret_str("VERTEXAI_PROJECT")
|
||||
)
|
||||
vertex_ai_location = (
|
||||
optional_params.vertex_location
|
||||
or litellm.vertex_location
|
||||
or get_secret("VERTEXAI_LOCATION")
|
||||
or get_secret_str("VERTEXAI_LOCATION")
|
||||
)
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret(
|
||||
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
|
||||
"VERTEXAI_CREDENTIALS"
|
||||
)
|
||||
create_fine_tuning_job_data = FineTuningJobCreate(
|
||||
|
@ -315,7 +314,7 @@ async def acancel_fine_tuning_job(
|
|||
|
||||
def cancel_fine_tuning_job(
|
||||
fine_tuning_job_id: str,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -335,7 +334,7 @@ def cancel_fine_tuning_job(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -386,23 +385,22 @@ def cancel_fine_tuning_job(
|
|||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret_str("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.cancel_fine_tuning_job(
|
||||
api_base=api_base,
|
||||
|
@ -438,7 +436,7 @@ async def alist_fine_tuning_jobs(
|
|||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
) -> FineTuningJob:
|
||||
):
|
||||
"""
|
||||
Async: List your organization's fine-tuning jobs
|
||||
"""
|
||||
|
@ -473,7 +471,7 @@ async def alist_fine_tuning_jobs(
|
|||
def list_fine_tuning_jobs(
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = None,
|
||||
custom_llm_provider: Literal["openai"] = "openai",
|
||||
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
|
||||
extra_headers: Optional[Dict[str, str]] = None,
|
||||
extra_body: Optional[Dict[str, str]] = None,
|
||||
**kwargs,
|
||||
|
@ -495,7 +493,7 @@ def list_fine_tuning_jobs(
|
|||
if (
|
||||
timeout is not None
|
||||
and isinstance(timeout, httpx.Timeout)
|
||||
and supports_httpx_timeout(custom_llm_provider) == False
|
||||
and supports_httpx_timeout(custom_llm_provider) is False
|
||||
):
|
||||
read_timeout = timeout.read or 600
|
||||
timeout = read_timeout # default 10 min timeout
|
||||
|
@ -542,28 +540,27 @@ def list_fine_tuning_jobs(
|
|||
)
|
||||
# Azure OpenAI
|
||||
elif custom_llm_provider == "azure":
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret("AZURE_API_BASE") # type: ignore
|
||||
api_base = optional_params.api_base or litellm.api_base or get_secret_str("AZURE_API_BASE") # type: ignore
|
||||
|
||||
api_version = (
|
||||
optional_params.api_version
|
||||
or litellm.api_version
|
||||
or get_secret("AZURE_API_VERSION")
|
||||
or get_secret_str("AZURE_API_VERSION")
|
||||
) # type: ignore
|
||||
|
||||
api_key = (
|
||||
optional_params.api_key
|
||||
or litellm.api_key
|
||||
or litellm.azure_key
|
||||
or get_secret("AZURE_OPENAI_API_KEY")
|
||||
or get_secret("AZURE_API_KEY")
|
||||
or get_secret_str("AZURE_OPENAI_API_KEY")
|
||||
or get_secret_str("AZURE_API_KEY")
|
||||
) # type: ignore
|
||||
|
||||
extra_body = optional_params.get("extra_body", {})
|
||||
azure_ad_token: Optional[str] = None
|
||||
if extra_body is not None:
|
||||
azure_ad_token = extra_body.pop("azure_ad_token", None)
|
||||
extra_body.pop("azure_ad_token", None)
|
||||
else:
|
||||
azure_ad_token = get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
get_secret("AZURE_AD_TOKEN") # type: ignore
|
||||
|
||||
response = azure_fine_tuning_apis_instance.list_fine_tuning_jobs(
|
||||
api_base=api_base,
|
||||
|
|
|
@ -23,6 +23,9 @@ import litellm.types
|
|||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.integrations.custom_batch_logger import CustomBatchLogger
|
||||
from litellm.litellm_core_utils.exception_mapping_utils import (
|
||||
_add_key_name_and_team_to_alert,
|
||||
)
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
AsyncHTTPHandler,
|
||||
|
@ -219,7 +222,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
and "metadata" in kwargs["litellm_params"]
|
||||
):
|
||||
_metadata: dict = kwargs["litellm_params"]["metadata"]
|
||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||
request_info = _add_key_name_and_team_to_alert(
|
||||
request_info=request_info, metadata=_metadata
|
||||
)
|
||||
|
||||
|
@ -281,7 +284,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
return_val += 1
|
||||
|
||||
return return_val
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
async def send_daily_reports(self, router) -> bool:
|
||||
|
@ -455,7 +458,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
try:
|
||||
messages = str(messages)
|
||||
messages = messages[:100]
|
||||
except:
|
||||
except Exception:
|
||||
messages = ""
|
||||
|
||||
if (
|
||||
|
@ -508,7 +511,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
_metadata: dict = request_data["metadata"]
|
||||
_api_base = _metadata.get("api_base", "")
|
||||
|
||||
request_info = litellm.utils._add_key_name_and_team_to_alert(
|
||||
request_info = _add_key_name_and_team_to_alert(
|
||||
request_info=request_info, metadata=_metadata
|
||||
)
|
||||
|
||||
|
@ -846,7 +849,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
|
||||
## MINOR OUTAGE ALERT SENT ##
|
||||
if (
|
||||
outage_value["minor_alert_sent"] == False
|
||||
outage_value["minor_alert_sent"] is False
|
||||
and len(outage_value["alerts"])
|
||||
>= self.alerting_args.minor_outage_alert_threshold
|
||||
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
||||
|
@ -871,7 +874,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
|
||||
## MAJOR OUTAGE ALERT SENT ##
|
||||
elif (
|
||||
outage_value["major_alert_sent"] == False
|
||||
outage_value["major_alert_sent"] is False
|
||||
and len(outage_value["alerts"])
|
||||
>= self.alerting_args.major_outage_alert_threshold
|
||||
and len(_deployment_set) > 1 # make sure it's not just 1 bad deployment
|
||||
|
@ -941,7 +944,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
if provider is None:
|
||||
try:
|
||||
model, provider, _, _ = litellm.get_llm_provider(model=model)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
provider = ""
|
||||
api_base = litellm.get_api_base(
|
||||
model=model, optional_params=deployment.litellm_params
|
||||
|
@ -976,7 +979,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
|
||||
## MINOR OUTAGE ALERT SENT ##
|
||||
if (
|
||||
outage_value["minor_alert_sent"] == False
|
||||
outage_value["minor_alert_sent"] is False
|
||||
and len(outage_value["alerts"])
|
||||
>= self.alerting_args.minor_outage_alert_threshold
|
||||
):
|
||||
|
@ -998,7 +1001,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
# set to true
|
||||
outage_value["minor_alert_sent"] = True
|
||||
elif (
|
||||
outage_value["major_alert_sent"] == False
|
||||
outage_value["major_alert_sent"] is False
|
||||
and len(outage_value["alerts"])
|
||||
>= self.alerting_args.major_outage_alert_threshold
|
||||
):
|
||||
|
@ -1024,7 +1027,7 @@ class SlackAlerting(CustomBatchLogger):
|
|||
await self.internal_usage_cache.async_set_cache(
|
||||
key=deployment_id, value=outage_value
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def model_added_alert(
|
||||
|
@ -1177,7 +1180,6 @@ Model Info:
|
|||
if user_row is not None:
|
||||
recipient_email = user_row.user_email
|
||||
|
||||
key_name = webhook_event.key_alias
|
||||
key_token = webhook_event.token
|
||||
key_budget = webhook_event.max_budget
|
||||
base_url = os.getenv("PROXY_BASE_URL", "http://0.0.0.0:4000")
|
||||
|
@ -1221,14 +1223,14 @@ Model Info:
|
|||
extra=webhook_event.model_dump(),
|
||||
)
|
||||
|
||||
payload = webhook_event.model_dump_json()
|
||||
webhook_event.model_dump_json()
|
||||
email_event = {
|
||||
"to": recipient_email,
|
||||
"subject": f"LiteLLM: {event_name}",
|
||||
"html": email_html_content,
|
||||
}
|
||||
|
||||
response = await send_email(
|
||||
await send_email(
|
||||
receiver_email=email_event["to"],
|
||||
subject=email_event["subject"],
|
||||
html=email_event["html"],
|
||||
|
@ -1292,14 +1294,14 @@ Model Info:
|
|||
The LiteLLM team <br />
|
||||
"""
|
||||
|
||||
payload = webhook_event.model_dump_json()
|
||||
webhook_event.model_dump_json()
|
||||
email_event = {
|
||||
"to": recipient_email,
|
||||
"subject": f"LiteLLM: {event_name}",
|
||||
"html": email_html_content,
|
||||
}
|
||||
|
||||
response = await send_email(
|
||||
await send_email(
|
||||
receiver_email=email_event["to"],
|
||||
subject=email_event["subject"],
|
||||
html=email_event["html"],
|
||||
|
@ -1446,7 +1448,6 @@ Model Info:
|
|||
response_s: timedelta = end_time - start_time
|
||||
|
||||
final_value = response_s
|
||||
total_tokens = 0
|
||||
|
||||
if isinstance(response_obj, litellm.ModelResponse) and (
|
||||
hasattr(response_obj, "usage")
|
||||
|
@ -1505,7 +1506,7 @@ Model Info:
|
|||
await self.region_outage_alerts(
|
||||
exception=kwargs["exception"], deployment_id=model_id
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def _run_scheduler_helper(self, llm_router) -> bool:
|
||||
|
|
|
@ -35,7 +35,7 @@ class LiteLLMBase(BaseModel):
|
|||
def json(self, **kwargs): # type: ignore
|
||||
try:
|
||||
return self.model_dump() # noqa
|
||||
except:
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
return self.dict()
|
||||
|
||||
|
|
|
@ -1,8 +1,10 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
import traceback
|
||||
import datetime
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import dotenv
|
||||
|
||||
model_cost = {
|
||||
"gpt-3.5-turbo": {
|
||||
|
@ -118,8 +120,6 @@ class AISpendLogger:
|
|||
for model in model_cost:
|
||||
input_cost_sum += model_cost[model]["input_cost_per_token"]
|
||||
output_cost_sum += model_cost[model]["output_cost_per_token"]
|
||||
avg_input_cost = input_cost_sum / len(model_cost.keys())
|
||||
avg_output_cost = output_cost_sum / len(model_cost.keys())
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost[model]["input_cost_per_token"]
|
||||
* response_obj["usage"]["prompt_tokens"]
|
||||
|
@ -137,12 +137,6 @@ class AISpendLogger:
|
|||
f"AISpend Logging - Enters logging function for model {model}"
|
||||
)
|
||||
|
||||
url = f"https://aispend.io/api/v1/accounts/{self.account_id}/data"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
response_timestamp = datetime.datetime.fromtimestamp(
|
||||
int(response_obj["created"])
|
||||
).strftime("%Y-%m-%d")
|
||||
|
@ -168,6 +162,6 @@ class AISpendLogger:
|
|||
]
|
||||
|
||||
print_verbose(f"AISpend Logging - final data object: {data}")
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"AISpend Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -23,7 +23,7 @@ def set_arize_ai_attributes(span: Span, kwargs, response_obj):
|
|||
)
|
||||
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
# litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
|
||||
#############################################
|
||||
############ LLM CALL METADATA ##############
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
import datetime
|
||||
|
||||
|
||||
class AthinaLogger:
|
||||
def __init__(self):
|
||||
import os
|
||||
|
@ -23,17 +24,20 @@ class AthinaLogger:
|
|||
]
|
||||
|
||||
def log_event(self, kwargs, response_obj, start_time, end_time, print_verbose):
|
||||
import requests # type: ignore
|
||||
import json
|
||||
import traceback
|
||||
|
||||
import requests # type: ignore
|
||||
|
||||
try:
|
||||
is_stream = kwargs.get("stream", False)
|
||||
if is_stream:
|
||||
if "complete_streaming_response" in kwargs:
|
||||
# Log the completion response in streaming mode
|
||||
completion_response = kwargs["complete_streaming_response"]
|
||||
response_json = completion_response.model_dump() if completion_response else {}
|
||||
response_json = (
|
||||
completion_response.model_dump() if completion_response else {}
|
||||
)
|
||||
else:
|
||||
# Skip logging if the completion response is not available
|
||||
return
|
||||
|
@ -52,8 +56,8 @@ class AthinaLogger:
|
|||
}
|
||||
|
||||
if (
|
||||
type(end_time) == datetime.datetime
|
||||
and type(start_time) == datetime.datetime
|
||||
type(end_time) is datetime.datetime
|
||||
and type(start_time) is datetime.datetime
|
||||
):
|
||||
data["response_time"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
|
|
|
@ -1,10 +1,11 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to aispend.io
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
|
||||
import traceback
|
||||
import datetime
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
|
||||
model_cost = {
|
||||
"gpt-3.5-turbo": {
|
||||
|
@ -92,91 +93,12 @@ class BerriSpendLogger:
|
|||
self.account_id = os.getenv("BERRISPEND_ACCOUNT_ID")
|
||||
|
||||
def price_calculator(self, model, response_obj, start_time, end_time):
|
||||
# try and find if the model is in the model_cost map
|
||||
# else default to the average of the costs
|
||||
prompt_tokens_cost_usd_dollar = 0
|
||||
completion_tokens_cost_usd_dollar = 0
|
||||
if model in model_cost:
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost[model]["input_cost_per_token"]
|
||||
* response_obj["usage"]["prompt_tokens"]
|
||||
)
|
||||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost[model]["output_cost_per_token"]
|
||||
* response_obj["usage"]["completion_tokens"]
|
||||
)
|
||||
elif "replicate" in model:
|
||||
# replicate models are charged based on time
|
||||
# llama 2 runs on an nvidia a100 which costs $0.0032 per second - https://replicate.com/replicate/llama-2-70b-chat
|
||||
model_run_time = end_time - start_time # assuming time in seconds
|
||||
cost_usd_dollar = model_run_time * 0.0032
|
||||
prompt_tokens_cost_usd_dollar = cost_usd_dollar / 2
|
||||
completion_tokens_cost_usd_dollar = cost_usd_dollar / 2
|
||||
else:
|
||||
# calculate average input cost
|
||||
input_cost_sum = 0
|
||||
output_cost_sum = 0
|
||||
for model in model_cost:
|
||||
input_cost_sum += model_cost[model]["input_cost_per_token"]
|
||||
output_cost_sum += model_cost[model]["output_cost_per_token"]
|
||||
avg_input_cost = input_cost_sum / len(model_cost.keys())
|
||||
avg_output_cost = output_cost_sum / len(model_cost.keys())
|
||||
prompt_tokens_cost_usd_dollar = (
|
||||
model_cost[model]["input_cost_per_token"]
|
||||
* response_obj["usage"]["prompt_tokens"]
|
||||
)
|
||||
completion_tokens_cost_usd_dollar = (
|
||||
model_cost[model]["output_cost_per_token"]
|
||||
* response_obj["usage"]["completion_tokens"]
|
||||
)
|
||||
return prompt_tokens_cost_usd_dollar, completion_tokens_cost_usd_dollar
|
||||
return
|
||||
|
||||
def log_event(
|
||||
self, model, messages, response_obj, start_time, end_time, print_verbose
|
||||
):
|
||||
# Method definition
|
||||
try:
|
||||
print_verbose(
|
||||
f"BerriSpend Logging - Enters logging function for model {model}"
|
||||
)
|
||||
|
||||
url = f"https://berrispend.berri.ai/spend"
|
||||
headers = {"Content-Type": "application/json"}
|
||||
|
||||
(
|
||||
prompt_tokens_cost_usd_dollar,
|
||||
completion_tokens_cost_usd_dollar,
|
||||
) = self.price_calculator(model, response_obj, start_time, end_time)
|
||||
total_cost = (
|
||||
prompt_tokens_cost_usd_dollar + completion_tokens_cost_usd_dollar
|
||||
)
|
||||
|
||||
response_time = (end_time - start_time).total_seconds()
|
||||
if "response" in response_obj:
|
||||
data = [
|
||||
{
|
||||
"response_time": response_time,
|
||||
"model_id": response_obj["model"],
|
||||
"total_cost": total_cost,
|
||||
"messages": messages,
|
||||
"response": response_obj["choices"][0]["message"]["content"],
|
||||
"account_id": self.account_id,
|
||||
}
|
||||
]
|
||||
elif "error" in response_obj:
|
||||
data = [
|
||||
{
|
||||
"response_time": response_time,
|
||||
"model_id": response_obj["model"],
|
||||
"total_cost": total_cost,
|
||||
"messages": messages,
|
||||
"error": response_obj["error"],
|
||||
"account_id": self.account_id,
|
||||
}
|
||||
]
|
||||
|
||||
print_verbose(f"BerriSpend Logging - final data object: {data}")
|
||||
response = requests.post(url, headers=headers, json=data)
|
||||
except:
|
||||
print_verbose(f"BerriSpend Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
"""
|
||||
This integration is not implemented yet.
|
||||
"""
|
||||
return
|
||||
|
|
|
@ -136,27 +136,23 @@ class BraintrustLogger(CustomLogger):
|
|||
project_id = self.default_project_id
|
||||
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
|
||||
output = None
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
input = prompt
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.choices[0].text
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
|
@ -169,7 +165,7 @@ class BraintrustLogger(CustomLogger):
|
|||
metadata = copy.deepcopy(
|
||||
metadata
|
||||
) # Avoid modifying the original metadata
|
||||
except:
|
||||
except Exception:
|
||||
new_metadata = {}
|
||||
for key, value in metadata.items():
|
||||
if (
|
||||
|
@ -210,16 +206,13 @@ class BraintrustLogger(CustomLogger):
|
|||
clean_metadata["litellm_response_cost"] = cost
|
||||
|
||||
metrics: Optional[dict] = None
|
||||
if (
|
||||
response_obj is not None
|
||||
and hasattr(response_obj, "usage")
|
||||
and isinstance(response_obj.usage, litellm.Usage)
|
||||
):
|
||||
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
||||
usage_obj = getattr(response_obj, "usage", None)
|
||||
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||
litellm.utils.get_logging_id(start_time, response_obj)
|
||||
metrics = {
|
||||
"prompt_tokens": response_obj.usage.prompt_tokens,
|
||||
"completion_tokens": response_obj.usage.completion_tokens,
|
||||
"total_tokens": response_obj.usage.total_tokens,
|
||||
"prompt_tokens": usage_obj.prompt_tokens,
|
||||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
}
|
||||
|
||||
|
@ -255,27 +248,23 @@ class BraintrustLogger(CustomLogger):
|
|||
project_id = self.default_project_id
|
||||
|
||||
prompt = {"messages": kwargs.get("messages")}
|
||||
|
||||
output = None
|
||||
if response_obj is not None and (
|
||||
kwargs.get("call_type", None) == "embedding"
|
||||
or isinstance(response_obj, litellm.EmbeddingResponse)
|
||||
):
|
||||
input = prompt
|
||||
output = None
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ModelResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["choices"][0]["message"].json()
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.TextCompletionResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj.choices[0].text
|
||||
elif response_obj is not None and isinstance(
|
||||
response_obj, litellm.ImageResponse
|
||||
):
|
||||
input = prompt
|
||||
output = response_obj["data"]
|
||||
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
|
@ -331,16 +320,13 @@ class BraintrustLogger(CustomLogger):
|
|||
clean_metadata["litellm_response_cost"] = cost
|
||||
|
||||
metrics: Optional[dict] = None
|
||||
if (
|
||||
response_obj is not None
|
||||
and hasattr(response_obj, "usage")
|
||||
and isinstance(response_obj.usage, litellm.Usage)
|
||||
):
|
||||
generation_id = litellm.utils.get_logging_id(start_time, response_obj)
|
||||
usage_obj = getattr(response_obj, "usage", None)
|
||||
if usage_obj and isinstance(usage_obj, litellm.Usage):
|
||||
litellm.utils.get_logging_id(start_time, response_obj)
|
||||
metrics = {
|
||||
"prompt_tokens": response_obj.usage.prompt_tokens,
|
||||
"completion_tokens": response_obj.usage.completion_tokens,
|
||||
"total_tokens": response_obj.usage.total_tokens,
|
||||
"prompt_tokens": usage_obj.prompt_tokens,
|
||||
"completion_tokens": usage_obj.completion_tokens,
|
||||
"total_tokens": usage_obj.total_tokens,
|
||||
"total_cost": cost,
|
||||
}
|
||||
|
||||
|
|
|
@ -2,25 +2,24 @@
|
|||
|
||||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.caching import DualCache
|
||||
|
||||
from typing import Literal, Union
|
||||
import datetime
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
from typing import Literal, Optional, Union
|
||||
|
||||
import dotenv
|
||||
import requests
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.caching import DualCache
|
||||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
|
||||
|
||||
def create_client():
|
||||
try:
|
||||
|
@ -260,18 +259,12 @@ class ClickhouseLogger:
|
|||
f"ClickhouseLogger Logging - Enters logging function for model {kwargs}"
|
||||
)
|
||||
# follows the same params as langfuse.py
|
||||
from litellm.proxy.utils import get_logging_payload
|
||||
|
||||
payload = get_logging_payload(
|
||||
kwargs=kwargs,
|
||||
response_obj=response_obj,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
"standard_logging_object"
|
||||
)
|
||||
metadata = payload.get("metadata", "") or ""
|
||||
request_tags = payload.get("request_tags", "") or ""
|
||||
payload["metadata"] = str(metadata)
|
||||
payload["request_tags"] = str(request_tags)
|
||||
if payload is None:
|
||||
return
|
||||
# Build the initial payload
|
||||
|
||||
verbose_logger.debug(f"\nClickhouse Logger - Logging payload = {payload}")
|
||||
|
|
|
@ -12,7 +12,12 @@ from litellm.caching import DualCache
|
|||
from litellm.proxy._types import UserAPIKeyAuth
|
||||
from litellm.types.llms.openai import ChatCompletionRequest
|
||||
from litellm.types.services import ServiceLoggerPayload
|
||||
from litellm.types.utils import AdapterCompletionStreamWrapper, ModelResponse
|
||||
from litellm.types.utils import (
|
||||
AdapterCompletionStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
ModelResponse,
|
||||
)
|
||||
|
||||
|
||||
class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callback#callback-class
|
||||
|
@ -140,8 +145,8 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
self,
|
||||
data: dict,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
response,
|
||||
):
|
||||
response: Union[Any, ModelResponse, EmbeddingResponse, ImageResponse],
|
||||
) -> Any:
|
||||
pass
|
||||
|
||||
async def async_logging_hook(
|
||||
|
@ -188,7 +193,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
kwargs,
|
||||
)
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
async def async_log_input_event(
|
||||
|
@ -202,7 +207,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
kwargs,
|
||||
)
|
||||
print_verbose(f"Custom Logger - model call details: {kwargs}")
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
|
||||
def log_event(
|
||||
|
@ -217,7 +222,7 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
start_time,
|
||||
end_time,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
|
@ -233,6 +238,6 @@ class CustomLogger: # https://docs.litellm.ai/docs/observability/custom_callbac
|
|||
start_time,
|
||||
end_time,
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Custom Logger Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -54,7 +54,7 @@ class DataDogLogger(CustomBatchLogger):
|
|||
`DD_SITE` - your datadog site, example = `"us5.datadoghq.com"`
|
||||
"""
|
||||
try:
|
||||
verbose_logger.debug(f"Datadog: in init datadog logger")
|
||||
verbose_logger.debug("Datadog: in init datadog logger")
|
||||
# check if the correct env variables are set
|
||||
if os.getenv("DD_API_KEY", None) is None:
|
||||
raise Exception("DD_API_KEY is not set, set 'DD_API_KEY=<>")
|
||||
|
@ -245,12 +245,12 @@ class DataDogLogger(CustomBatchLogger):
|
|||
usage = dict(usage)
|
||||
try:
|
||||
response_time = (end_time - start_time).total_seconds() * 1000
|
||||
except:
|
||||
except Exception:
|
||||
response_time = None
|
||||
|
||||
try:
|
||||
response_obj = dict(response_obj)
|
||||
except:
|
||||
except Exception:
|
||||
response_obj = response_obj
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
|
|
|
@ -7,7 +7,7 @@ def make_json_serializable(payload):
|
|||
elif not isinstance(value, (str, int, float, bool, type(None))):
|
||||
# everything else becomes a string
|
||||
payload[key] = str(value)
|
||||
except:
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
return payload
|
||||
|
|
|
@ -1,12 +1,16 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import datetime
|
||||
import os
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
from litellm._logging import print_verbose
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
class DyanmoDBLogger:
|
||||
|
@ -16,7 +20,7 @@ class DyanmoDBLogger:
|
|||
# Instance variables
|
||||
import boto3
|
||||
|
||||
self.dynamodb = boto3.resource(
|
||||
self.dynamodb: Any = boto3.resource(
|
||||
"dynamodb", region_name=os.environ["AWS_REGION_NAME"]
|
||||
)
|
||||
if litellm.dynamodb_table_name is None:
|
||||
|
@ -67,7 +71,7 @@ class DyanmoDBLogger:
|
|||
for key, value in payload.items():
|
||||
try:
|
||||
payload[key] = str(value)
|
||||
except:
|
||||
except Exception:
|
||||
# non blocking if it can't cast to a str
|
||||
pass
|
||||
|
||||
|
@ -84,6 +88,6 @@ class DyanmoDBLogger:
|
|||
f"DynamoDB Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
return response
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"DynamoDB Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -9,6 +9,7 @@ import litellm
|
|||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
_get_httpx_client,
|
||||
get_async_httpx_client,
|
||||
httpxSpecialProvider,
|
||||
)
|
||||
|
@ -41,7 +42,7 @@ class GalileoObserve(CustomLogger):
|
|||
self.batch_size = 1
|
||||
self.base_url = os.getenv("GALILEO_BASE_URL", None)
|
||||
self.project_id = os.getenv("GALILEO_PROJECT_ID", None)
|
||||
self.headers = None
|
||||
self.headers: Optional[Dict[str, str]] = None
|
||||
self.async_httpx_handler = get_async_httpx_client(
|
||||
llm_provider=httpxSpecialProvider.LoggingCallback
|
||||
)
|
||||
|
@ -54,7 +55,7 @@ class GalileoObserve(CustomLogger):
|
|||
"accept": "application/json",
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
}
|
||||
galileo_login_response = self.async_httpx_handler.post(
|
||||
galileo_login_response = litellm.module_level_client.post(
|
||||
url=f"{self.base_url}/login",
|
||||
headers=headers,
|
||||
data={
|
||||
|
@ -94,13 +95,9 @@ class GalileoObserve(CustomLogger):
|
|||
return output
|
||||
|
||||
async def async_log_success_event(
|
||||
self,
|
||||
kwargs,
|
||||
start_time,
|
||||
end_time,
|
||||
response_obj,
|
||||
self, kwargs: Any, response_obj: Any, start_time: Any, end_time: Any
|
||||
):
|
||||
verbose_logger.debug(f"On Async Success")
|
||||
verbose_logger.debug("On Async Success")
|
||||
|
||||
_latency_ms = int((end_time - start_time).total_seconds() * 1000)
|
||||
_call_type = kwargs.get("call_type", "litellm")
|
||||
|
@ -116,6 +113,7 @@ class GalileoObserve(CustomLogger):
|
|||
response_obj=response_obj, kwargs=kwargs
|
||||
)
|
||||
|
||||
if output_text is not None:
|
||||
request_record = LLMResponse(
|
||||
latency_ms=_latency_ms,
|
||||
status_code=200,
|
||||
|
@ -159,4 +157,4 @@ class GalileoObserve(CustomLogger):
|
|||
)
|
||||
|
||||
async def async_log_failure_event(self, kwargs, response_obj, start_time, end_time):
|
||||
verbose_logger.debug(f"On Async Failure")
|
||||
verbose_logger.debug("On Async Failure")
|
||||
|
|
|
@ -56,8 +56,8 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
response_obj,
|
||||
)
|
||||
|
||||
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
headers = await self.construct_request_headers()
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
|
@ -103,8 +103,8 @@ class GCSBucketLogger(GCSBucketBase):
|
|||
response_obj,
|
||||
)
|
||||
|
||||
start_time_str = start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time_str = end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
start_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
end_time.strftime("%Y-%m-%d %H:%M:%S")
|
||||
headers = await self.construct_request_headers()
|
||||
|
||||
logging_payload: Optional[StandardLoggingPayload] = kwargs.get(
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
import requests # type: ignore
|
||||
import json
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import requests # type: ignore
|
||||
|
||||
|
||||
class GreenscaleLogger:
|
||||
def __init__(self):
|
||||
|
@ -29,7 +30,7 @@ class GreenscaleLogger:
|
|||
"%Y-%m-%dT%H:%M:%SZ"
|
||||
)
|
||||
|
||||
if type(end_time) == datetime and type(start_time) == datetime:
|
||||
if type(end_time) is datetime and type(start_time) is datetime:
|
||||
data["invocationLatency"] = int(
|
||||
(end_time - start_time).total_seconds() * 1000
|
||||
)
|
||||
|
@ -50,6 +51,9 @@ class GreenscaleLogger:
|
|||
|
||||
data["tags"] = tags
|
||||
|
||||
if self.greenscale_logging_url is None:
|
||||
raise Exception("Greenscale Logger Error - No logging URL found")
|
||||
|
||||
response = requests.post(
|
||||
self.greenscale_logging_url,
|
||||
headers=self.headers,
|
||||
|
|
|
@ -1,15 +1,28 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Helicone
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import litellm
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import verbose_logger
|
||||
|
||||
|
||||
class HeliconeLogger:
|
||||
# Class variables or attributes
|
||||
helicone_model_list = ["gpt", "claude", "command-r", "command-r-plus", "command-light", "command-medium", "command-medium-beta", "command-xlarge-nightly", "command-nightly"]
|
||||
helicone_model_list = [
|
||||
"gpt",
|
||||
"claude",
|
||||
"command-r",
|
||||
"command-r-plus",
|
||||
"command-light",
|
||||
"command-medium",
|
||||
"command-medium-beta",
|
||||
"command-xlarge-nightly",
|
||||
"command-nightly",
|
||||
]
|
||||
|
||||
def __init__(self):
|
||||
# Instance variables
|
||||
|
@ -17,7 +30,7 @@ class HeliconeLogger:
|
|||
self.key = os.getenv("HELICONE_API_KEY")
|
||||
|
||||
def claude_mapping(self, model, messages, response_obj):
|
||||
from anthropic import HUMAN_PROMPT, AI_PROMPT
|
||||
from anthropic import AI_PROMPT, HUMAN_PROMPT
|
||||
|
||||
prompt = f"{HUMAN_PROMPT}"
|
||||
for message in messages:
|
||||
|
@ -29,7 +42,6 @@ class HeliconeLogger:
|
|||
else:
|
||||
prompt += f"{HUMAN_PROMPT}{message['content']}"
|
||||
prompt += f"{AI_PROMPT}"
|
||||
claude_provider_request = {"model": model, "prompt": prompt}
|
||||
|
||||
choice = response_obj["choices"][0]
|
||||
message = choice["message"]
|
||||
|
@ -37,12 +49,14 @@ class HeliconeLogger:
|
|||
content = []
|
||||
if "tool_calls" in message and message["tool_calls"]:
|
||||
for tool_call in message["tool_calls"]:
|
||||
content.append({
|
||||
content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": tool_call["id"],
|
||||
"name": tool_call["function"]["name"],
|
||||
"input": tool_call["function"]["arguments"]
|
||||
})
|
||||
"input": tool_call["function"]["arguments"],
|
||||
}
|
||||
)
|
||||
elif "content" in message and message["content"]:
|
||||
content = [{"type": "text", "text": message["content"]}]
|
||||
|
||||
|
@ -56,8 +70,8 @@ class HeliconeLogger:
|
|||
"stop_sequence": None,
|
||||
"usage": {
|
||||
"input_tokens": response_obj["usage"]["prompt_tokens"],
|
||||
"output_tokens": response_obj["usage"]["completion_tokens"]
|
||||
}
|
||||
"output_tokens": response_obj["usage"]["completion_tokens"],
|
||||
},
|
||||
}
|
||||
|
||||
return claude_response_obj
|
||||
|
@ -99,10 +113,8 @@ class HeliconeLogger:
|
|||
f"Helicone Logging - Enters logging function for model {model}"
|
||||
)
|
||||
litellm_params = kwargs.get("litellm_params", {})
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
metadata = (
|
||||
litellm_params.get("metadata", {}) or {}
|
||||
)
|
||||
kwargs.get("litellm_call_id", None)
|
||||
metadata = litellm_params.get("metadata", {}) or {}
|
||||
metadata = self.add_metadata_from_header(litellm_params, metadata)
|
||||
model = (
|
||||
model
|
||||
|
@ -175,6 +187,6 @@ class HeliconeLogger:
|
|||
f"Helicone Logging - Error Request was not successful. Status Code: {response.status_code}"
|
||||
)
|
||||
print_verbose(f"Helicone Logging - Error {response.text}")
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Helicone Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -11,7 +11,7 @@ import dotenv
|
|||
import httpx
|
||||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm._logging import verbose_logger
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.llms.custom_httpx.http_handler import (
|
||||
HTTPHandler,
|
||||
|
@ -65,8 +65,8 @@ class LagoLogger(CustomLogger):
|
|||
raise Exception("Missing keys={} in environment.".format(missing_keys))
|
||||
|
||||
def _common_logic(self, kwargs: dict, response_obj) -> dict:
|
||||
call_id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
dt = get_utc_datetime().isoformat()
|
||||
response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
get_utc_datetime().isoformat()
|
||||
cost = kwargs.get("response_cost", None)
|
||||
model = kwargs.get("model")
|
||||
usage = {}
|
||||
|
@ -86,7 +86,7 @@ class LagoLogger(CustomLogger):
|
|||
end_user_id = proxy_server_request.get("body", {}).get("user", None)
|
||||
user_id = litellm_params["metadata"].get("user_api_key_user_id", None)
|
||||
team_id = litellm_params["metadata"].get("user_api_key_team_id", None)
|
||||
org_id = litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||
litellm_params["metadata"].get("user_api_key_org_id", None)
|
||||
|
||||
charge_by: Literal["end_user_id", "team_id", "user_id"] = "end_user_id"
|
||||
external_customer_id: Optional[str] = None
|
||||
|
@ -158,8 +158,9 @@ class LagoLogger(CustomLogger):
|
|||
|
||||
response.raise_for_status()
|
||||
except Exception as e:
|
||||
if hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
error_response = getattr(e, "response", None)
|
||||
if error_response is not None and hasattr(error_response, "text"):
|
||||
verbose_logger.debug(f"\nError Message: {error_response.text}")
|
||||
raise e
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
|
@ -199,5 +200,5 @@ class LagoLogger(CustomLogger):
|
|||
verbose_logger.debug(f"Logged Lago Object: {response.text}")
|
||||
except Exception as e:
|
||||
if response is not None and hasattr(response, "text"):
|
||||
litellm.print_verbose(f"\nError Message: {response.text}")
|
||||
verbose_logger.debug(f"\nError Message: {response.text}")
|
||||
raise e
|
||||
|
|
|
@ -67,7 +67,7 @@ class LangFuseLogger:
|
|||
try:
|
||||
project_id = self.Langfuse.client.projects.get().data[0].id
|
||||
os.environ["LANGFUSE_PROJECT_ID"] = project_id
|
||||
except:
|
||||
except Exception:
|
||||
project_id = None
|
||||
|
||||
if os.getenv("UPSTREAM_LANGFUSE_SECRET_KEY") is not None:
|
||||
|
@ -184,7 +184,7 @@ class LangFuseLogger:
|
|||
if not isinstance(value, (str, int, bool, float)):
|
||||
try:
|
||||
optional_params[param] = str(value)
|
||||
except:
|
||||
except Exception:
|
||||
# if casting value to str fails don't block logging
|
||||
pass
|
||||
|
||||
|
@ -275,7 +275,7 @@ class LangFuseLogger:
|
|||
print_verbose(
|
||||
f"Langfuse Layer Logging - final response object: {response_obj}"
|
||||
)
|
||||
verbose_logger.info(f"Langfuse Layer Logging - logging success")
|
||||
verbose_logger.info("Langfuse Layer Logging - logging success")
|
||||
|
||||
return {"trace_id": trace_id, "generation_id": generation_id}
|
||||
except Exception as e:
|
||||
|
@ -492,7 +492,7 @@ class LangFuseLogger:
|
|||
output if not mask_output else "redacted-by-litellm"
|
||||
)
|
||||
|
||||
if debug == True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||
if debug is True or (isinstance(debug, str) and debug.lower() == "true"):
|
||||
if "metadata" in trace_params:
|
||||
# log the raw_metadata in the trace
|
||||
trace_params["metadata"]["metadata_passed_to_litellm"] = metadata
|
||||
|
@ -535,8 +535,8 @@ class LangFuseLogger:
|
|||
|
||||
proxy_server_request = litellm_params.get("proxy_server_request", None)
|
||||
if proxy_server_request:
|
||||
method = proxy_server_request.get("method", None)
|
||||
url = proxy_server_request.get("url", None)
|
||||
proxy_server_request.get("method", None)
|
||||
proxy_server_request.get("url", None)
|
||||
headers = proxy_server_request.get("headers", None)
|
||||
clean_headers = {}
|
||||
if headers:
|
||||
|
@ -625,7 +625,7 @@ class LangFuseLogger:
|
|||
generation_client = trace.generation(**generation_params)
|
||||
|
||||
return generation_client.trace_id, generation_id
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.error(f"Langfuse Layer Error - {traceback.format_exc()}")
|
||||
return None, None
|
||||
|
||||
|
|
|
@ -404,7 +404,7 @@ class LangsmithLogger(CustomBatchLogger):
|
|||
verbose_logger.exception(
|
||||
f"Langsmith HTTP Error: {e.response.status_code} - {e.response.text}"
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.exception(
|
||||
f"Langsmith Layer Error - {traceback.format_exc()}"
|
||||
)
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
import requests, traceback, json, os
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import types
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
class LiteDebugger:
|
||||
user_email = None
|
||||
|
@ -17,23 +21,17 @@ class LiteDebugger:
|
|||
email or os.getenv("LITELLM_TOKEN") or os.getenv("LITELLM_EMAIL")
|
||||
)
|
||||
if (
|
||||
self.user_email == None
|
||||
self.user_email is None
|
||||
): # if users are trying to use_client=True but token not set
|
||||
raise ValueError(
|
||||
"litellm.use_client = True but no token or email passed. Please set it in litellm.token"
|
||||
)
|
||||
self.dashboard_url = "https://admin.litellm.ai/" + self.user_email
|
||||
try:
|
||||
print(
|
||||
f"\033[92mHere's your LiteLLM Dashboard 👉 \033[94m\033[4m{self.dashboard_url}\033[0m"
|
||||
)
|
||||
except:
|
||||
print(f"Here's your LiteLLM Dashboard 👉 {self.dashboard_url}")
|
||||
if self.user_email == None:
|
||||
if self.user_email is None:
|
||||
raise ValueError(
|
||||
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ValueError(
|
||||
"[Non-Blocking Error] LiteLLMDebugger: Missing LITELLM_TOKEN. Set it in your environment. Eg.: os.environ['LITELLM_TOKEN']= <your_email>"
|
||||
)
|
||||
|
@ -49,123 +47,18 @@ class LiteDebugger:
|
|||
litellm_params,
|
||||
optional_params,
|
||||
):
|
||||
print_verbose(
|
||||
f"LiteDebugger: Pre-API Call Logging for call id {litellm_call_id}"
|
||||
)
|
||||
try:
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Logging - Enters input logging function for model {model}"
|
||||
)
|
||||
|
||||
def remove_key_value(dictionary, key):
|
||||
new_dict = dictionary.copy() # Create a copy of the original dictionary
|
||||
new_dict.pop(key) # Remove the specified key-value pair from the copy
|
||||
return new_dict
|
||||
|
||||
updated_litellm_params = remove_key_value(litellm_params, "logger_fn")
|
||||
|
||||
if call_type == "embedding":
|
||||
for (
|
||||
message
|
||||
) in (
|
||||
messages
|
||||
): # assuming the input is a list as required by the embedding function
|
||||
litellm_data_obj = {
|
||||
"model": model,
|
||||
"messages": [{"role": "user", "content": message}],
|
||||
"end_user": end_user,
|
||||
"status": "initiated",
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
"litellm_params": updated_litellm_params,
|
||||
"optional_params": optional_params,
|
||||
}
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
print_verbose(f"LiteDebugger: embedding api response - {response.text}")
|
||||
elif call_type == "completion":
|
||||
litellm_data_obj = {
|
||||
"model": model,
|
||||
"messages": messages
|
||||
if isinstance(messages, list)
|
||||
else [{"role": "user", "content": messages}],
|
||||
"end_user": end_user,
|
||||
"status": "initiated",
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
"litellm_params": updated_litellm_params,
|
||||
"optional_params": optional_params,
|
||||
}
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Logging - logged data obj {litellm_data_obj}"
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
print_verbose(
|
||||
f"LiteDebugger: completion api response - {response.text}"
|
||||
)
|
||||
except:
|
||||
print_verbose(
|
||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
"""
|
||||
This integration is not implemented yet.
|
||||
"""
|
||||
return
|
||||
|
||||
def post_call_log_event(
|
||||
self, original_response, litellm_call_id, print_verbose, call_type, stream
|
||||
):
|
||||
print_verbose(
|
||||
f"LiteDebugger: Post-API Call Logging for call id {litellm_call_id}"
|
||||
)
|
||||
try:
|
||||
if call_type == "embedding":
|
||||
litellm_data_obj = {
|
||||
"status": "received",
|
||||
"additional_details": {
|
||||
"original_response": str(
|
||||
original_response["data"][0]["embedding"][:5]
|
||||
)
|
||||
}, # don't store the entire vector
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
}
|
||||
elif call_type == "completion" and not stream:
|
||||
litellm_data_obj = {
|
||||
"status": "received",
|
||||
"additional_details": {"original_response": original_response},
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
}
|
||||
elif call_type == "completion" and stream:
|
||||
litellm_data_obj = {
|
||||
"status": "received",
|
||||
"additional_details": {
|
||||
"original_response": "Streamed response"
|
||||
if isinstance(original_response, types.GeneratorType)
|
||||
else original_response
|
||||
},
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"user_email": self.user_email,
|
||||
}
|
||||
print_verbose(f"litedebugger post-call data object - {litellm_data_obj}")
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
print_verbose(f"LiteDebugger: api response - {response.text}")
|
||||
except:
|
||||
print_verbose(
|
||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
||||
)
|
||||
"""
|
||||
This integration is not implemented yet.
|
||||
"""
|
||||
return
|
||||
|
||||
def log_event(
|
||||
self,
|
||||
|
@ -178,85 +71,7 @@ class LiteDebugger:
|
|||
call_type,
|
||||
stream=False,
|
||||
):
|
||||
print_verbose(
|
||||
f"LiteDebugger: Success/Failure Call Logging for call id {litellm_call_id}"
|
||||
)
|
||||
try:
|
||||
print_verbose(
|
||||
f"LiteLLMDebugger: Success/Failure Logging - Enters handler logging function for function {call_type} and stream set to {stream} with response object {response_obj}"
|
||||
)
|
||||
total_cost = 0 # [TODO] implement cost tracking
|
||||
response_time = (end_time - start_time).total_seconds()
|
||||
if call_type == "completion" and stream == False:
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
"total_cost": total_cost,
|
||||
"response": response_obj["choices"][0]["message"]["content"],
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"status": "success",
|
||||
}
|
||||
print_verbose(
|
||||
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
elif call_type == "embedding":
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
"total_cost": total_cost,
|
||||
"response": str(response_obj["data"][0]["embedding"][:5]),
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"status": "success",
|
||||
}
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
elif call_type == "completion" and stream == True:
|
||||
if len(response_obj["content"]) > 0: # don't log the empty strings
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
"total_cost": total_cost,
|
||||
"response": response_obj["content"],
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"status": "success",
|
||||
}
|
||||
print_verbose(
|
||||
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
elif "error" in response_obj:
|
||||
if "Unable to map your input to a model." in response_obj["error"]:
|
||||
total_cost = 0
|
||||
litellm_data_obj = {
|
||||
"response_time": response_time,
|
||||
"model": response_obj["model"],
|
||||
"total_cost": total_cost,
|
||||
"error": response_obj["error"],
|
||||
"end_user": end_user,
|
||||
"litellm_call_id": litellm_call_id,
|
||||
"status": "failure",
|
||||
"user_email": self.user_email,
|
||||
}
|
||||
print_verbose(
|
||||
f"LiteDebugger: Logging - final data object: {litellm_data_obj}"
|
||||
)
|
||||
response = requests.post(
|
||||
url=self.api_url,
|
||||
headers={"content-type": "application/json"},
|
||||
data=json.dumps(litellm_data_obj),
|
||||
)
|
||||
print_verbose(f"LiteDebugger: api response - {response.text}")
|
||||
except:
|
||||
print_verbose(
|
||||
f"[Non-Blocking Error] LiteDebugger: Logging Error - {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
"""
|
||||
This integration is not implemented yet.
|
||||
"""
|
||||
return
|
||||
|
|
|
@ -27,7 +27,7 @@ class LogfireLogger:
|
|||
# Class variables or attributes
|
||||
def __init__(self):
|
||||
try:
|
||||
verbose_logger.debug(f"in init logfire logger")
|
||||
verbose_logger.debug("in init logfire logger")
|
||||
import logfire
|
||||
|
||||
# only setting up logfire if we are sending to logfire
|
||||
|
@ -116,7 +116,7 @@ class LogfireLogger:
|
|||
id = response_obj.get("id", str(uuid.uuid4()))
|
||||
try:
|
||||
response_time = (end_time - start_time).total_seconds()
|
||||
except:
|
||||
except Exception:
|
||||
response_time = None
|
||||
|
||||
# Clean Metadata before logging - never log raw metadata
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to lunary.ai
|
||||
from datetime import datetime, timezone
|
||||
import traceback
|
||||
import importlib
|
||||
import traceback
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import packaging
|
||||
|
||||
|
@ -74,9 +74,9 @@ class LunaryLogger:
|
|||
try:
|
||||
import lunary
|
||||
|
||||
version = importlib.metadata.version("lunary")
|
||||
version = importlib.metadata.version("lunary") # type: ignore
|
||||
# if version < 0.1.43 then raise ImportError
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"):
|
||||
if packaging.version.Version(version) < packaging.version.Version("0.1.43"): # type: ignore
|
||||
print( # noqa
|
||||
"Lunary version outdated. Required: >= 0.1.43. Upgrade via 'pip install lunary --upgrade'"
|
||||
)
|
||||
|
@ -97,7 +97,7 @@ class LunaryLogger:
|
|||
run_id,
|
||||
model,
|
||||
print_verbose,
|
||||
extra=None,
|
||||
extra={},
|
||||
input=None,
|
||||
user_id=None,
|
||||
response_obj=None,
|
||||
|
@ -128,7 +128,7 @@ class LunaryLogger:
|
|||
if not isinstance(value, (str, int, bool, float)) and param != "tools":
|
||||
try:
|
||||
extra[param] = str(value)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if response_obj:
|
||||
|
@ -175,6 +175,6 @@ class LunaryLogger:
|
|||
token_usage=usage,
|
||||
)
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Lunary Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -98,7 +98,7 @@ class OpenTelemetry(CustomLogger):
|
|||
import logging
|
||||
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
logger = logging.getLogger(__name__)
|
||||
logging.getLogger(__name__)
|
||||
|
||||
# Enable OpenTelemetry logging
|
||||
otel_exporter_logger = logging.getLogger("opentelemetry.sdk.trace.export")
|
||||
|
@ -520,7 +520,7 @@ class OpenTelemetry(CustomLogger):
|
|||
def set_raw_request_attributes(self, span: Span, kwargs, response_obj):
|
||||
from litellm.proxy._types import SpanAttributes
|
||||
|
||||
optional_params = kwargs.get("optional_params", {})
|
||||
kwargs.get("optional_params", {})
|
||||
litellm_params = kwargs.get("litellm_params", {}) or {}
|
||||
custom_llm_provider = litellm_params.get("custom_llm_provider", "Unknown")
|
||||
|
||||
|
@ -769,6 +769,6 @@ class OpenTelemetry(CustomLogger):
|
|||
management_endpoint_span.set_attribute(f"request.{key}", value)
|
||||
|
||||
_exception = logging_payload.exception
|
||||
management_endpoint_span.set_attribute(f"exception", str(_exception))
|
||||
management_endpoint_span.set_attribute("exception", str(_exception))
|
||||
management_endpoint_span.set_status(Status(StatusCode.ERROR))
|
||||
management_endpoint_span.end(end_time=_end_time_ns)
|
||||
|
|
|
@ -520,7 +520,7 @@ class PrometheusLogger(CustomLogger):
|
|||
user_api_team_alias = standard_logging_payload["metadata"][
|
||||
"user_api_key_team_alias"
|
||||
]
|
||||
exception = kwargs.get("exception", None)
|
||||
kwargs.get("exception", None)
|
||||
|
||||
try:
|
||||
self.litellm_llm_api_failed_requests_metric.labels(
|
||||
|
@ -679,7 +679,7 @@ class PrometheusLogger(CustomLogger):
|
|||
).inc()
|
||||
|
||||
pass
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def set_llm_deployment_success_metrics(
|
||||
|
@ -800,7 +800,7 @@ class PrometheusLogger(CustomLogger):
|
|||
|
||||
if (
|
||||
request_kwargs.get("stream", None) is not None
|
||||
and request_kwargs["stream"] == True
|
||||
and request_kwargs["stream"] is True
|
||||
):
|
||||
# only log ttft for streaming request
|
||||
time_to_first_token_response_time = (
|
||||
|
|
|
@ -3,11 +3,17 @@
|
|||
# On success + failure, log events to Prometheus for litellm / adjacent services (litellm, redis, postgres, llm api providers)
|
||||
|
||||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import datetime
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
import litellm, uuid
|
||||
import uuid
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
from litellm._logging import print_verbose, verbose_logger
|
||||
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
|
||||
|
||||
|
@ -23,7 +29,7 @@ class PrometheusServicesLogger:
|
|||
):
|
||||
try:
|
||||
try:
|
||||
from prometheus_client import Counter, Histogram, REGISTRY
|
||||
from prometheus_client import REGISTRY, Counter, Histogram
|
||||
except ImportError:
|
||||
raise Exception(
|
||||
"Missing prometheus_client. Run `pip install prometheus-client`"
|
||||
|
@ -33,7 +39,7 @@ class PrometheusServicesLogger:
|
|||
self.Counter = Counter
|
||||
self.REGISTRY = REGISTRY
|
||||
|
||||
verbose_logger.debug(f"in init prometheus services metrics")
|
||||
verbose_logger.debug("in init prometheus services metrics")
|
||||
|
||||
self.services = [item.value for item in ServiceTypes]
|
||||
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
#### What this does ####
|
||||
# On success, logs events to Promptlayer
|
||||
import dotenv, os
|
||||
import os
|
||||
import traceback
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
from pydantic import BaseModel
|
||||
import traceback
|
||||
|
||||
|
||||
class PromptLayerLogger:
|
||||
|
@ -84,6 +86,6 @@ class PromptLayerLogger:
|
|||
f"Prompt Layer Logging: success - metadata post response object: {response.text}"
|
||||
)
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"error: Prompt Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -1,10 +1,15 @@
|
|||
#### What this does ####
|
||||
# On success + failure, log events to Supabase
|
||||
|
||||
import dotenv, os
|
||||
import requests # type: ignore
|
||||
import datetime
|
||||
import os
|
||||
import subprocess
|
||||
import sys
|
||||
import traceback
|
||||
import datetime, subprocess, sys
|
||||
|
||||
import dotenv
|
||||
import requests # type: ignore
|
||||
|
||||
import litellm
|
||||
|
||||
|
||||
|
@ -21,7 +26,12 @@ class Supabase:
|
|||
except ImportError:
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "supabase"])
|
||||
import supabase
|
||||
self.supabase_client = supabase.create_client(
|
||||
|
||||
if self.supabase_url is None or self.supabase_key is None:
|
||||
raise ValueError(
|
||||
"LiteLLM Error, trying to use Supabase but url or key not passed. Create a table and set `litellm.supabase_url=<your-url>` and `litellm.supabase_key=<your-key>`"
|
||||
)
|
||||
self.supabase_client = supabase.create_client( # type: ignore
|
||||
self.supabase_url, self.supabase_key
|
||||
)
|
||||
|
||||
|
@ -45,7 +55,7 @@ class Supabase:
|
|||
.execute()
|
||||
)
|
||||
print_verbose(f"data: {data}")
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
||||
|
@ -109,6 +119,6 @@ class Supabase:
|
|||
.execute()
|
||||
)
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"Supabase Logging Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -167,18 +167,17 @@ try:
|
|||
trace = self.results_to_trace_tree(request, response, results, time_elapsed)
|
||||
return trace
|
||||
|
||||
except:
|
||||
except Exception:
|
||||
imported_openAIResponse = False
|
||||
|
||||
|
||||
#### What this does ####
|
||||
# On success, logs events to Langfuse
|
||||
import os
|
||||
import requests
|
||||
import requests
|
||||
import traceback
|
||||
from datetime import datetime
|
||||
|
||||
import traceback
|
||||
import requests
|
||||
|
||||
|
||||
class WeightsBiasesLogger:
|
||||
|
@ -186,11 +185,11 @@ class WeightsBiasesLogger:
|
|||
def __init__(self):
|
||||
try:
|
||||
import wandb
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
||||
)
|
||||
if imported_openAIResponse == False:
|
||||
if imported_openAIResponse is False:
|
||||
raise Exception(
|
||||
"\033[91m wandb not installed, try running 'pip install wandb' to fix this error\033[0m"
|
||||
)
|
||||
|
@ -209,13 +208,14 @@ class WeightsBiasesLogger:
|
|||
kwargs, response_obj, (end_time - start_time).total_seconds()
|
||||
)
|
||||
|
||||
if trace is not None:
|
||||
if trace is not None and run is not None:
|
||||
run.log({"trace": trace})
|
||||
|
||||
if run is not None:
|
||||
run.finish()
|
||||
print_verbose(
|
||||
f"W&B Logging Logging - final response object: {response_obj}"
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
print_verbose(f"W&B Logging Layer Error - {traceback.format_exc()}")
|
||||
pass
|
||||
|
|
|
@ -62,7 +62,7 @@ def get_error_message(error_obj) -> Optional[str]:
|
|||
|
||||
# If all else fails, return None
|
||||
return None
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
|
@ -910,7 +910,7 @@ def exception_type( # type: ignore
|
|||
):
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
|
||||
message="SagemakerException - the value of 'n' must be > 0 and <= 2 for sagemaker endpoints",
|
||||
model=model,
|
||||
llm_provider="sagemaker",
|
||||
response=original_exception.response,
|
||||
|
@ -1122,7 +1122,7 @@ def exception_type( # type: ignore
|
|||
# 503 Getting metadata from plugin failed with error: Reauthentication is needed. Please run `gcloud auth application-default login` to reauthenticate.
|
||||
exception_mapping_worked = True
|
||||
raise BadRequestError(
|
||||
message=f"GeminiException - Invalid api key",
|
||||
message="GeminiException - Invalid api key",
|
||||
model=model,
|
||||
llm_provider="palm",
|
||||
response=original_exception.response,
|
||||
|
@ -2067,12 +2067,34 @@ def exception_logging(
|
|||
logger_fn(
|
||||
model_call_details
|
||||
) # Expectation: any logger function passed in by the user should accept a dict object
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.debug(
|
||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while logging {traceback.format_exc()}"
|
||||
)
|
||||
pass
|
||||
|
||||
|
||||
def _add_key_name_and_team_to_alert(request_info: str, metadata: dict) -> str:
|
||||
"""
|
||||
Internal helper function for litellm proxy
|
||||
Add the Key Name + Team Name to the error
|
||||
Only gets added if the metadata contains the user_api_key_alias and user_api_key_team_alias
|
||||
|
||||
[Non-Blocking helper function]
|
||||
"""
|
||||
try:
|
||||
_api_key_name = metadata.get("user_api_key_alias", None)
|
||||
_user_api_key_team_alias = metadata.get("user_api_key_team_alias", None)
|
||||
if _api_key_name is not None:
|
||||
request_info = (
|
||||
f"\n\nKey Name: `{_api_key_name}`\nTeam: `{_user_api_key_team_alias}`"
|
||||
+ request_info
|
||||
)
|
||||
|
||||
return request_info
|
||||
except Exception:
|
||||
return request_info
|
||||
|
|
|
@ -476,7 +476,7 @@ def get_llm_provider(
|
|||
elif model == "*":
|
||||
custom_llm_provider = "openai"
|
||||
if custom_llm_provider is None or custom_llm_provider == "":
|
||||
if litellm.suppress_debug_info == False:
|
||||
if litellm.suppress_debug_info is False:
|
||||
print() # noqa
|
||||
print( # noqa
|
||||
"\033[1;31mProvider List: https://docs.litellm.ai/docs/providers\033[0m" # noqa
|
||||
|
|
|
@ -52,18 +52,8 @@ from litellm.types.utils import (
|
|||
)
|
||||
from litellm.utils import (
|
||||
_get_base_model_from_metadata,
|
||||
add_breadcrumb,
|
||||
capture_exception,
|
||||
customLogger,
|
||||
liteDebuggerClient,
|
||||
logfireLogger,
|
||||
lunaryLogger,
|
||||
print_verbose,
|
||||
prometheusLogger,
|
||||
prompt_token_calculator,
|
||||
promptLayerLogger,
|
||||
supabaseClient,
|
||||
weightsBiasesLogger,
|
||||
)
|
||||
|
||||
from ..integrations.aispend import AISpendLogger
|
||||
|
@ -71,7 +61,6 @@ from ..integrations.athina import AthinaLogger
|
|||
from ..integrations.berrispend import BerriSpendLogger
|
||||
from ..integrations.braintrust_logging import BraintrustLogger
|
||||
from ..integrations.clickhouse import ClickhouseLogger
|
||||
from ..integrations.custom_logger import CustomLogger
|
||||
from ..integrations.datadog.datadog import DataDogLogger
|
||||
from ..integrations.dynamodb import DyanmoDBLogger
|
||||
from ..integrations.galileo import GalileoObserve
|
||||
|
@ -423,7 +412,7 @@ class Logging:
|
|||
elif callback == "sentry" and add_breadcrumb:
|
||||
try:
|
||||
details_to_log = copy.deepcopy(self.model_call_details)
|
||||
except:
|
||||
except Exception:
|
||||
details_to_log = self.model_call_details
|
||||
if litellm.turn_off_message_logging:
|
||||
# make a copy of the _model_Call_details and log it
|
||||
|
@ -528,7 +517,7 @@ class Logging:
|
|||
verbose_logger.debug("reaches sentry breadcrumbing")
|
||||
try:
|
||||
details_to_log = copy.deepcopy(self.model_call_details)
|
||||
except:
|
||||
except Exception:
|
||||
details_to_log = self.model_call_details
|
||||
if litellm.turn_off_message_logging:
|
||||
# make a copy of the _model_Call_details and log it
|
||||
|
@ -1326,7 +1315,7 @@ class Logging:
|
|||
and customLogger is not None
|
||||
): # custom logger functions
|
||||
print_verbose(
|
||||
f"success callbacks: Running Custom Callback Function"
|
||||
"success callbacks: Running Custom Callback Function"
|
||||
)
|
||||
customLogger.log_event(
|
||||
kwargs=self.model_call_details,
|
||||
|
@ -1400,7 +1389,7 @@ class Logging:
|
|||
self.model_call_details["response_cost"] = 0.0
|
||||
else:
|
||||
# check if base_model set on azure
|
||||
base_model = _get_base_model_from_metadata(
|
||||
_get_base_model_from_metadata(
|
||||
model_call_details=self.model_call_details
|
||||
)
|
||||
# base_model defaults to None if not set on model_info
|
||||
|
@ -1483,7 +1472,7 @@ class Logging:
|
|||
for callback in callbacks:
|
||||
# check if callback can run for this request
|
||||
litellm_params = self.model_call_details.get("litellm_params", {})
|
||||
if litellm_params.get("no-log", False) == True:
|
||||
if litellm_params.get("no-log", False) is True:
|
||||
# proxy cost tracking cal backs should run
|
||||
if not (
|
||||
isinstance(callback, CustomLogger)
|
||||
|
@ -1492,7 +1481,7 @@ class Logging:
|
|||
print_verbose("no-log request, skipping logging")
|
||||
continue
|
||||
try:
|
||||
if kwargs.get("no-log", False) == True:
|
||||
if kwargs.get("no-log", False) is True:
|
||||
print_verbose("no-log request, skipping logging")
|
||||
continue
|
||||
if (
|
||||
|
@ -1641,7 +1630,7 @@ class Logging:
|
|||
end_time=end_time,
|
||||
print_verbose=print_verbose,
|
||||
)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
verbose_logger.error(
|
||||
f"LiteLLM.LoggingError: [Non-Blocking] Exception occurred while success logging {traceback.format_exc()}"
|
||||
)
|
||||
|
@ -2433,7 +2422,7 @@ def get_standard_logging_object_payload(
|
|||
call_type = kwargs.get("call_type")
|
||||
cache_hit = kwargs.get("cache_hit", False)
|
||||
usage = response_obj.get("usage", None) or {}
|
||||
if type(usage) == litellm.Usage:
|
||||
if type(usage) is litellm.Usage:
|
||||
usage = dict(usage)
|
||||
id = response_obj.get("id", kwargs.get("litellm_call_id"))
|
||||
|
||||
|
@ -2656,3 +2645,11 @@ def scrub_sensitive_keys_in_metadata(litellm_params: Optional[dict]):
|
|||
litellm_params["metadata"] = metadata
|
||||
|
||||
return litellm_params
|
||||
|
||||
|
||||
# integration helper function
|
||||
def modify_integration(integration_name, integration_params):
|
||||
global supabaseClient
|
||||
if integration_name == "supabase":
|
||||
if "table_name" in integration_params:
|
||||
Supabase.supabase_table_name = integration_params["table_name"]
|
||||
|
|
|
@ -45,7 +45,7 @@ def pick_cheapest_chat_model_from_llm_provider(custom_llm_provider: str):
|
|||
model_info = litellm.get_model_info(
|
||||
model=model, custom_llm_provider=custom_llm_provider
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
continue
|
||||
if model_info.get("mode") != "chat":
|
||||
continue
|
||||
|
|
|
@ -123,7 +123,7 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -165,7 +165,7 @@ def completion(
|
|||
)
|
||||
if response.status_code != 200:
|
||||
raise AI21Error(status_code=response.status_code, message=response.text)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -191,7 +191,7 @@ def completion(
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise AI21Error(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
|
|
@ -151,7 +151,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
aget_assistants=None,
|
||||
):
|
||||
if aget_assistants is not None and aget_assistants == True:
|
||||
if aget_assistants is not None and aget_assistants is True:
|
||||
return self.async_get_assistants(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -260,7 +260,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
a_add_message: Optional[bool] = None,
|
||||
):
|
||||
if a_add_message is not None and a_add_message == True:
|
||||
if a_add_message is not None and a_add_message is True:
|
||||
return self.a_add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=message_data,
|
||||
|
@ -365,7 +365,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
aget_messages=None,
|
||||
):
|
||||
if aget_messages is not None and aget_messages == True:
|
||||
if aget_messages is not None and aget_messages is True:
|
||||
return self.async_get_messages(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
|
@ -483,7 +483,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
openai_api.create_thread(messages=[message])
|
||||
```
|
||||
"""
|
||||
if acreate_thread is not None and acreate_thread == True:
|
||||
if acreate_thread is not None and acreate_thread is True:
|
||||
return self.async_create_thread(
|
||||
metadata=metadata,
|
||||
api_key=api_key,
|
||||
|
@ -586,7 +586,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
aget_thread=None,
|
||||
):
|
||||
if aget_thread is not None and aget_thread == True:
|
||||
if aget_thread is not None and aget_thread is True:
|
||||
return self.async_get_thread(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
|
@ -774,8 +774,8 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
if arun_thread is not None and arun_thread is True:
|
||||
if stream is not None and stream is True:
|
||||
azure_client = self.async_get_azure_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -823,7 +823,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
if stream is not None and stream is True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
|
@ -887,7 +887,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
async_create_assistants=None,
|
||||
):
|
||||
if async_create_assistants is not None and async_create_assistants == True:
|
||||
if async_create_assistants is not None and async_create_assistants is True:
|
||||
return self.async_create_assistants(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -950,7 +950,7 @@ class AzureAssistantsAPI(BaseLLM):
|
|||
async_delete_assistants: Optional[bool] = None,
|
||||
client=None,
|
||||
):
|
||||
if async_delete_assistants is not None and async_delete_assistants == True:
|
||||
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||
return self.async_delete_assistant(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
|
|
@ -317,7 +317,7 @@ class AzureOpenAIAssistantsAPIConfig:
|
|||
if "file_id" in item:
|
||||
file_ids.append(item["file_id"])
|
||||
else:
|
||||
if litellm.drop_params == True:
|
||||
if litellm.drop_params is True:
|
||||
pass
|
||||
else:
|
||||
raise litellm.utils.UnsupportedParamsError(
|
||||
|
@ -580,7 +580,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
try:
|
||||
if model is None or messages is None:
|
||||
raise AzureOpenAIError(
|
||||
status_code=422, message=f"Missing model or messages"
|
||||
status_code=422, message="Missing model or messages"
|
||||
)
|
||||
|
||||
max_retries = optional_params.pop("max_retries", 2)
|
||||
|
@ -1240,12 +1240,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
)
|
||||
while response.json()["status"] not in ["succeeded", "failed"]:
|
||||
if time.time() - start_time > timeout_secs:
|
||||
timeout_msg = {
|
||||
"error": {
|
||||
"code": "Timeout",
|
||||
"message": "Operation polling timed out.",
|
||||
}
|
||||
}
|
||||
|
||||
raise AzureOpenAIError(
|
||||
status_code=408, message="Operation polling timed out."
|
||||
|
@ -1493,7 +1487,6 @@ class AzureChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
try:
|
||||
if model and len(model) > 0:
|
||||
model = model
|
||||
|
@ -1534,7 +1527,7 @@ class AzureChatCompletion(BaseLLM):
|
|||
azure_ad_token = get_azure_ad_token_from_oidc(azure_ad_token)
|
||||
azure_client_params["azure_ad_token"] = azure_ad_token
|
||||
|
||||
if aimg_generation == True:
|
||||
if aimg_generation is True:
|
||||
response = self.aimage_generation(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_key=api_key, client=client, azure_client_params=azure_client_params, timeout=timeout) # type: ignore
|
||||
return response
|
||||
|
||||
|
|
|
@ -1263,7 +1263,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
client=None,
|
||||
aimg_generation=None,
|
||||
):
|
||||
exception_mapping_worked = False
|
||||
data = {}
|
||||
try:
|
||||
model = model
|
||||
|
@ -1272,7 +1271,7 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
if not isinstance(max_retries, int):
|
||||
raise OpenAIError(status_code=422, message="max retries must be an int")
|
||||
|
||||
if aimg_generation == True:
|
||||
if aimg_generation is True:
|
||||
response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore
|
||||
return response
|
||||
|
||||
|
@ -1311,7 +1310,6 @@ class OpenAIChatCompletion(BaseLLM):
|
|||
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||
except OpenAIError as e:
|
||||
|
||||
exception_mapping_worked = True
|
||||
## LOGGING
|
||||
logging_obj.post_call(
|
||||
input=prompt,
|
||||
|
@ -1543,7 +1541,7 @@ class OpenAITextCompletion(BaseLLM):
|
|||
if (
|
||||
len(messages) > 0
|
||||
and "content" in messages[0]
|
||||
and type(messages[0]["content"]) == list
|
||||
and isinstance(messages[0]["content"], list)
|
||||
):
|
||||
prompt = messages[0]["content"]
|
||||
else:
|
||||
|
@ -2413,7 +2411,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
aget_assistants=None,
|
||||
):
|
||||
if aget_assistants is not None and aget_assistants == True:
|
||||
if aget_assistants is not None and aget_assistants is True:
|
||||
return self.async_get_assistants(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -2470,7 +2468,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
async_create_assistants=None,
|
||||
):
|
||||
if async_create_assistants is not None and async_create_assistants == True:
|
||||
if async_create_assistants is not None and async_create_assistants is True:
|
||||
return self.async_create_assistants(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -2527,7 +2525,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
async_delete_assistants=None,
|
||||
):
|
||||
if async_delete_assistants is not None and async_delete_assistants == True:
|
||||
if async_delete_assistants is not None and async_delete_assistants is True:
|
||||
return self.async_delete_assistant(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -2629,7 +2627,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
a_add_message: Optional[bool] = None,
|
||||
):
|
||||
if a_add_message is not None and a_add_message == True:
|
||||
if a_add_message is not None and a_add_message is True:
|
||||
return self.a_add_message(
|
||||
thread_id=thread_id,
|
||||
message_data=message_data,
|
||||
|
@ -2727,7 +2725,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
aget_messages=None,
|
||||
):
|
||||
if aget_messages is not None and aget_messages == True:
|
||||
if aget_messages is not None and aget_messages is True:
|
||||
return self.async_get_messages(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
|
@ -2838,7 +2836,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
openai_api.create_thread(messages=[message])
|
||||
```
|
||||
"""
|
||||
if acreate_thread is not None and acreate_thread == True:
|
||||
if acreate_thread is not None and acreate_thread is True:
|
||||
return self.async_create_thread(
|
||||
metadata=metadata,
|
||||
api_key=api_key,
|
||||
|
@ -2934,7 +2932,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=None,
|
||||
aget_thread=None,
|
||||
):
|
||||
if aget_thread is not None and aget_thread == True:
|
||||
if aget_thread is not None and aget_thread is True:
|
||||
return self.async_get_thread(
|
||||
thread_id=thread_id,
|
||||
api_key=api_key,
|
||||
|
@ -3117,8 +3115,8 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
arun_thread=None,
|
||||
event_handler: Optional[AssistantEventHandler] = None,
|
||||
):
|
||||
if arun_thread is not None and arun_thread == True:
|
||||
if stream is not None and stream == True:
|
||||
if arun_thread is not None and arun_thread is True:
|
||||
if stream is not None and stream is True:
|
||||
_client = self.async_get_openai_client(
|
||||
api_key=api_key,
|
||||
api_base=api_base,
|
||||
|
@ -3163,7 +3161,7 @@ class OpenAIAssistantsAPI(BaseLLM):
|
|||
client=client,
|
||||
)
|
||||
|
||||
if stream is not None and stream == True:
|
||||
if stream is not None and stream is True:
|
||||
return self.run_thread_stream(
|
||||
client=openai_client,
|
||||
thread_id=thread_id,
|
||||
|
|
|
@ -191,7 +191,7 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
default_max_tokens_to_sample=None,
|
||||
|
@ -246,7 +246,7 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -279,7 +279,7 @@ def completion(
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
raise AlephAlphaError(
|
||||
message=json.dumps(completion_response),
|
||||
status_code=response.status_code,
|
||||
|
|
|
@ -607,7 +607,6 @@ class ModelResponseIterator:
|
|||
def _handle_usage(
|
||||
self, anthropic_usage_chunk: Union[dict, UsageDelta]
|
||||
) -> AnthropicChatCompletionUsageBlock:
|
||||
special_fields = ["input_tokens", "output_tokens"]
|
||||
|
||||
usage_block = AnthropicChatCompletionUsageBlock(
|
||||
prompt_tokens=anthropic_usage_chunk.get("input_tokens", 0),
|
||||
|
@ -683,7 +682,7 @@ class ModelResponseIterator:
|
|||
"index": self.tool_index,
|
||||
}
|
||||
elif type_chunk == "content_block_stop":
|
||||
content_block_stop = ContentBlockStop(**chunk) # type: ignore
|
||||
ContentBlockStop(**chunk) # type: ignore
|
||||
# check if tool call content block
|
||||
is_empty = self.check_empty_tool_call_args()
|
||||
if is_empty:
|
||||
|
|
|
@ -114,7 +114,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
except Exception:
|
||||
raise AnthropicError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
@ -229,7 +229,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
headers={},
|
||||
|
@ -276,8 +276,8 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if acompletion == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
if acompletion is True:
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
api_base=api_base,
|
||||
|
@ -309,7 +309,7 @@ class AnthropicTextCompletion(BaseLLM):
|
|||
logging_obj=logging_obj,
|
||||
)
|
||||
return stream_response
|
||||
elif acompletion == True:
|
||||
elif acompletion is True:
|
||||
return self.async_completion(
|
||||
model=model,
|
||||
model_response=model_response,
|
||||
|
|
|
@ -233,7 +233,7 @@ class AzureTextCompletion(BaseLLM):
|
|||
client=client,
|
||||
logging_obj=logging_obj,
|
||||
)
|
||||
elif "stream" in optional_params and optional_params["stream"] == True:
|
||||
elif "stream" in optional_params and optional_params["stream"] is True:
|
||||
return self.streaming(
|
||||
logging_obj=logging_obj,
|
||||
api_base=api_base,
|
||||
|
|
|
@ -36,7 +36,7 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -59,7 +59,7 @@ def completion(
|
|||
"parameters": optional_params,
|
||||
"stream": (
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
if "stream" in optional_params and optional_params["stream"] is True
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
@ -77,12 +77,12 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=(
|
||||
True
|
||||
if "stream" in optional_params and optional_params["stream"] == True
|
||||
if "stream" in optional_params and optional_params["stream"] is True
|
||||
else False
|
||||
),
|
||||
)
|
||||
if "text/event-stream" in response.headers["Content-Type"] or (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
"stream" in optional_params and optional_params["stream"] is True
|
||||
):
|
||||
return response.iter_lines()
|
||||
else:
|
||||
|
|
|
@ -183,7 +183,7 @@ class BedrockConverseLLM(BaseAWSLLM):
|
|||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return litellm.AmazonConverseConfig()._transform_response(
|
||||
|
|
|
@ -251,9 +251,7 @@ class AmazonConverseConfig:
|
|||
supported_converse_params = AmazonConverseConfig.__annotations__.keys()
|
||||
supported_tool_call_params = ["tools", "tool_choice"]
|
||||
supported_guardrail_params = ["guardrailConfig"]
|
||||
json_mode: Optional[bool] = inference_params.pop(
|
||||
"json_mode", None
|
||||
) # used for handling json_schema
|
||||
inference_params.pop("json_mode", None) # used for handling json_schema
|
||||
## TRANSFORMATION ##
|
||||
|
||||
bedrock_messages: List[MessageBlock] = _bedrock_converse_messages_pt(
|
||||
|
|
|
@ -234,7 +234,7 @@ async def make_call(
|
|||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise BedrockError(status_code=500, message=str(e))
|
||||
|
@ -335,7 +335,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
except Exception:
|
||||
raise BedrockError(message=response.text, status_code=422)
|
||||
|
||||
outputText: Optional[str] = None
|
||||
|
@ -394,12 +394,12 @@ class BedrockLLM(BaseAWSLLM):
|
|||
outputText # allow user to access raw anthropic tool calling response
|
||||
)
|
||||
if (
|
||||
_is_function_call == True
|
||||
_is_function_call is True
|
||||
and stream is not None
|
||||
and stream == True
|
||||
and stream is True
|
||||
):
|
||||
print_verbose(
|
||||
f"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||
"INSIDE BEDROCK STREAMING TOOL CALLING CONDITION BLOCK"
|
||||
)
|
||||
# return an iterator
|
||||
streaming_model_response = ModelResponse(stream=True)
|
||||
|
@ -440,7 +440,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
model_response=streaming_model_response
|
||||
)
|
||||
print_verbose(
|
||||
f"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
"Returns anthropic CustomStreamWrapper with 'cached_response' streaming object"
|
||||
)
|
||||
return litellm.CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
@ -597,7 +597,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
## SETUP ##
|
||||
|
@ -700,7 +700,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
k not in inference_params
|
||||
): # completion(top_k=3) > anthropic_config(top_k=3) <- allows for dynamic variables to be passed in
|
||||
inference_params[k] = v
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
inference_params["stream"] = (
|
||||
True # cohere requires stream = True in inference params
|
||||
)
|
||||
|
@ -845,7 +845,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
if acompletion:
|
||||
if isinstance(client, HTTPHandler):
|
||||
client = None
|
||||
if stream == True and provider != "ai21":
|
||||
if stream is True and provider != "ai21":
|
||||
return self.async_streaming(
|
||||
model=model,
|
||||
messages=messages,
|
||||
|
@ -891,7 +891,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
self.client = _get_httpx_client(_params) # type: ignore
|
||||
else:
|
||||
self.client = client
|
||||
if (stream is not None and stream == True) and provider != "ai21":
|
||||
if (stream is not None and stream is True) and provider != "ai21":
|
||||
response = self.client.post(
|
||||
url=proxy_endpoint_url,
|
||||
headers=prepped.headers, # type: ignore
|
||||
|
@ -929,7 +929,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
|
@ -980,7 +980,7 @@ class BedrockLLM(BaseAWSLLM):
|
|||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise BedrockError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise BedrockError(status_code=408, message="Timeout error occurred.")
|
||||
|
||||
return self.process_response(
|
||||
|
|
|
@ -260,7 +260,7 @@ class AmazonAnthropicConfig:
|
|||
optional_params["top_p"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop_sequences"] = value
|
||||
if param == "stream" and value == True:
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
return optional_params
|
||||
|
||||
|
|
|
@ -39,7 +39,7 @@ class BedrockEmbedding(BaseAWSLLM):
|
|||
) -> Tuple[Any, str]:
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
|
|
@ -130,7 +130,7 @@ def process_response(
|
|||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ClarifaiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||
)
|
||||
|
@ -219,7 +219,7 @@ async def async_completion(
|
|||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ClarifaiError(
|
||||
message=traceback.format_exc(), status_code=response.status_code, url=model
|
||||
)
|
||||
|
@ -251,9 +251,9 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
acompletion=False,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -268,20 +268,12 @@ def completion(
|
|||
optional_params[k] = v
|
||||
|
||||
custom_llm_provider, orig_model_name = get_prompt_model_name(model)
|
||||
if custom_llm_provider == "anthropic":
|
||||
prompt = prompt_factory(
|
||||
prompt: str = prompt_factory( # type: ignore
|
||||
model=orig_model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
custom_llm_provider="clarifai",
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=orig_model_name,
|
||||
messages=messages,
|
||||
api_key=api_key,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
# print(prompt); exit(0)
|
||||
|
||||
data = {
|
||||
|
@ -300,7 +292,7 @@ def completion(
|
|||
"api_base": model,
|
||||
},
|
||||
)
|
||||
if acompletion == True:
|
||||
if acompletion is True:
|
||||
return async_completion(
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
|
@ -331,7 +323,7 @@ def completion(
|
|||
status_code=response.status_code, message=response.text, url=model
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
completion_stream = response.iter_lines()
|
||||
stream_response = CustomStreamWrapper(
|
||||
completion_stream=completion_stream,
|
||||
|
|
|
@ -80,8 +80,8 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -97,7 +97,7 @@ def completion(
|
|||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
custom_prompt(
|
||||
role_dict=model_prompt_details.get("roles", {}),
|
||||
initial_prompt_value=model_prompt_details.get("initial_prompt_value", ""),
|
||||
final_prompt_value=model_prompt_details.get("final_prompt_value", ""),
|
||||
|
@ -126,7 +126,7 @@ def completion(
|
|||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
response = requests.post(
|
||||
api_base,
|
||||
headers=headers,
|
||||
|
|
|
@ -268,7 +268,7 @@ def completion(
|
|||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -283,12 +283,12 @@ def completion(
|
|||
completion_response = response.json()
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["text"] # type: ignore
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
## Tool calling response
|
||||
cohere_tools_response = completion_response.get("tool_calls", None)
|
||||
if cohere_tools_response is not None and cohere_tools_response is not []:
|
||||
if cohere_tools_response is not None and cohere_tools_response != []:
|
||||
# convert cohere_tools_response to OpenAI response format
|
||||
tool_calls = []
|
||||
for tool in cohere_tools_response:
|
||||
|
|
|
@ -146,7 +146,7 @@ def completion(
|
|||
api_key,
|
||||
logging_obj,
|
||||
headers: dict,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -198,7 +198,7 @@ def completion(
|
|||
if response.status_code != 200:
|
||||
raise CohereError(message=response.text, status_code=response.status_code)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -231,7 +231,7 @@ def completion(
|
|||
)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise CohereError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
|
|
@ -17,7 +17,7 @@ else:
|
|||
|
||||
try:
|
||||
from litellm._version import version
|
||||
except:
|
||||
except Exception:
|
||||
version = "0.0.0"
|
||||
|
||||
headers = {
|
||||
|
|
|
@ -1,15 +1,17 @@
|
|||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
|
||||
try:
|
||||
from litellm._version import version
|
||||
except:
|
||||
except Exception:
|
||||
version = "0.0.0"
|
||||
|
||||
headers = {
|
||||
"User-Agent": f"litellm/{version}",
|
||||
}
|
||||
|
||||
|
||||
class HTTPHandler:
|
||||
def __init__(self, concurrent_limit=1000):
|
||||
# Create a client with a connection pool
|
||||
|
|
|
@ -113,7 +113,7 @@ class DatabricksConfig:
|
|||
optional_params["max_tokens"] = value
|
||||
if param == "n":
|
||||
optional_params["n"] = value
|
||||
if param == "stream" and value == True:
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "temperature":
|
||||
optional_params["temperature"] = value
|
||||
|
@ -564,7 +564,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise DatabricksError(
|
||||
status_code=408, message="Timeout error occurred."
|
||||
)
|
||||
|
@ -614,7 +614,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
status_code=e.response.status_code,
|
||||
message=response.text if response else str(e),
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise DatabricksError(
|
||||
status_code=408, message="Timeout error occurred."
|
||||
)
|
||||
|
@ -669,7 +669,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
additional_args={"complete_input_dict": data, "api_base": api_base},
|
||||
)
|
||||
|
||||
if aembedding == True:
|
||||
if aembedding is True:
|
||||
return self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, headers=headers) # type: ignore
|
||||
if client is None or isinstance(client, AsyncHTTPHandler):
|
||||
self.client = HTTPHandler(timeout=timeout) # type: ignore
|
||||
|
@ -692,7 +692,7 @@ class DatabricksChatCompletion(BaseLLM):
|
|||
status_code=e.response.status_code,
|
||||
message=e.response.text,
|
||||
)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise DatabricksError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise DatabricksError(status_code=500, message=str(e))
|
||||
|
|
|
@ -71,7 +71,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
self,
|
||||
_is_async: bool,
|
||||
create_file_data: CreateFileRequest,
|
||||
api_base: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
api_version: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
|
@ -117,7 +117,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
self,
|
||||
_is_async: bool,
|
||||
file_content_request: FileContentRequest,
|
||||
api_base: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
|
@ -168,7 +168,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
|
@ -220,7 +220,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
self,
|
||||
_is_async: bool,
|
||||
file_id: str,
|
||||
api_base: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
|
@ -275,7 +275,7 @@ class AzureOpenAIFilesAPI(BaseLLM):
|
|||
def list_files(
|
||||
self,
|
||||
_is_async: bool,
|
||||
api_base: str,
|
||||
api_base: Optional[str],
|
||||
api_key: Optional[str],
|
||||
timeout: Union[float, httpx.Timeout],
|
||||
max_retries: Optional[int],
|
||||
|
|
|
@ -41,7 +41,7 @@ class VertexFineTuningAPI(VertexLLM):
|
|||
created_at = int(create_time_datetime.timestamp())
|
||||
|
||||
return created_at
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return 0
|
||||
|
||||
def convert_vertex_response_to_open_ai_response(
|
||||
|
|
|
@ -136,7 +136,7 @@ class GeminiConfig:
|
|||
# ):
|
||||
# try:
|
||||
# import google.generativeai as genai # type: ignore
|
||||
# except:
|
||||
# except Exception:
|
||||
# raise Exception(
|
||||
# "Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
# )
|
||||
|
@ -282,7 +282,7 @@ class GeminiConfig:
|
|||
# completion_response = model_response["choices"][0]["message"].get("content")
|
||||
# if completion_response is None:
|
||||
# raise Exception
|
||||
# except:
|
||||
# except Exception:
|
||||
# original_response = f"response: {response}"
|
||||
# if hasattr(response, "candidates"):
|
||||
# original_response = f"response: {response.candidates}"
|
||||
|
@ -374,7 +374,7 @@ class GeminiConfig:
|
|||
# completion_response = model_response["choices"][0]["message"].get("content")
|
||||
# if completion_response is None:
|
||||
# raise Exception
|
||||
# except:
|
||||
# except Exception:
|
||||
# original_response = f"response: {response}"
|
||||
# if hasattr(response, "candidates"):
|
||||
# original_response = f"response: {response.candidates}"
|
||||
|
|
|
@ -13,6 +13,7 @@ import requests
|
|||
|
||||
import litellm
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLoggingObj
|
||||
from litellm.secret_managers.main import get_secret_str
|
||||
from litellm.types.completion import ChatCompletionMessageToolCallParam
|
||||
from litellm.utils import Choices, CustomStreamWrapper, Message, ModelResponse, Usage
|
||||
|
||||
|
@ -181,7 +182,7 @@ class HuggingfaceConfig:
|
|||
return optional_params
|
||||
|
||||
def get_hf_api_key(self) -> Optional[str]:
|
||||
return litellm.utils.get_secret("HUGGINGFACE_API_KEY")
|
||||
return get_secret_str("HUGGINGFACE_API_KEY")
|
||||
|
||||
|
||||
def output_parser(generated_text: str):
|
||||
|
@ -240,7 +241,7 @@ def read_tgi_conv_models():
|
|||
# Cache the set for future use
|
||||
conv_models_cache = conv_models
|
||||
return tgi_models, conv_models
|
||||
except:
|
||||
except Exception:
|
||||
return set(), set()
|
||||
|
||||
|
||||
|
@ -372,7 +373,7 @@ class Huggingface(BaseLLM):
|
|||
]["finish_reason"]
|
||||
sum_logprob = 0
|
||||
for token in completion_response[0]["details"]["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
if token["logprob"] is not None:
|
||||
sum_logprob += token["logprob"]
|
||||
setattr(model_response.choices[0].message, "_logprob", sum_logprob) # type: ignore
|
||||
if "best_of" in optional_params and optional_params["best_of"] > 1:
|
||||
|
@ -386,7 +387,7 @@ class Huggingface(BaseLLM):
|
|||
):
|
||||
sum_logprob = 0
|
||||
for token in item["tokens"]:
|
||||
if token["logprob"] != None:
|
||||
if token["logprob"] is not None:
|
||||
sum_logprob += token["logprob"]
|
||||
if len(item["generated_text"]) > 0:
|
||||
message_obj = Message(
|
||||
|
@ -417,7 +418,7 @@ class Huggingface(BaseLLM):
|
|||
prompt_tokens = len(
|
||||
encoding.encode(input_text)
|
||||
) ##[TODO] use the llama2 tokenizer here
|
||||
except:
|
||||
except Exception:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
output_text = model_response["choices"][0]["message"].get("content", "")
|
||||
|
@ -429,7 +430,7 @@ class Huggingface(BaseLLM):
|
|||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
) ##[TODO] use the llama2 tokenizer here
|
||||
except:
|
||||
except Exception:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
else:
|
||||
|
@ -559,7 +560,7 @@ class Huggingface(BaseLLM):
|
|||
True
|
||||
if "stream" in optional_params
|
||||
and isinstance(optional_params["stream"], bool)
|
||||
and optional_params["stream"] == True # type: ignore
|
||||
and optional_params["stream"] is True # type: ignore
|
||||
else False
|
||||
),
|
||||
}
|
||||
|
@ -595,7 +596,7 @@ class Huggingface(BaseLLM):
|
|||
data["stream"] = ( # type: ignore
|
||||
True # type: ignore
|
||||
if "stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and optional_params["stream"] is True
|
||||
else False
|
||||
)
|
||||
input_text = prompt
|
||||
|
@ -631,7 +632,7 @@ class Huggingface(BaseLLM):
|
|||
### ASYNC COMPLETION
|
||||
return self.acompletion(api_base=completion_url, data=data, headers=headers, model_response=model_response, task=task, encoding=encoding, input_text=input_text, model=model, optional_params=optional_params, timeout=timeout) # type: ignore
|
||||
### SYNC STREAMING
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
response = requests.post(
|
||||
completion_url,
|
||||
headers=headers,
|
||||
|
@ -691,7 +692,7 @@ class Huggingface(BaseLLM):
|
|||
completion_response = response.json()
|
||||
if isinstance(completion_response, dict):
|
||||
completion_response = [completion_response]
|
||||
except:
|
||||
except Exception:
|
||||
import traceback
|
||||
|
||||
raise HuggingfaceError(
|
||||
|
|
|
@ -101,7 +101,7 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -135,7 +135,7 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -159,7 +159,7 @@ def completion(
|
|||
model_response.choices[0].message.content = completion_response[ # type: ignore
|
||||
"answer"
|
||||
]
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise MaritalkError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
|
|
@ -120,7 +120,7 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
default_max_tokens_to_sample=None,
|
||||
|
@ -164,7 +164,7 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return clean_and_iterate_chunks(response)
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -178,7 +178,7 @@ def completion(
|
|||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
except Exception:
|
||||
raise NLPCloudError(message=response.text, status_code=response.status_code)
|
||||
if "error" in completion_response:
|
||||
raise NLPCloudError(
|
||||
|
@ -191,7 +191,7 @@ def completion(
|
|||
model_response.choices[0].message.content = ( # type: ignore
|
||||
completion_response["generated_text"]
|
||||
)
|
||||
except:
|
||||
except Exception:
|
||||
raise NLPCloudError(
|
||||
message=json.dumps(completion_response),
|
||||
status_code=response.status_code,
|
||||
|
|
|
@ -14,7 +14,7 @@ import requests # type: ignore
|
|||
|
||||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.types.utils import ProviderField
|
||||
from litellm.types.utils import ProviderField, StreamingChoices
|
||||
|
||||
from .prompt_templates.factory import custom_prompt, prompt_factory
|
||||
|
||||
|
@ -172,7 +172,7 @@ def _convert_image(image):
|
|||
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"ollama image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
|
@ -184,7 +184,7 @@ def _convert_image(image):
|
|||
image_data = Image.open(io.BytesIO(base64.b64decode(image)))
|
||||
if image_data.format in ["JPEG", "PNG"]:
|
||||
return image
|
||||
except:
|
||||
except Exception:
|
||||
return orig
|
||||
jpeg_image = io.BytesIO()
|
||||
image_data.convert("RGB").save(jpeg_image, "JPEG")
|
||||
|
@ -195,13 +195,13 @@ def _convert_image(image):
|
|||
# ollama implementation
|
||||
def get_ollama_response(
|
||||
model_response: litellm.ModelResponse,
|
||||
api_base="http://localhost:11434",
|
||||
model="llama2",
|
||||
prompt="Why is the sky blue?",
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model: str,
|
||||
prompt: str,
|
||||
optional_params: dict,
|
||||
logging_obj: Any,
|
||||
encoding: Any,
|
||||
acompletion: bool = False,
|
||||
encoding=None,
|
||||
api_base="http://localhost:11434",
|
||||
):
|
||||
if api_base.endswith("/api/generate"):
|
||||
url = api_base
|
||||
|
@ -242,7 +242,7 @@ def get_ollama_response(
|
|||
},
|
||||
)
|
||||
if acompletion is True:
|
||||
if stream == True:
|
||||
if stream is True:
|
||||
response = ollama_async_streaming(
|
||||
url=url,
|
||||
data=data,
|
||||
|
@ -340,11 +340,16 @@ def ollama_completion_stream(url, data, logging_obj):
|
|||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = next(streamwrapper)
|
||||
response_content = "".join(
|
||||
chunk.choices[0].delta.content
|
||||
for chunk in chain([first_chunk], streamwrapper)
|
||||
if chunk.choices[0].delta.content
|
||||
)
|
||||
content_chunks = []
|
||||
for chunk in chain([first_chunk], streamwrapper):
|
||||
content_chunk = chunk.choices[0]
|
||||
if (
|
||||
isinstance(content_chunk, StreamingChoices)
|
||||
and hasattr(content_chunk, "delta")
|
||||
and hasattr(content_chunk.delta, "content")
|
||||
):
|
||||
content_chunks.append(content_chunk.delta.content)
|
||||
response_content = "".join(content_chunks)
|
||||
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
|
@ -392,15 +397,27 @@ async def ollama_async_streaming(url, data, model_response, encoding, logging_ob
|
|||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = await anext(streamwrapper)
|
||||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||
response_content = first_chunk_content + "".join(
|
||||
[
|
||||
chunk.choices[0].delta.content
|
||||
async for chunk in streamwrapper
|
||||
if chunk.choices[0].delta.content
|
||||
]
|
||||
)
|
||||
first_chunk = await anext(streamwrapper) # noqa F821
|
||||
chunk_choice = first_chunk.choices[0]
|
||||
if (
|
||||
isinstance(chunk_choice, StreamingChoices)
|
||||
and hasattr(chunk_choice, "delta")
|
||||
and hasattr(chunk_choice.delta, "content")
|
||||
):
|
||||
first_chunk_content = chunk_choice.delta.content or ""
|
||||
else:
|
||||
first_chunk_content = ""
|
||||
|
||||
content_chunks = []
|
||||
async for chunk in streamwrapper:
|
||||
chunk_choice = chunk.choices[0]
|
||||
if (
|
||||
isinstance(chunk_choice, StreamingChoices)
|
||||
and hasattr(chunk_choice, "delta")
|
||||
and hasattr(chunk_choice.delta, "content")
|
||||
):
|
||||
content_chunks.append(chunk_choice.delta.content)
|
||||
response_content = first_chunk_content + "".join(content_chunks)
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
content=None,
|
||||
|
@ -501,8 +518,8 @@ async def ollama_aembeddings(
|
|||
prompts: List[str],
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
optional_params: dict,
|
||||
logging_obj=None,
|
||||
encoding=None,
|
||||
logging_obj: Any,
|
||||
encoding: Any,
|
||||
):
|
||||
if api_base.endswith("/api/embed"):
|
||||
url = api_base
|
||||
|
@ -581,9 +598,9 @@ def ollama_embeddings(
|
|||
api_base: str,
|
||||
model: str,
|
||||
prompts: list,
|
||||
optional_params=None,
|
||||
logging_obj=None,
|
||||
model_response=None,
|
||||
optional_params: dict,
|
||||
model_response: litellm.EmbeddingResponse,
|
||||
logging_obj: Any,
|
||||
encoding=None,
|
||||
):
|
||||
return asyncio.run(
|
||||
|
|
|
@ -4,7 +4,7 @@ import traceback
|
|||
import types
|
||||
import uuid
|
||||
from itertools import chain
|
||||
from typing import List, Optional
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import aiohttp
|
||||
import httpx
|
||||
|
@ -15,6 +15,7 @@ import litellm
|
|||
from litellm import verbose_logger
|
||||
from litellm.types.llms.ollama import OllamaToolCall, OllamaToolCallFunction
|
||||
from litellm.types.llms.openai import ChatCompletionAssistantToolCall
|
||||
from litellm.types.utils import StreamingChoices
|
||||
|
||||
|
||||
class OllamaError(Exception):
|
||||
|
@ -216,10 +217,10 @@ def get_ollama_response(
|
|||
model_response: litellm.ModelResponse,
|
||||
messages: list,
|
||||
optional_params: dict,
|
||||
model: str,
|
||||
logging_obj: Any,
|
||||
api_base="http://localhost:11434",
|
||||
api_key: Optional[str] = None,
|
||||
model="llama2",
|
||||
logging_obj=None,
|
||||
acompletion: bool = False,
|
||||
encoding=None,
|
||||
):
|
||||
|
@ -252,10 +253,13 @@ def get_ollama_response(
|
|||
for tool in m["tool_calls"]:
|
||||
typed_tool = ChatCompletionAssistantToolCall(**tool) # type: ignore
|
||||
if typed_tool["type"] == "function":
|
||||
arguments = {}
|
||||
if "arguments" in typed_tool["function"]:
|
||||
arguments = json.loads(typed_tool["function"]["arguments"])
|
||||
ollama_tool_call = OllamaToolCall(
|
||||
function=OllamaToolCallFunction(
|
||||
name=typed_tool["function"]["name"],
|
||||
arguments=json.loads(typed_tool["function"]["arguments"]),
|
||||
name=typed_tool["function"].get("name") or "",
|
||||
arguments=arguments,
|
||||
)
|
||||
)
|
||||
new_tools.append(ollama_tool_call)
|
||||
|
@ -401,12 +405,16 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
|||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = next(streamwrapper)
|
||||
response_content = "".join(
|
||||
chunk.choices[0].delta.content
|
||||
for chunk in chain([first_chunk], streamwrapper)
|
||||
if chunk.choices[0].delta.content
|
||||
)
|
||||
content_chunks = []
|
||||
for chunk in streamwrapper:
|
||||
chunk_choice = chunk.choices[0]
|
||||
if (
|
||||
isinstance(chunk_choice, StreamingChoices)
|
||||
and hasattr(chunk_choice, "delta")
|
||||
and hasattr(chunk_choice.delta, "content")
|
||||
):
|
||||
content_chunks.append(chunk_choice.delta.content)
|
||||
response_content = "".join(content_chunks)
|
||||
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
|
@ -422,7 +430,7 @@ def ollama_completion_stream(url, api_key, data, logging_obj):
|
|||
}
|
||||
],
|
||||
)
|
||||
model_response = first_chunk
|
||||
model_response = content_chunks[0]
|
||||
model_response.choices[0].delta = delta # type: ignore
|
||||
model_response.choices[0].finish_reason = "tool_calls"
|
||||
yield model_response
|
||||
|
@ -462,15 +470,28 @@ async def ollama_async_streaming(
|
|||
# If format is JSON, this was a function call
|
||||
# Gather all chunks and return the function call as one delta to simplify parsing
|
||||
if data.get("format", "") == "json":
|
||||
first_chunk = await anext(streamwrapper)
|
||||
first_chunk_content = first_chunk.choices[0].delta.content or ""
|
||||
response_content = first_chunk_content + "".join(
|
||||
[
|
||||
chunk.choices[0].delta.content
|
||||
async for chunk in streamwrapper
|
||||
if chunk.choices[0].delta.content
|
||||
]
|
||||
)
|
||||
first_chunk = await anext(streamwrapper) # noqa F821
|
||||
chunk_choice = first_chunk.choices[0]
|
||||
if (
|
||||
isinstance(chunk_choice, StreamingChoices)
|
||||
and hasattr(chunk_choice, "delta")
|
||||
and hasattr(chunk_choice.delta, "content")
|
||||
):
|
||||
first_chunk_content = chunk_choice.delta.content or ""
|
||||
else:
|
||||
first_chunk_content = ""
|
||||
|
||||
content_chunks = []
|
||||
async for chunk in streamwrapper:
|
||||
chunk_choice = chunk.choices[0]
|
||||
if (
|
||||
isinstance(chunk_choice, StreamingChoices)
|
||||
and hasattr(chunk_choice, "delta")
|
||||
and hasattr(chunk_choice.delta, "content")
|
||||
):
|
||||
content_chunks.append(chunk_choice.delta.content)
|
||||
response_content = first_chunk_content + "".join(content_chunks)
|
||||
|
||||
function_call = json.loads(response_content)
|
||||
delta = litellm.utils.Delta(
|
||||
content=None,
|
||||
|
|
|
@ -39,8 +39,8 @@ def completion(
|
|||
encoding,
|
||||
api_key,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
default_max_tokens_to_sample=None,
|
||||
|
@ -77,7 +77,7 @@ def completion(
|
|||
data=json.dumps(data),
|
||||
stream=optional_params["stream"] if "stream" in optional_params else False,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return response.iter_lines()
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -91,7 +91,7 @@ def completion(
|
|||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
except Exception:
|
||||
raise OobaboogaError(
|
||||
message=response.text, status_code=response.status_code
|
||||
)
|
||||
|
@ -103,7 +103,7 @@ def completion(
|
|||
else:
|
||||
try:
|
||||
model_response.choices[0].message.content = completion_response["choices"][0]["message"]["content"] # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
raise OobaboogaError(
|
||||
message=json.dumps(completion_response),
|
||||
status_code=response.status_code,
|
||||
|
|
|
@ -96,13 +96,13 @@ def completion(
|
|||
api_key,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
try:
|
||||
import google.generativeai as palm # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
|
@ -167,14 +167,14 @@ def completion(
|
|||
choice_obj = Choices(index=idx + 1, message=message_obj)
|
||||
choices_list.append(choice_obj)
|
||||
model_response.choices = choices_list # type: ignore
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise PalmError(
|
||||
message=traceback.format_exc(), status_code=response.status_code
|
||||
)
|
||||
|
||||
try:
|
||||
completion_response = model_response["choices"][0]["message"].get("content")
|
||||
except:
|
||||
except Exception:
|
||||
raise PalmError(
|
||||
status_code=400,
|
||||
message=f"No response received. Original response - {response}",
|
||||
|
|
|
@ -98,7 +98,7 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params=None,
|
||||
optional_params: dict,
|
||||
stream=False,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
|
@ -123,6 +123,7 @@ def completion(
|
|||
else:
|
||||
prompt = prompt_factory(model=model, messages=messages)
|
||||
|
||||
output_text: Optional[str] = None
|
||||
if api_base:
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
|
@ -157,7 +158,7 @@ def completion(
|
|||
import torch
|
||||
from petals import AutoDistributedModelForCausalLM # type: ignore
|
||||
from transformers import AutoTokenizer
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Importing torch, transformers, petals failed\nTry pip installing petals \npip install git+https://github.com/bigscience-workshop/petals"
|
||||
)
|
||||
|
@ -192,7 +193,7 @@ def completion(
|
|||
## RESPONSE OBJECT
|
||||
output_text = tokenizer.decode(outputs[0])
|
||||
|
||||
if len(output_text) > 0:
|
||||
if output_text is not None and len(output_text) > 0:
|
||||
model_response.choices[0].message.content = output_text # type: ignore
|
||||
|
||||
prompt_tokens = len(encoding.encode(prompt))
|
||||
|
|
|
@ -265,7 +265,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
## RESPONSE OBJECT
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
except Exception:
|
||||
raise PredibaseError(message=response.text, status_code=422)
|
||||
if "error" in completion_response:
|
||||
raise PredibaseError(
|
||||
|
@ -348,7 +348,7 @@ class PredibaseChatCompletion(BaseLLM):
|
|||
model_response["choices"][0]["message"].get("content", "")
|
||||
)
|
||||
) ##[TODO] use a model-specific tokenizer
|
||||
except:
|
||||
except Exception:
|
||||
# this should remain non blocking we should not block a response returning if calculating usage fails
|
||||
pass
|
||||
else:
|
||||
|
|
|
@ -5,7 +5,7 @@ import traceback
|
|||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from enum import Enum
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple
|
||||
from typing import Any, List, Mapping, MutableMapping, Optional, Sequence, Tuple, cast
|
||||
|
||||
from jinja2 import BaseLoader, Template, exceptions, meta
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
@ -26,11 +26,14 @@ from litellm.types.completion import (
|
|||
)
|
||||
from litellm.types.llms.anthropic import *
|
||||
from litellm.types.llms.bedrock import MessageBlock as BedrockMessageBlock
|
||||
from litellm.types.llms.ollama import OllamaVisionModelObject
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionAssistantToolCall,
|
||||
ChatCompletionFunctionMessage,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionTextObject,
|
||||
ChatCompletionToolCallFunctionChunk,
|
||||
ChatCompletionToolMessage,
|
||||
ChatCompletionUserMessage,
|
||||
|
@ -164,7 +167,9 @@ def convert_to_ollama_image(openai_image_url: str):
|
|||
|
||||
def ollama_pt(
|
||||
model, messages
|
||||
): # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
) -> Union[
|
||||
str, OllamaVisionModelObject
|
||||
]: # https://github.com/ollama/ollama/blob/af4cf55884ac54b9e637cd71dadfe9b7a5685877/docs/modelfile.md#template
|
||||
if "instruct" in model:
|
||||
prompt = custom_prompt(
|
||||
role_dict={
|
||||
|
@ -438,7 +443,7 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
|||
def _is_system_in_template():
|
||||
try:
|
||||
# Try rendering the template with a system message
|
||||
response = template.render(
|
||||
template.render(
|
||||
messages=[{"role": "system", "content": "test"}],
|
||||
eos_token="<eos>",
|
||||
bos_token="<bos>",
|
||||
|
@ -446,10 +451,11 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
|||
return True
|
||||
|
||||
# This will be raised if Jinja attempts to render the system message and it can't
|
||||
except:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
try:
|
||||
rendered_text = ""
|
||||
# Render the template with the provided values
|
||||
if _is_system_in_template():
|
||||
rendered_text = template.render(
|
||||
|
@ -460,8 +466,8 @@ def hf_chat_template(model: str, messages: list, chat_template: Optional[Any] =
|
|||
)
|
||||
else:
|
||||
# treat a system message as a user message, if system not in template
|
||||
try:
|
||||
reformatted_messages = []
|
||||
try:
|
||||
for message in messages:
|
||||
if message["role"] == "system":
|
||||
reformatted_messages.append(
|
||||
|
@ -556,30 +562,31 @@ def get_model_info(token, model):
|
|||
return None, None
|
||||
else:
|
||||
return None, None
|
||||
except Exception as e: # safely fail a prompt template request
|
||||
except Exception: # safely fail a prompt template request
|
||||
return None, None
|
||||
|
||||
|
||||
def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||
if prompt_format is None:
|
||||
return default_pt(messages)
|
||||
## OLD TOGETHER AI FLOW
|
||||
# def format_prompt_togetherai(messages, prompt_format, chat_template):
|
||||
# if prompt_format is None:
|
||||
# return default_pt(messages)
|
||||
|
||||
human_prompt, assistant_prompt = prompt_format.split("{prompt}")
|
||||
# human_prompt, assistant_prompt = prompt_format.split("{prompt}")
|
||||
|
||||
if chat_template is not None:
|
||||
prompt = hf_chat_template(
|
||||
model=None, messages=messages, chat_template=chat_template
|
||||
)
|
||||
elif prompt_format is not None:
|
||||
prompt = custom_prompt(
|
||||
role_dict={},
|
||||
messages=messages,
|
||||
initial_prompt_value=human_prompt,
|
||||
final_prompt_value=assistant_prompt,
|
||||
)
|
||||
else:
|
||||
prompt = default_pt(messages)
|
||||
return prompt
|
||||
# if chat_template is not None:
|
||||
# prompt = hf_chat_template(
|
||||
# model=None, messages=messages, chat_template=chat_template
|
||||
# )
|
||||
# elif prompt_format is not None:
|
||||
# prompt = custom_prompt(
|
||||
# role_dict={},
|
||||
# messages=messages,
|
||||
# initial_prompt_value=human_prompt,
|
||||
# final_prompt_value=assistant_prompt,
|
||||
# )
|
||||
# else:
|
||||
# prompt = default_pt(messages)
|
||||
# return prompt
|
||||
|
||||
|
||||
### IBM Granite
|
||||
|
@ -1063,7 +1070,7 @@ def convert_to_gemini_tool_call_invoke(
|
|||
else: # don't silently drop params. Make it clear to user what's happening.
|
||||
raise Exception(
|
||||
"function_call missing. Received tool call with 'type': 'function'. No function call in argument - {}".format(
|
||||
tool
|
||||
message
|
||||
)
|
||||
)
|
||||
return _parts_list
|
||||
|
@ -1216,12 +1223,14 @@ def convert_function_to_anthropic_tool_invoke(
|
|||
function_call: Union[dict, ChatCompletionToolCallFunctionChunk],
|
||||
) -> List[AnthropicMessagesToolUseParam]:
|
||||
try:
|
||||
_name = get_attribute_or_key(function_call, "name") or ""
|
||||
_arguments = get_attribute_or_key(function_call, "arguments")
|
||||
anthropic_tool_invoke = [
|
||||
AnthropicMessagesToolUseParam(
|
||||
type="tool_use",
|
||||
id=str(uuid.uuid4()),
|
||||
name=get_attribute_or_key(function_call, "name"),
|
||||
input=json.loads(get_attribute_or_key(function_call, "arguments")),
|
||||
name=_name,
|
||||
input=json.loads(_arguments) if _arguments else {},
|
||||
)
|
||||
]
|
||||
return anthropic_tool_invoke
|
||||
|
@ -1349,8 +1358,9 @@ def anthropic_messages_pt(
|
|||
):
|
||||
for m in user_message_types_block["content"]:
|
||||
if m.get("type", "") == "image_url":
|
||||
m = cast(ChatCompletionImageObject, m)
|
||||
image_chunk = convert_to_anthropic_image_obj(
|
||||
m["image_url"]["url"]
|
||||
openai_image_url=m["image_url"]["url"] # type: ignore
|
||||
)
|
||||
|
||||
_anthropic_content_element = AnthropicMessagesImageParam(
|
||||
|
@ -1362,21 +1372,31 @@ def anthropic_messages_pt(
|
|||
),
|
||||
)
|
||||
|
||||
anthropic_content_element = add_cache_control_to_content(
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_content_element,
|
||||
orignal_content_element=m,
|
||||
orignal_content_element=dict(m),
|
||||
)
|
||||
user_content.append(anthropic_content_element)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_content_element["cache_control"] = (
|
||||
_content_element["cache_control"]
|
||||
)
|
||||
user_content.append(_anthropic_content_element)
|
||||
elif m.get("type", "") == "text":
|
||||
_anthropic_text_content_element = {
|
||||
"type": "text",
|
||||
"text": m["text"],
|
||||
}
|
||||
anthropic_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
orignal_content_element=m,
|
||||
m = cast(ChatCompletionTextObject, m)
|
||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||
type="text",
|
||||
text=m["text"],
|
||||
)
|
||||
user_content.append(anthropic_content_element)
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
orignal_content_element=dict(m),
|
||||
)
|
||||
_content_element = cast(
|
||||
AnthropicMessagesTextParam, _content_element
|
||||
)
|
||||
|
||||
user_content.append(_content_element)
|
||||
elif (
|
||||
user_message_types_block["role"] == "tool"
|
||||
or user_message_types_block["role"] == "function"
|
||||
|
@ -1390,12 +1410,17 @@ def anthropic_messages_pt(
|
|||
"type": "text",
|
||||
"text": user_message_types_block["content"],
|
||||
}
|
||||
anthropic_content_element = add_cache_control_to_content(
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_content_text_element,
|
||||
orignal_content_element=user_message_types_block,
|
||||
orignal_content_element=dict(user_message_types_block),
|
||||
)
|
||||
|
||||
user_content.append(anthropic_content_element)
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_content_text_element["cache_control"] = _content_element[
|
||||
"cache_control"
|
||||
]
|
||||
|
||||
user_content.append(_anthropic_content_text_element)
|
||||
|
||||
msg_i += 1
|
||||
|
||||
|
@ -1417,11 +1442,14 @@ def anthropic_messages_pt(
|
|||
anthropic_message = AnthropicMessagesTextParam(
|
||||
type="text", text=m.get("text")
|
||||
)
|
||||
anthropic_message = add_cache_control_to_content(
|
||||
_cached_message = add_cache_control_to_content(
|
||||
anthropic_content_element=anthropic_message,
|
||||
orignal_content_element=m,
|
||||
orignal_content_element=dict(m),
|
||||
)
|
||||
|
||||
assistant_content.append(
|
||||
cast(AnthropicMessagesTextParam, _cached_message)
|
||||
)
|
||||
assistant_content.append(anthropic_message)
|
||||
elif (
|
||||
"content" in assistant_content_block
|
||||
and isinstance(assistant_content_block["content"], str)
|
||||
|
@ -1430,16 +1458,22 @@ def anthropic_messages_pt(
|
|||
] # don't pass empty text blocks. anthropic api raises errors.
|
||||
):
|
||||
|
||||
_anthropic_text_content_element = {
|
||||
"type": "text",
|
||||
"text": assistant_content_block["content"],
|
||||
}
|
||||
|
||||
anthropic_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
orignal_content_element=assistant_content_block,
|
||||
_anthropic_text_content_element = AnthropicMessagesTextParam(
|
||||
type="text",
|
||||
text=assistant_content_block["content"],
|
||||
)
|
||||
assistant_content.append(anthropic_content_element)
|
||||
|
||||
_content_element = add_cache_control_to_content(
|
||||
anthropic_content_element=_anthropic_text_content_element,
|
||||
orignal_content_element=dict(assistant_content_block),
|
||||
)
|
||||
|
||||
if "cache_control" in _content_element:
|
||||
_anthropic_text_content_element["cache_control"] = _content_element[
|
||||
"cache_control"
|
||||
]
|
||||
|
||||
assistant_content.append(_anthropic_text_content_element)
|
||||
|
||||
assistant_tool_calls = assistant_content_block.get("tool_calls")
|
||||
if (
|
||||
|
@ -1566,30 +1600,6 @@ def get_system_prompt(messages):
|
|||
return system_prompt, messages
|
||||
|
||||
|
||||
def convert_to_documents(
|
||||
observations: Any,
|
||||
) -> List[MutableMapping]:
|
||||
"""Converts observations into a 'document' dict"""
|
||||
documents: List[MutableMapping] = []
|
||||
if isinstance(observations, str):
|
||||
# strings are turned into a key/value pair and a key of 'output' is added.
|
||||
observations = [{"output": observations}]
|
||||
elif isinstance(observations, Mapping):
|
||||
# single mappings are transformed into a list to simplify the rest of the code.
|
||||
observations = [observations]
|
||||
elif not isinstance(observations, Sequence):
|
||||
# all other types are turned into a key/value pair within a list
|
||||
observations = [{"output": observations}]
|
||||
|
||||
for doc in observations:
|
||||
if not isinstance(doc, Mapping):
|
||||
# types that aren't Mapping are turned into a key/value pair.
|
||||
doc = {"output": doc}
|
||||
documents.append(doc)
|
||||
|
||||
return documents
|
||||
|
||||
|
||||
from litellm.types.llms.cohere import (
|
||||
CallObject,
|
||||
ChatHistory,
|
||||
|
@ -1943,7 +1953,7 @@ def amazon_titan_pt(
|
|||
def _load_image_from_url(image_url):
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception("image conversion failed please run `pip install Pillow`")
|
||||
from io import BytesIO
|
||||
|
||||
|
@ -2008,7 +2018,7 @@ def _gemini_vision_convert_messages(messages: list):
|
|||
else:
|
||||
try:
|
||||
from PIL import Image
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"gemini image conversion failed please run `pip install Pillow`"
|
||||
)
|
||||
|
@ -2056,7 +2066,7 @@ def gemini_text_image_pt(messages: list):
|
|||
"""
|
||||
try:
|
||||
import google.generativeai as genai # type: ignore
|
||||
except:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
"Importing google.generativeai failed, please run 'pip install -q google-generativeai"
|
||||
)
|
||||
|
@ -2331,7 +2341,7 @@ def _convert_to_bedrock_tool_call_result(
|
|||
for content in content_list:
|
||||
if content["type"] == "text":
|
||||
content_str += content["text"]
|
||||
name = message.get("name", "")
|
||||
message.get("name", "")
|
||||
id = str(message.get("tool_call_id", str(uuid.uuid4())))
|
||||
|
||||
tool_result_content_block = BedrockToolResultContentBlock(text=content_str)
|
||||
|
@ -2575,7 +2585,7 @@ def function_call_prompt(messages: list, functions: list):
|
|||
message["content"] += f""" {function_prompt}"""
|
||||
function_added_to_prompt = True
|
||||
|
||||
if function_added_to_prompt == False:
|
||||
if function_added_to_prompt is False:
|
||||
messages.append({"role": "system", "content": f"""{function_prompt}"""})
|
||||
|
||||
return messages
|
||||
|
@ -2692,11 +2702,6 @@ def prompt_factory(
|
|||
)
|
||||
elif custom_llm_provider == "anthropic_xml":
|
||||
return anthropic_messages_pt_xml(messages=messages)
|
||||
elif custom_llm_provider == "together_ai":
|
||||
prompt_format, chat_template = get_model_info(token=api_key, model=model)
|
||||
return format_prompt_togetherai(
|
||||
messages=messages, prompt_format=prompt_format, chat_template=chat_template
|
||||
)
|
||||
elif custom_llm_provider == "gemini":
|
||||
if (
|
||||
model == "gemini-pro-vision"
|
||||
|
@ -2810,7 +2815,7 @@ def prompt_factory(
|
|||
)
|
||||
else:
|
||||
return hf_chat_template(original_model_name, messages)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return default_pt(
|
||||
messages=messages
|
||||
) # default that covers Bloom, T-5, any non-chat tuned model (e.g. base Llama2)
|
||||
|
|
|
@ -61,7 +61,7 @@ async def async_convert_url_to_base64(url: str) -> str:
|
|||
try:
|
||||
response = await client.get(url, follow_redirects=True)
|
||||
return _process_image_response(response, url)
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
raise Exception(
|
||||
f"Error: Unable to fetch image from URL after 3 attempts. url={url}"
|
||||
|
|
|
@ -297,7 +297,7 @@ def handle_prediction_response_streaming(prediction_url, api_token, print_verbos
|
|||
if "output" in response_data:
|
||||
try:
|
||||
output_string = "".join(response_data["output"])
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ReplicateError(
|
||||
status_code=422,
|
||||
message="Unable to parse response. Got={}".format(
|
||||
|
@ -344,7 +344,7 @@ async def async_handle_prediction_response_streaming(
|
|||
if "output" in response_data:
|
||||
try:
|
||||
output_string = "".join(response_data["output"])
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise ReplicateError(
|
||||
status_code=422,
|
||||
message="Unable to parse response. Got={}".format(
|
||||
|
@ -479,7 +479,7 @@ def completion(
|
|||
else:
|
||||
input_data = {"prompt": prompt, **optional_params}
|
||||
|
||||
if acompletion is not None and acompletion == True:
|
||||
if acompletion is not None and acompletion is True:
|
||||
return async_completion(
|
||||
model_response=model_response,
|
||||
model=model,
|
||||
|
@ -513,7 +513,7 @@ def completion(
|
|||
print_verbose(prediction_url)
|
||||
|
||||
# Handle the prediction response (streaming or non-streaming)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
print_verbose("streaming request")
|
||||
_response = handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
|
@ -571,7 +571,7 @@ async def async_completion(
|
|||
http_handler=http_handler,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
_response = async_handle_prediction_response_streaming(
|
||||
prediction_url, api_key, print_verbose
|
||||
)
|
||||
|
|
|
@ -8,7 +8,7 @@ import types
|
|||
from copy import deepcopy
|
||||
from enum import Enum
|
||||
from functools import partial
|
||||
from typing import Any, AsyncIterator, Callable, Iterator, List, Optional, Union
|
||||
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
@ -112,7 +112,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
):
|
||||
try:
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
## CREDENTIALS ##
|
||||
# pop aws_secret_access_key, aws_access_key_id, aws_session_token, aws_region_name from kwargs, since completion calls fail with them
|
||||
|
@ -123,7 +123,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
aws_role_name = optional_params.pop("aws_role_name", None)
|
||||
aws_session_name = optional_params.pop("aws_session_name", None)
|
||||
aws_profile_name = optional_params.pop("aws_profile_name", None)
|
||||
aws_bedrock_runtime_endpoint = optional_params.pop(
|
||||
optional_params.pop(
|
||||
"aws_bedrock_runtime_endpoint", None
|
||||
) # https://bedrock-runtime.{region_name}.amazonaws.com
|
||||
aws_web_identity_token = optional_params.pop("aws_web_identity_token", None)
|
||||
|
@ -175,7 +175,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
from botocore.auth import SigV4Auth
|
||||
from botocore.awsrequest import AWSRequest
|
||||
from botocore.credentials import Credentials
|
||||
except ImportError as e:
|
||||
except ImportError:
|
||||
raise ImportError("Missing boto3 to call bedrock. Run 'pip install boto3'.")
|
||||
|
||||
sigv4 = SigV4Auth(credentials, "sagemaker", aws_region_name)
|
||||
|
@ -244,7 +244,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
hf_model_name = (
|
||||
hf_model_name or model
|
||||
) # pass in hf model name for pulling it's prompt template - (e.g. `hf_model_name="meta-llama/Llama-2-7b-chat-hf` applies the llama2 chat template to the prompt)
|
||||
prompt = prompt_factory(model=hf_model_name, messages=messages)
|
||||
prompt: str = prompt_factory(model=hf_model_name, messages=messages) # type: ignore
|
||||
|
||||
return prompt
|
||||
|
||||
|
@ -256,10 +256,10 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
timeout: Optional[Union[float, httpx.Timeout]] = None,
|
||||
custom_prompt_dict={},
|
||||
hf_model_name=None,
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
acompletion: bool = False,
|
||||
|
@ -277,7 +277,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
|
||||
openai_like_chat_completions = DatabricksChatCompletion()
|
||||
inference_params["stream"] = True if stream is True else False
|
||||
_data = {
|
||||
_data: Dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
**inference_params,
|
||||
|
@ -310,7 +310,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
logger_fn=logger_fn,
|
||||
timeout=timeout,
|
||||
encoding=encoding,
|
||||
headers=prepared_request.headers,
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
custom_endpoint=True,
|
||||
custom_llm_provider="sagemaker_chat",
|
||||
streaming_decoder=custom_stream_decoder, # type: ignore
|
||||
|
@ -474,7 +474,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
try:
|
||||
sync_response = sync_handler.post(
|
||||
url=prepared_request.url,
|
||||
headers=prepared_request.headers,
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
json=_data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
@ -559,7 +559,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
self,
|
||||
api_base: str,
|
||||
headers: dict,
|
||||
data: str,
|
||||
data: dict,
|
||||
logging_obj,
|
||||
client=None,
|
||||
):
|
||||
|
@ -598,7 +598,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
except httpx.HTTPStatusError as err:
|
||||
error_code = err.response.status_code
|
||||
raise SagemakerError(status_code=error_code, message=err.response.text)
|
||||
except httpx.TimeoutException as e:
|
||||
except httpx.TimeoutException:
|
||||
raise SagemakerError(status_code=408, message="Timeout error occurred.")
|
||||
except Exception as e:
|
||||
raise SagemakerError(status_code=500, message=str(e))
|
||||
|
@ -638,7 +638,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
make_call=partial(
|
||||
self.make_async_call,
|
||||
api_base=prepared_request.url,
|
||||
headers=prepared_request.headers,
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
data=data,
|
||||
logging_obj=logging_obj,
|
||||
),
|
||||
|
@ -716,7 +716,7 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
try:
|
||||
response = await async_handler.post(
|
||||
url=prepared_request.url,
|
||||
headers=prepared_request.headers,
|
||||
headers=prepared_request.headers, # type: ignore
|
||||
json=data,
|
||||
timeout=timeout,
|
||||
)
|
||||
|
@ -794,8 +794,8 @@ class SagemakerLLM(BaseAWSLLM):
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -1032,7 +1032,7 @@ class AWSEventStreamDecoder:
|
|||
yield self._chunk_parser_messages_api(chunk_data=_data)
|
||||
else:
|
||||
yield self._chunk_parser(chunk_data=_data)
|
||||
except json.JSONDecodeError as e:
|
||||
except json.JSONDecodeError:
|
||||
# Handle or log any unparseable data at the end
|
||||
verbose_logger.error(
|
||||
f"Warning: Unparseable JSON data remained: {accumulated_json}"
|
||||
|
|
|
@ -17,6 +17,7 @@ import requests # type: ignore
|
|||
import litellm
|
||||
from litellm import verbose_logger
|
||||
from litellm.litellm_core_utils.core_helpers import map_finish_reason
|
||||
from litellm.litellm_core_utils.litellm_logging import Logging as LiteLLMLogging
|
||||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler
|
||||
from litellm.types.llms.databricks import GenericStreamingChunk
|
||||
from litellm.utils import (
|
||||
|
@ -157,7 +158,7 @@ class MistralTextCompletionConfig:
|
|||
optional_params["top_p"] = value
|
||||
if param == "max_tokens" or param == "max_completion_tokens":
|
||||
optional_params["max_tokens"] = value
|
||||
if param == "stream" and value == True:
|
||||
if param == "stream" and value is True:
|
||||
optional_params["stream"] = value
|
||||
if param == "stop":
|
||||
optional_params["stop"] = value
|
||||
|
@ -249,7 +250,7 @@ class CodestralTextCompletion(BaseLLM):
|
|||
response: Union[requests.Response, httpx.Response],
|
||||
model_response: TextCompletionResponse,
|
||||
stream: bool,
|
||||
logging_obj: litellm.litellm_core_utils.litellm_logging.Logging,
|
||||
logging_obj: LiteLLMLogging,
|
||||
optional_params: dict,
|
||||
api_key: str,
|
||||
data: Union[dict, str],
|
||||
|
@ -273,7 +274,7 @@ class CodestralTextCompletion(BaseLLM):
|
|||
)
|
||||
try:
|
||||
completion_response = response.json()
|
||||
except:
|
||||
except Exception:
|
||||
raise TextCompletionCodestralError(message=response.text, status_code=422)
|
||||
|
||||
_original_choices = completion_response.get("choices", [])
|
||||
|
|
|
@ -176,7 +176,7 @@ class VertexAIConfig:
|
|||
if param == "top_p":
|
||||
optional_params["top_p"] = value
|
||||
if (
|
||||
param == "stream" and value == True
|
||||
param == "stream" and value is True
|
||||
): # sending stream = False, can cause it to get passed unchecked and raise issues
|
||||
optional_params["stream"] = value
|
||||
if param == "n":
|
||||
|
@ -1313,7 +1313,6 @@ class ModelResponseIterator:
|
|||
|
||||
text = ""
|
||||
tool_use: Optional[ChatCompletionToolCallChunk] = None
|
||||
is_finished = False
|
||||
finish_reason = ""
|
||||
usage: Optional[ChatCompletionUsageBlock] = None
|
||||
_candidates: Optional[List[Candidates]] = processed_chunk.get("candidates")
|
||||
|
|
|
@ -268,7 +268,7 @@ def completion(
|
|||
):
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
except Exception:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="""vertexai import failed please run `pip install -U google-cloud-aiplatform "anthropic[vertex]"`""",
|
||||
|
|
|
@ -5,7 +5,7 @@ import time
|
|||
import types
|
||||
import uuid
|
||||
from enum import Enum
|
||||
from typing import Any, Callable, List, Literal, Optional, Union
|
||||
from typing import Any, Callable, List, Literal, Optional, Union, cast
|
||||
|
||||
import httpx # type: ignore
|
||||
import requests # type: ignore
|
||||
|
@ -25,7 +25,12 @@ from litellm.types.files import (
|
|||
is_gemini_1_5_accepted_file_type,
|
||||
is_video_file_type,
|
||||
)
|
||||
from litellm.types.llms.openai import AllMessageValues, ChatCompletionAssistantMessage
|
||||
from litellm.types.llms.openai import (
|
||||
AllMessageValues,
|
||||
ChatCompletionAssistantMessage,
|
||||
ChatCompletionImageObject,
|
||||
ChatCompletionTextObject,
|
||||
)
|
||||
from litellm.types.llms.vertex_ai import *
|
||||
from litellm.utils import CustomStreamWrapper, ModelResponse, Usage
|
||||
|
||||
|
@ -150,30 +155,34 @@ def _gemini_convert_messages_with_history(
|
|||
while (
|
||||
msg_i < len(messages) and messages[msg_i]["role"] in user_message_types
|
||||
):
|
||||
if messages[msg_i]["content"] is not None and isinstance(
|
||||
messages[msg_i]["content"], list
|
||||
):
|
||||
_message_content = messages[msg_i].get("content")
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts: List[PartType] = []
|
||||
for element in messages[msg_i]["content"]: # type: ignore
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text" and len(element["text"]) > 0: # type: ignore
|
||||
_part = PartType(text=element["text"]) # type: ignore
|
||||
for element in _message_content:
|
||||
if (
|
||||
element["type"] == "text"
|
||||
and "text" in element
|
||||
and len(element["text"]) > 0
|
||||
):
|
||||
element = cast(ChatCompletionTextObject, element)
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
elif element["type"] == "image_url":
|
||||
img_element: ChatCompletionImageObject = element # type: ignore
|
||||
element = cast(ChatCompletionImageObject, element)
|
||||
img_element = element
|
||||
if isinstance(img_element["image_url"], dict):
|
||||
image_url = img_element["image_url"]["url"]
|
||||
else:
|
||||
image_url = img_element["image_url"]
|
||||
_part = _process_gemini_image(image_url=image_url)
|
||||
_parts.append(_part) # type: ignore
|
||||
_parts.append(_part)
|
||||
user_content.extend(_parts)
|
||||
elif (
|
||||
messages[msg_i]["content"] is not None
|
||||
and isinstance(messages[msg_i]["content"], str)
|
||||
and len(messages[msg_i]["content"]) > 0 # type: ignore
|
||||
_message_content is not None
|
||||
and isinstance(_message_content, str)
|
||||
and len(_message_content) > 0
|
||||
):
|
||||
_part = PartType(text=messages[msg_i]["content"]) # type: ignore
|
||||
_part = PartType(text=_message_content)
|
||||
user_content.append(_part)
|
||||
|
||||
msg_i += 1
|
||||
|
@ -201,22 +210,21 @@ def _gemini_convert_messages_with_history(
|
|||
else:
|
||||
msg_dict = messages[msg_i] # type: ignore
|
||||
assistant_msg = ChatCompletionAssistantMessage(**msg_dict) # type: ignore
|
||||
if assistant_msg.get("content", None) is not None and isinstance(
|
||||
assistant_msg["content"], list
|
||||
):
|
||||
_message_content = assistant_msg.get("content", None)
|
||||
if _message_content is not None and isinstance(_message_content, list):
|
||||
_parts = []
|
||||
for element in assistant_msg["content"]:
|
||||
for element in _message_content:
|
||||
if isinstance(element, dict):
|
||||
if element["type"] == "text":
|
||||
_part = PartType(text=element["text"]) # type: ignore
|
||||
_part = PartType(text=element["text"])
|
||||
_parts.append(_part)
|
||||
assistant_content.extend(_parts)
|
||||
elif (
|
||||
assistant_msg.get("content", None) is not None
|
||||
and isinstance(assistant_msg["content"], str)
|
||||
and assistant_msg["content"]
|
||||
_message_content is not None
|
||||
and isinstance(_message_content, str)
|
||||
and _message_content
|
||||
):
|
||||
assistant_text = assistant_msg["content"] # either string or none
|
||||
assistant_text = _message_content # either string or none
|
||||
assistant_content.append(PartType(text=assistant_text)) # type: ignore
|
||||
|
||||
## HANDLE ASSISTANT FUNCTION CALL
|
||||
|
@ -256,7 +264,9 @@ def _gemini_convert_messages_with_history(
|
|||
raise e
|
||||
|
||||
|
||||
def _get_client_cache_key(model: str, vertex_project: str, vertex_location: str):
|
||||
def _get_client_cache_key(
|
||||
model: str, vertex_project: Optional[str], vertex_location: Optional[str]
|
||||
):
|
||||
_cache_key = f"{model}-{vertex_project}-{vertex_location}"
|
||||
return _cache_key
|
||||
|
||||
|
@ -294,7 +304,7 @@ def completion(
|
|||
"""
|
||||
try:
|
||||
import vertexai
|
||||
except:
|
||||
except Exception:
|
||||
raise VertexAIError(
|
||||
status_code=400,
|
||||
message="vertexai import failed please run `pip install google-cloud-aiplatform`. This is required for the 'vertex_ai/' route on LiteLLM",
|
||||
|
@ -339,6 +349,8 @@ def completion(
|
|||
_vertex_llm_model_object = _get_client_from_cache(client_cache_key=_cache_key)
|
||||
|
||||
if _vertex_llm_model_object is None:
|
||||
from google.auth.credentials import Credentials
|
||||
|
||||
if vertex_credentials is not None and isinstance(vertex_credentials, str):
|
||||
import google.oauth2.service_account
|
||||
|
||||
|
@ -356,7 +368,9 @@ def completion(
|
|||
f"VERTEX AI: creds={creds}; google application credentials: {os.getenv('GOOGLE_APPLICATION_CREDENTIALS')}"
|
||||
)
|
||||
vertexai.init(
|
||||
project=vertex_project, location=vertex_location, credentials=creds
|
||||
project=vertex_project,
|
||||
location=vertex_location,
|
||||
credentials=cast(Credentials, creds),
|
||||
)
|
||||
|
||||
## Load Config
|
||||
|
@ -391,7 +405,6 @@ def completion(
|
|||
|
||||
request_str = ""
|
||||
response_obj = None
|
||||
async_client = None
|
||||
instances = None
|
||||
client_options = {
|
||||
"api_endpoint": f"{vertex_location}-aiplatform.googleapis.com"
|
||||
|
@ -400,7 +413,7 @@ def completion(
|
|||
model in litellm.vertex_language_models
|
||||
or model in litellm.vertex_vision_models
|
||||
):
|
||||
llm_model = _vertex_llm_model_object or GenerativeModel(model)
|
||||
llm_model: Any = _vertex_llm_model_object or GenerativeModel(model)
|
||||
mode = "vision"
|
||||
request_str += f"llm_model = GenerativeModel({model})\n"
|
||||
elif model in litellm.vertex_chat_models:
|
||||
|
@ -459,7 +472,6 @@ def completion(
|
|||
"model_response": model_response,
|
||||
"encoding": encoding,
|
||||
"messages": messages,
|
||||
"request_str": request_str,
|
||||
"print_verbose": print_verbose,
|
||||
"client_options": client_options,
|
||||
"instances": instances,
|
||||
|
@ -474,6 +486,7 @@ def completion(
|
|||
|
||||
return async_completion(**data)
|
||||
|
||||
completion_response = None
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro / Pro Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
|
@ -529,7 +542,7 @@ def completion(
|
|||
# Check if it's a RepeatedComposite instance
|
||||
for key, val in function_call.args.items():
|
||||
if isinstance(
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
|
||||
):
|
||||
# If so, convert to list
|
||||
args_dict[key] = [v for v in val]
|
||||
|
@ -560,9 +573,9 @@ def completion(
|
|||
optional_params["tools"] = tools
|
||||
elif mode == "chat":
|
||||
chat = llm_model.start_chat()
|
||||
request_str += f"chat = llm_model.start_chat()\n"
|
||||
request_str += "chat = llm_model.start_chat()\n"
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# NOTE: VertexAI does not accept stream=True as a param and raises an error,
|
||||
# we handle this by removing 'stream' from optional params and sending the request
|
||||
# after we get the response we add optional_params["stream"] = True, since main.py needs to know it's a streaming response to then transform it for the OpenAI format
|
||||
|
@ -597,7 +610,7 @@ def completion(
|
|||
)
|
||||
completion_response = chat.send_message(prompt, **optional_params).text
|
||||
elif mode == "text":
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
optional_params.pop(
|
||||
"stream", None
|
||||
) # See note above on handling streaming for vertex ai
|
||||
|
@ -632,6 +645,12 @@ def completion(
|
|||
"""
|
||||
Vertex AI Model Garden
|
||||
"""
|
||||
|
||||
if vertex_project is None or vertex_location is None:
|
||||
raise ValueError(
|
||||
"Vertex project and location are required for custom endpoint"
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -661,13 +680,17 @@ def completion(
|
|||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
elif mode == "private":
|
||||
"""
|
||||
Vertex AI Model Garden deployed on private endpoint
|
||||
"""
|
||||
if instances is None:
|
||||
raise ValueError("instances are required for private endpoint")
|
||||
if llm_model is None:
|
||||
raise ValueError("Unable to pick client for private endpoint")
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -686,7 +709,7 @@ def completion(
|
|||
and "\nOutput:\n" in completion_response
|
||||
):
|
||||
completion_response = completion_response.split("\nOutput:\n", 1)[1]
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
response = TextStreamer(completion_response)
|
||||
return response
|
||||
|
||||
|
@ -715,7 +738,7 @@ def completion(
|
|||
else:
|
||||
# init prompt tokens
|
||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
||||
prompt_tokens, completion_tokens, _ = 0, 0, 0
|
||||
if response_obj is not None:
|
||||
if hasattr(response_obj, "usage_metadata") and hasattr(
|
||||
response_obj.usage_metadata, "prompt_token_count"
|
||||
|
@ -771,11 +794,13 @@ async def async_completion(
|
|||
try:
|
||||
import proto # type: ignore
|
||||
|
||||
response_obj = None
|
||||
completion_response = None
|
||||
if mode == "vision":
|
||||
print_verbose("\nMaking VertexAI Gemini Pro/Vision Call")
|
||||
print_verbose(f"\nProcessing input messages = {messages}")
|
||||
tools = optional_params.pop("tools", None)
|
||||
stream = optional_params.pop("stream", False)
|
||||
optional_params.pop("stream", False)
|
||||
|
||||
content = _gemini_convert_messages_with_history(messages=messages)
|
||||
|
||||
|
@ -817,7 +842,7 @@ async def async_completion(
|
|||
# Check if it's a RepeatedComposite instance
|
||||
for key, val in function_call.args.items():
|
||||
if isinstance(
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite
|
||||
val, proto.marshal.collections.repeated.RepeatedComposite # type: ignore
|
||||
):
|
||||
# If so, convert to list
|
||||
args_dict[key] = [v for v in val]
|
||||
|
@ -880,6 +905,11 @@ async def async_completion(
|
|||
"""
|
||||
from google.cloud import aiplatform # type: ignore
|
||||
|
||||
if vertex_project is None or vertex_location is None:
|
||||
raise ValueError(
|
||||
"Vertex project and location are required for custom endpoint"
|
||||
)
|
||||
|
||||
## LOGGING
|
||||
logging_obj.pre_call(
|
||||
input=prompt,
|
||||
|
@ -953,7 +983,7 @@ async def async_completion(
|
|||
else:
|
||||
# init prompt tokens
|
||||
# this block attempts to get usage from response_obj if it exists, if not it uses the litellm token counter
|
||||
prompt_tokens, completion_tokens, total_tokens = 0, 0, 0
|
||||
prompt_tokens, completion_tokens, _ = 0, 0, 0
|
||||
if response_obj is not None and (
|
||||
hasattr(response_obj, "usage_metadata")
|
||||
and hasattr(response_obj.usage_metadata, "prompt_token_count")
|
||||
|
@ -1001,6 +1031,7 @@ async def async_streaming(
|
|||
"""
|
||||
Add support for async streaming calls for gemini-pro
|
||||
"""
|
||||
response: Any = None
|
||||
if mode == "vision":
|
||||
stream = optional_params.pop("stream")
|
||||
tools = optional_params.pop("tools", None)
|
||||
|
@ -1065,6 +1096,11 @@ async def async_streaming(
|
|||
elif mode == "custom":
|
||||
from google.cloud import aiplatform # type: ignore
|
||||
|
||||
if vertex_project is None or vertex_location is None:
|
||||
raise ValueError(
|
||||
"Vertex project and location are required for custom endpoint"
|
||||
)
|
||||
|
||||
stream = optional_params.pop("stream", None)
|
||||
|
||||
## LOGGING
|
||||
|
@ -1102,6 +1138,8 @@ async def async_streaming(
|
|||
response = TextStreamer(completion_response)
|
||||
|
||||
elif mode == "private":
|
||||
if instances is None:
|
||||
raise ValueError("Instances are required for private endpoint")
|
||||
stream = optional_params.pop("stream", None)
|
||||
_ = instances[0].pop("stream", None)
|
||||
request_str += f"llm_model.predict_async(instances={instances})\n"
|
||||
|
@ -1118,6 +1156,9 @@ async def async_streaming(
|
|||
if stream:
|
||||
response = TextStreamer(completion_response)
|
||||
|
||||
if response is None:
|
||||
raise ValueError("Unable to generate response")
|
||||
|
||||
logging_obj.post_call(input=prompt, api_key=None, original_response=response)
|
||||
|
||||
streamwrapper = CustomStreamWrapper(
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import json
|
||||
import os
|
||||
import types
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Any, Literal, Optional, Union, cast
|
||||
|
||||
import httpx
|
||||
from pydantic import BaseModel
|
||||
|
@ -53,7 +53,7 @@ class VertexEmbedding(VertexBase):
|
|||
gemini_api_key: Optional[str] = None,
|
||||
extra_headers: Optional[dict] = None,
|
||||
):
|
||||
if aembedding == True:
|
||||
if aembedding is True:
|
||||
return self.async_embedding(
|
||||
model=model,
|
||||
input=input,
|
||||
|
|
|
@ -45,8 +45,8 @@ def completion(
|
|||
print_verbose: Callable,
|
||||
encoding,
|
||||
logging_obj,
|
||||
optional_params: dict,
|
||||
custom_prompt_dict={},
|
||||
optional_params=None,
|
||||
litellm_params=None,
|
||||
logger_fn=None,
|
||||
):
|
||||
|
@ -83,7 +83,7 @@ def completion(
|
|||
)
|
||||
|
||||
## COMPLETION CALL
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
return iter(outputs)
|
||||
else:
|
||||
## LOGGING
|
||||
|
@ -144,9 +144,6 @@ def batch_completions(
|
|||
llm, SamplingParams = validate_environment(model=model)
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
if "data parallel group is already initialized" in error_str:
|
||||
pass
|
||||
else:
|
||||
raise VLLMError(status_code=0, message=error_str)
|
||||
sampling_params = SamplingParams(**optional_params)
|
||||
prompts = []
|
||||
|
|
134
litellm/main.py
134
litellm/main.py
|
@ -106,6 +106,7 @@ from .llms.prompt_templates.factory import (
|
|||
custom_prompt,
|
||||
function_call_prompt,
|
||||
map_system_message_pt,
|
||||
ollama_pt,
|
||||
prompt_factory,
|
||||
stringify_json_tool_call_content,
|
||||
)
|
||||
|
@ -150,7 +151,6 @@ from .types.utils import (
|
|||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
from litellm.utils import (
|
||||
Choices,
|
||||
CustomStreamWrapper,
|
||||
EmbeddingResponse,
|
||||
ImageResponse,
|
||||
Message,
|
||||
|
@ -159,8 +159,6 @@ from litellm.utils import (
|
|||
TextCompletionResponse,
|
||||
TextCompletionStreamWrapper,
|
||||
TranscriptionResponse,
|
||||
get_secret,
|
||||
read_config_args,
|
||||
)
|
||||
|
||||
####### ENVIRONMENT VARIABLES ###################
|
||||
|
@ -214,7 +212,7 @@ class LiteLLM:
|
|||
class Chat:
|
||||
def __init__(self, params, router_obj: Optional[Any]):
|
||||
self.params = params
|
||||
if self.params.get("acompletion", False) == True:
|
||||
if self.params.get("acompletion", False) is True:
|
||||
self.params.pop("acompletion")
|
||||
self.completions: Union[AsyncCompletions, Completions] = AsyncCompletions(
|
||||
self.params, router_obj=router_obj
|
||||
|
@ -837,10 +835,10 @@ def completion(
|
|||
model_response = ModelResponse()
|
||||
setattr(model_response, "usage", litellm.Usage())
|
||||
if (
|
||||
kwargs.get("azure", False) == True
|
||||
kwargs.get("azure", False) is True
|
||||
): # don't remove flag check, to remain backwards compatible for repos like Codium
|
||||
custom_llm_provider = "azure"
|
||||
if deployment_id != None: # azure llms
|
||||
if deployment_id is not None: # azure llms
|
||||
model = deployment_id
|
||||
custom_llm_provider = "azure"
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(
|
||||
|
@ -1156,7 +1154,7 @@ def completion(
|
|||
client=client, # pass AsyncAzureOpenAI, AzureOpenAI client
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1278,7 +1276,7 @@ def completion(
|
|||
if (
|
||||
len(messages) > 0
|
||||
and "content" in messages[0]
|
||||
and type(messages[0]["content"]) == list
|
||||
and isinstance(messages[0]["content"], list)
|
||||
):
|
||||
# text-davinci-003 can accept a string or array, if it's an array, assume the array is set in messages[0]['content']
|
||||
# https://platform.openai.com/docs/api-reference/completions/create
|
||||
|
@ -1304,16 +1302,16 @@ def completion(
|
|||
)
|
||||
|
||||
if (
|
||||
optional_params.get("stream", False) == False
|
||||
and acompletion == False
|
||||
and text_completion == False
|
||||
optional_params.get("stream", False) is False
|
||||
and acompletion is False
|
||||
and text_completion is False
|
||||
):
|
||||
# convert to chat completion response
|
||||
_response = litellm.OpenAITextCompletionConfig().convert_to_chat_model_response_object(
|
||||
response_object=_response, model_response_object=model_response
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1519,7 +1517,7 @@ def completion(
|
|||
acompletion=acompletion,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) == True:
|
||||
if optional_params.get("stream", False) is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1566,7 +1564,7 @@ def completion(
|
|||
custom_prompt_dict=custom_prompt_dict,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
|
@ -1575,7 +1573,7 @@ def completion(
|
|||
original_response=model_response,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1654,7 +1652,7 @@ def completion(
|
|||
timeout=timeout,
|
||||
client=client,
|
||||
)
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1691,7 +1689,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
|
@ -1700,7 +1698,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -1740,7 +1738,7 @@ def completion(
|
|||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
|
@ -1788,7 +1786,7 @@ def completion(
|
|||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
|
@ -1836,7 +1834,7 @@ def completion(
|
|||
logging_obj=logging, # model call logging done inside the class as we make need to modify I/O to fit aleph alpha's requirements
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
|
@ -1875,7 +1873,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
|
@ -1916,7 +1914,7 @@ def completion(
|
|||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and optional_params["stream"] is True
|
||||
and acompletion is False
|
||||
):
|
||||
# don't try to access stream object,
|
||||
|
@ -1943,7 +1941,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
|
@ -2095,7 +2093,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
# fake palm streaming
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# fake streaming for palm
|
||||
resp_string = model_response["choices"][0]["message"]["content"]
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -2390,7 +2388,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
model_response,
|
||||
|
@ -2527,7 +2525,7 @@ def completion(
|
|||
)
|
||||
if (
|
||||
"stream" in optional_params
|
||||
and optional_params["stream"] == True
|
||||
and optional_params["stream"] is True
|
||||
and not isinstance(response, CustomStreamWrapper)
|
||||
):
|
||||
# don't try to access stream object,
|
||||
|
@ -2563,7 +2561,7 @@ def completion(
|
|||
)
|
||||
|
||||
if (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
"stream" in optional_params and optional_params["stream"] is True
|
||||
): ## [BETA]
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -2587,38 +2585,38 @@ def completion(
|
|||
if model in custom_prompt_dict:
|
||||
# check if the model has a registered custom prompt
|
||||
model_prompt_details = custom_prompt_dict[model]
|
||||
prompt = custom_prompt(
|
||||
ollama_prompt = custom_prompt(
|
||||
role_dict=model_prompt_details["roles"],
|
||||
initial_prompt_value=model_prompt_details["initial_prompt_value"],
|
||||
final_prompt_value=model_prompt_details["final_prompt_value"],
|
||||
messages=messages,
|
||||
)
|
||||
else:
|
||||
prompt = prompt_factory(
|
||||
model=model,
|
||||
messages=messages,
|
||||
custom_llm_provider=custom_llm_provider,
|
||||
)
|
||||
if isinstance(prompt, dict):
|
||||
modified_prompt = ollama_pt(model=model, messages=messages)
|
||||
if isinstance(modified_prompt, dict):
|
||||
# for multimode models - ollama/llava prompt_factory returns a dict {
|
||||
# "prompt": prompt,
|
||||
# "images": images
|
||||
# }
|
||||
prompt, images = prompt["prompt"], prompt["images"]
|
||||
ollama_prompt, images = (
|
||||
modified_prompt["prompt"],
|
||||
modified_prompt["images"],
|
||||
)
|
||||
optional_params["images"] = images
|
||||
|
||||
else:
|
||||
ollama_prompt = modified_prompt
|
||||
## LOGGING
|
||||
generator = ollama.get_ollama_response(
|
||||
api_base=api_base,
|
||||
model=model,
|
||||
prompt=prompt,
|
||||
prompt=ollama_prompt,
|
||||
optional_params=optional_params,
|
||||
logging_obj=logging,
|
||||
acompletion=acompletion,
|
||||
model_response=model_response,
|
||||
encoding=encoding,
|
||||
)
|
||||
if acompletion is True or optional_params.get("stream", False) == True:
|
||||
if acompletion is True or optional_params.get("stream", False) is True:
|
||||
return generator
|
||||
|
||||
response = generator
|
||||
|
@ -2701,7 +2699,7 @@ def completion(
|
|||
api_key=api_key,
|
||||
logging_obj=logging,
|
||||
)
|
||||
if "stream" in optional_params and optional_params["stream"] == True:
|
||||
if "stream" in optional_params and optional_params["stream"] is True:
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
response,
|
||||
|
@ -2710,7 +2708,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
|
||||
if optional_params.get("stream", False) or acompletion == True:
|
||||
if optional_params.get("stream", False) or acompletion is True:
|
||||
## LOGGING
|
||||
logging.post_call(
|
||||
input=messages,
|
||||
|
@ -2743,7 +2741,7 @@ def completion(
|
|||
logging_obj=logging,
|
||||
)
|
||||
if inspect.isgenerator(model_response) or (
|
||||
"stream" in optional_params and optional_params["stream"] == True
|
||||
"stream" in optional_params and optional_params["stream"] is True
|
||||
):
|
||||
# don't try to access stream object,
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -2771,7 +2769,7 @@ def completion(
|
|||
encoding=encoding,
|
||||
logging_obj=logging,
|
||||
)
|
||||
if stream == True: ## [BETA]
|
||||
if stream is True: ## [BETA]
|
||||
# Fake streaming for petals
|
||||
resp_string = model_response["choices"][0]["message"]["content"]
|
||||
response = CustomStreamWrapper(
|
||||
|
@ -2786,7 +2784,7 @@ def completion(
|
|||
import requests
|
||||
|
||||
url = litellm.api_base or api_base or ""
|
||||
if url == None or url == "":
|
||||
if url is None or url == "":
|
||||
raise ValueError(
|
||||
"api_base not set. Set api_base or litellm.api_base for custom endpoints"
|
||||
)
|
||||
|
@ -3145,10 +3143,10 @@ def batch_completion_models(*args, **kwargs):
|
|||
try:
|
||||
result = future.result()
|
||||
return result
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# if model 1 fails, continue with response from model 2, model3
|
||||
print_verbose(
|
||||
f"\n\ngot an exception, ignoring, removing from futures"
|
||||
"\n\ngot an exception, ignoring, removing from futures"
|
||||
)
|
||||
print_verbose(futures)
|
||||
new_futures = {}
|
||||
|
@ -3189,9 +3187,6 @@ def batch_completion_models_all_responses(*args, **kwargs):
|
|||
import concurrent.futures
|
||||
|
||||
# ANSI escape codes for colored output
|
||||
GREEN = "\033[92m"
|
||||
RED = "\033[91m"
|
||||
RESET = "\033[0m"
|
||||
|
||||
if "model" in kwargs:
|
||||
kwargs.pop("model")
|
||||
|
@ -3520,7 +3515,7 @@ def embedding(
|
|||
|
||||
if api_base is None:
|
||||
raise ValueError(
|
||||
f"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
||||
"No API Base provided for Azure OpenAI LLM provider. Set 'AZURE_API_BASE' in .env"
|
||||
)
|
||||
|
||||
## EMBEDDING CALL
|
||||
|
@ -4106,7 +4101,6 @@ def text_completion(
|
|||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
global print_verbose
|
||||
import copy
|
||||
|
||||
"""
|
||||
|
@ -4136,7 +4130,7 @@ def text_completion(
|
|||
Your example of how to use this function goes here.
|
||||
"""
|
||||
if "engine" in kwargs:
|
||||
if model == None:
|
||||
if model is None:
|
||||
# only use engine when model not passed
|
||||
model = kwargs["engine"]
|
||||
kwargs.pop("engine")
|
||||
|
@ -4189,18 +4183,18 @@ def text_completion(
|
|||
|
||||
if custom_llm_provider == "huggingface":
|
||||
# if echo == True, for TGI llms we need to set top_n_tokens to 3
|
||||
if echo == True:
|
||||
if echo is True:
|
||||
# for tgi llms
|
||||
if "top_n_tokens" not in kwargs:
|
||||
kwargs["top_n_tokens"] = 3
|
||||
|
||||
# processing prompt - users can pass raw tokens to OpenAI Completion()
|
||||
if type(prompt) == list:
|
||||
if isinstance(prompt, list):
|
||||
import concurrent.futures
|
||||
|
||||
tokenizer = tiktoken.encoding_for_model("text-davinci-003")
|
||||
## if it's a 2d list - each element in the list is a text_completion() request
|
||||
if len(prompt) > 0 and type(prompt[0]) == list:
|
||||
if len(prompt) > 0 and isinstance(prompt[0], list):
|
||||
responses = [None for x in prompt] # init responses
|
||||
|
||||
def process_prompt(i, individual_prompt):
|
||||
|
@ -4299,7 +4293,7 @@ def text_completion(
|
|||
raw_response = response._hidden_params.get("original_response", None)
|
||||
transformed_logprobs = litellm.utils.transform_logprobs(raw_response)
|
||||
except Exception as e:
|
||||
print_verbose(f"LiteLLM non blocking exception: {e}")
|
||||
verbose_logger.exception(f"LiteLLM non blocking exception: {e}")
|
||||
|
||||
if isinstance(response, TextCompletionResponse):
|
||||
return response
|
||||
|
@ -4813,12 +4807,12 @@ def transcription(
|
|||
Allows router to load balance between them
|
||||
"""
|
||||
atranscription = kwargs.get("atranscription", False)
|
||||
litellm_call_id = kwargs.get("litellm_call_id", None)
|
||||
logger_fn = kwargs.get("logger_fn", None)
|
||||
proxy_server_request = kwargs.get("proxy_server_request", None)
|
||||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
tags = kwargs.pop("tags", [])
|
||||
kwargs.get("litellm_call_id", None)
|
||||
kwargs.get("logger_fn", None)
|
||||
kwargs.get("proxy_server_request", None)
|
||||
kwargs.get("model_info", None)
|
||||
kwargs.get("metadata", {})
|
||||
kwargs.pop("tags", [])
|
||||
|
||||
drop_params = kwargs.get("drop_params", None)
|
||||
client: Optional[
|
||||
|
@ -4996,7 +4990,7 @@ def speech(
|
|||
model_info = kwargs.get("model_info", None)
|
||||
metadata = kwargs.get("metadata", {})
|
||||
model, custom_llm_provider, dynamic_api_key, api_base = get_llm_provider(model=model, custom_llm_provider=custom_llm_provider, api_base=api_base) # type: ignore
|
||||
tags = kwargs.pop("tags", [])
|
||||
kwargs.pop("tags", [])
|
||||
|
||||
optional_params = {}
|
||||
if response_format is not None:
|
||||
|
@ -5345,12 +5339,12 @@ def print_verbose(print_statement):
|
|||
verbose_logger.debug(print_statement)
|
||||
if litellm.set_verbose:
|
||||
print(print_statement) # noqa
|
||||
except:
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def config_completion(**kwargs):
|
||||
if litellm.config_path != None:
|
||||
if litellm.config_path is not None:
|
||||
config_args = read_config_args(litellm.config_path)
|
||||
# overwrite any args passed in with config args
|
||||
return completion(**kwargs, **config_args)
|
||||
|
@ -5408,16 +5402,18 @@ def stream_chunk_builder_text_completion(chunks: list, messages: Optional[List]
|
|||
response["choices"][0]["text"] = combined_content
|
||||
|
||||
if len(combined_content) > 0:
|
||||
completion_output = combined_content
|
||||
pass
|
||||
else:
|
||||
completion_output = ""
|
||||
pass
|
||||
# # Update usage information if needed
|
||||
try:
|
||||
response["usage"]["prompt_tokens"] = token_counter(
|
||||
model=model, messages=messages
|
||||
)
|
||||
except: # don't allow this failing to block a complete streaming response from being returned
|
||||
print_verbose(f"token_counter failed, assuming prompt tokens is 0")
|
||||
except (
|
||||
Exception
|
||||
): # don't allow this failing to block a complete streaming response from being returned
|
||||
print_verbose("token_counter failed, assuming prompt tokens is 0")
|
||||
response["usage"]["prompt_tokens"] = 0
|
||||
response["usage"]["completion_tokens"] = token_counter(
|
||||
model=model,
|
||||
|
|
|
@ -128,14 +128,14 @@ class LiteLLMBase(BaseModel):
|
|||
def json(self, **kwargs): # type: ignore
|
||||
try:
|
||||
return self.model_dump(**kwargs) # noqa
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
return self.dict(**kwargs)
|
||||
|
||||
def fields_set(self):
|
||||
try:
|
||||
return self.model_fields_set # noqa
|
||||
except:
|
||||
except Exception:
|
||||
# if using pydantic v1
|
||||
return self.__fields_set__
|
||||
|
||||
|
|
|
@ -1,202 +0,0 @@
|
|||
def add_new_model():
|
||||
import streamlit as st
|
||||
import json, requests, uuid
|
||||
|
||||
model_name = st.text_input(
|
||||
"Model Name - user-facing model name", placeholder="gpt-3.5-turbo"
|
||||
)
|
||||
st.subheader("LiteLLM Params")
|
||||
litellm_model_name = st.text_input(
|
||||
"Model", placeholder="azure/gpt-35-turbo-us-east"
|
||||
)
|
||||
litellm_api_key = st.text_input("API Key")
|
||||
litellm_api_base = st.text_input(
|
||||
"API Base",
|
||||
placeholder="https://my-endpoint.openai.azure.com",
|
||||
)
|
||||
litellm_api_version = st.text_input("API Version", placeholder="2023-07-01-preview")
|
||||
litellm_params = json.loads(
|
||||
st.text_area(
|
||||
"Additional Litellm Params (JSON dictionary). [See all possible inputs](https://github.com/BerriAI/litellm/blob/3f15d7230fe8e7492c95a752963e7fbdcaf7bf98/litellm/main.py#L293)",
|
||||
value={},
|
||||
)
|
||||
)
|
||||
st.subheader("Model Info")
|
||||
mode_options = ("completion", "embedding", "image generation")
|
||||
mode_selected = st.selectbox("Mode", mode_options)
|
||||
model_info = json.loads(
|
||||
st.text_area(
|
||||
"Additional Model Info (JSON dictionary)",
|
||||
value={},
|
||||
)
|
||||
)
|
||||
|
||||
if st.button("Submit"):
|
||||
try:
|
||||
model_info = {
|
||||
"model_name": model_name,
|
||||
"litellm_params": {
|
||||
"model": litellm_model_name,
|
||||
"api_key": litellm_api_key,
|
||||
"api_base": litellm_api_base,
|
||||
"api_version": litellm_api_version,
|
||||
},
|
||||
"model_info": {
|
||||
"id": str(uuid.uuid4()),
|
||||
"mode": mode_selected,
|
||||
},
|
||||
}
|
||||
# Make the POST request to the specified URL
|
||||
complete_url = ""
|
||||
if st.session_state["api_url"].endswith("/"):
|
||||
complete_url = f"{st.session_state['api_url']}model/new"
|
||||
else:
|
||||
complete_url = f"{st.session_state['api_url']}/model/new"
|
||||
|
||||
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
|
||||
response = requests.post(complete_url, json=model_info, headers=headers)
|
||||
|
||||
if response.status_code == 200:
|
||||
st.success("Model added successfully!")
|
||||
else:
|
||||
st.error(f"Failed to add model. Status code: {response.status_code}")
|
||||
|
||||
st.success("Form submitted successfully!")
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
|
||||
def list_models():
|
||||
import streamlit as st
|
||||
import requests
|
||||
|
||||
# Check if the necessary configuration is available
|
||||
if (
|
||||
st.session_state.get("api_url", None) is not None
|
||||
and st.session_state.get("proxy_key", None) is not None
|
||||
):
|
||||
# Make the GET request
|
||||
try:
|
||||
complete_url = ""
|
||||
if isinstance(st.session_state["api_url"], str) and st.session_state[
|
||||
"api_url"
|
||||
].endswith("/"):
|
||||
complete_url = f"{st.session_state['api_url']}models"
|
||||
else:
|
||||
complete_url = f"{st.session_state['api_url']}/models"
|
||||
response = requests.get(
|
||||
complete_url,
|
||||
headers={"Authorization": f"Bearer {st.session_state['proxy_key']}"},
|
||||
)
|
||||
# Check if the request was successful
|
||||
if response.status_code == 200:
|
||||
models = response.json()
|
||||
st.write(models) # or st.json(models) to pretty print the JSON
|
||||
else:
|
||||
st.error(f"Failed to get models. Status code: {response.status_code}")
|
||||
except Exception as e:
|
||||
st.error(f"An error occurred while requesting models: {e}")
|
||||
else:
|
||||
st.warning(
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
|
||||
def create_key():
|
||||
import streamlit as st
|
||||
import json, requests, uuid
|
||||
|
||||
if (
|
||||
st.session_state.get("api_url", None) is not None
|
||||
and st.session_state.get("proxy_key", None) is not None
|
||||
):
|
||||
duration = st.text_input("Duration - Can be in (h,m,s)", placeholder="1h")
|
||||
|
||||
models = st.text_input("Models it can access (separated by comma)", value="")
|
||||
models = models.split(",") if models else []
|
||||
|
||||
additional_params = json.loads(
|
||||
st.text_area(
|
||||
"Additional Key Params (JSON dictionary). [See all possible inputs](https://litellm-api.up.railway.app/#/key%20management/generate_key_fn_key_generate_post)",
|
||||
value={},
|
||||
)
|
||||
)
|
||||
|
||||
if st.button("Submit"):
|
||||
try:
|
||||
key_post_body = {
|
||||
"duration": duration,
|
||||
"models": models,
|
||||
**additional_params,
|
||||
}
|
||||
# Make the POST request to the specified URL
|
||||
complete_url = ""
|
||||
if st.session_state["api_url"].endswith("/"):
|
||||
complete_url = f"{st.session_state['api_url']}key/generate"
|
||||
else:
|
||||
complete_url = f"{st.session_state['api_url']}/key/generate"
|
||||
|
||||
headers = {"Authorization": f"Bearer {st.session_state['proxy_key']}"}
|
||||
response = requests.post(
|
||||
complete_url, json=key_post_body, headers=headers
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
st.success(f"Key added successfully! - {response.json()}")
|
||||
else:
|
||||
st.error(f"Failed to add Key. Status code: {response.status_code}")
|
||||
|
||||
st.success("Form submitted successfully!")
|
||||
except Exception as e:
|
||||
raise e
|
||||
else:
|
||||
st.warning(
|
||||
f"Please configure the Proxy Endpoint and Proxy Key on the Proxy Setup page. Currently set Proxy Endpoint: {st.session_state.get('api_url', None)} and Proxy Key: {st.session_state.get('proxy_key', None)}"
|
||||
)
|
||||
|
||||
|
||||
def streamlit_ui():
|
||||
import streamlit as st
|
||||
|
||||
st.header("Admin Configuration")
|
||||
|
||||
# Add a navigation sidebar
|
||||
st.sidebar.title("Navigation")
|
||||
page = st.sidebar.radio(
|
||||
"Go to", ("Proxy Setup", "Add Models", "List Models", "Create Key")
|
||||
)
|
||||
|
||||
# Initialize session state variables if not already present
|
||||
if "api_url" not in st.session_state:
|
||||
st.session_state["api_url"] = None
|
||||
if "proxy_key" not in st.session_state:
|
||||
st.session_state["proxy_key"] = None
|
||||
|
||||
# Display different pages based on navigation selection
|
||||
if page == "Proxy Setup":
|
||||
# Use text inputs with intermediary variables
|
||||
input_api_url = st.text_input(
|
||||
"Proxy Endpoint",
|
||||
value=st.session_state.get("api_url", ""),
|
||||
placeholder="http://0.0.0.0:8000",
|
||||
)
|
||||
input_proxy_key = st.text_input(
|
||||
"Proxy Key",
|
||||
value=st.session_state.get("proxy_key", ""),
|
||||
placeholder="sk-...",
|
||||
)
|
||||
# When the "Save" button is clicked, update the session state
|
||||
if st.button("Save"):
|
||||
st.session_state["api_url"] = input_api_url
|
||||
st.session_state["proxy_key"] = input_proxy_key
|
||||
st.success("Configuration saved!")
|
||||
elif page == "Add Models":
|
||||
add_new_model()
|
||||
elif page == "List Models":
|
||||
list_models()
|
||||
elif page == "Create Key":
|
||||
create_key()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
streamlit_ui()
|
|
@ -69,7 +69,7 @@ async def get_global_activity(
|
|||
try:
|
||||
if prisma_client is None:
|
||||
raise ValueError(
|
||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
|
||||
sql_query = """
|
||||
|
|
|
@ -132,7 +132,7 @@ def common_checks(
|
|||
# 6. [OPTIONAL] If 'enforce_user_param' enabled - did developer pass in 'user' param for openai endpoints
|
||||
if (
|
||||
general_settings.get("enforce_user_param", None) is not None
|
||||
and general_settings["enforce_user_param"] == True
|
||||
and general_settings["enforce_user_param"] is True
|
||||
):
|
||||
if is_llm_api_route(route=route) and "user" not in request_body:
|
||||
raise Exception(
|
||||
|
@ -557,7 +557,7 @@ async def get_team_object(
|
|||
)
|
||||
|
||||
return _response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Team doesn't exist in db. Team={team_id}. Create team via `/team/new` call."
|
||||
)
|
||||
|
@ -664,7 +664,7 @@ async def get_org_object(
|
|||
raise Exception
|
||||
|
||||
return response
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Organization doesn't exist in db. Organization={org_id}. Create organization via `/organization/new` call."
|
||||
)
|
||||
|
|
|
@ -98,7 +98,7 @@ class LicenseCheck:
|
|||
elif self._verify(license_str=self.license_str) is True:
|
||||
return True
|
||||
return False
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def verify_license_without_api_request(self, public_key, license_key):
|
||||
|
|
|
@ -112,7 +112,6 @@ async def user_api_key_auth(
|
|||
),
|
||||
) -> UserAPIKeyAuth:
|
||||
from litellm.proxy.proxy_server import (
|
||||
custom_db_client,
|
||||
general_settings,
|
||||
jwt_handler,
|
||||
litellm_proxy_admin_name,
|
||||
|
@ -476,7 +475,7 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
if route == "/user/auth":
|
||||
if general_settings.get("allow_user_auth", False) == True:
|
||||
if general_settings.get("allow_user_auth", False) is True:
|
||||
return UserAPIKeyAuth()
|
||||
else:
|
||||
raise HTTPException(
|
||||
|
@ -597,7 +596,7 @@ async def user_api_key_auth(
|
|||
## VALIDATE MASTER KEY ##
|
||||
try:
|
||||
assert isinstance(master_key, str)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
|
@ -648,7 +647,7 @@ async def user_api_key_auth(
|
|||
)
|
||||
|
||||
if (
|
||||
prisma_client is None and custom_db_client is None
|
||||
prisma_client is None
|
||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
raise Exception("No connected db.")
|
||||
|
||||
|
@ -722,9 +721,9 @@ async def user_api_key_auth(
|
|||
|
||||
if config != {}:
|
||||
model_list = config.get("model_list", [])
|
||||
llm_model_list = model_list
|
||||
new_model_list = model_list
|
||||
verbose_proxy_logger.debug(
|
||||
f"\n new llm router model list {llm_model_list}"
|
||||
f"\n new llm router model list {new_model_list}"
|
||||
)
|
||||
if (
|
||||
len(valid_token.models) == 0
|
||||
|
|
|
@ -2,6 +2,7 @@ import sys
|
|||
from typing import Any, Dict, List, Optional, get_args
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy._types import CommonProxyErrors, LiteLLMPromptInjectionParams
|
||||
from litellm.proxy.utils import get_instance_fn
|
||||
|
@ -59,9 +60,15 @@ def initialize_callbacks_on_proxy(
|
|||
presidio_logging_only
|
||||
) # validate boolean given
|
||||
|
||||
params = {
|
||||
_presidio_params = {}
|
||||
if "presidio" in callback_specific_params and isinstance(
|
||||
callback_specific_params["presidio"], dict
|
||||
):
|
||||
_presidio_params = callback_specific_params["presidio"]
|
||||
|
||||
params: Dict[str, Any] = {
|
||||
"logging_only": presidio_logging_only,
|
||||
**callback_specific_params.get("presidio", {}),
|
||||
**_presidio_params,
|
||||
}
|
||||
pii_masking_object = _OPTIONAL_PresidioPIIMasking(**params)
|
||||
imported_list.append(pii_masking_object)
|
||||
|
@ -70,7 +77,7 @@ def initialize_callbacks_on_proxy(
|
|||
_ENTERPRISE_LlamaGuard,
|
||||
)
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Llama Guard"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -83,7 +90,7 @@ def initialize_callbacks_on_proxy(
|
|||
_ENTERPRISE_SecretDetection,
|
||||
)
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use secret hiding"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -96,7 +103,7 @@ def initialize_callbacks_on_proxy(
|
|||
_ENTERPRISE_OpenAI_Moderation,
|
||||
)
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use OpenAI Moderations Check"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -126,7 +133,7 @@ def initialize_callbacks_on_proxy(
|
|||
_ENTERPRISE_GoogleTextModeration,
|
||||
)
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Google Text Moderation"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -137,7 +144,7 @@ def initialize_callbacks_on_proxy(
|
|||
elif isinstance(callback, str) and callback == "llmguard_moderations":
|
||||
from enterprise.enterprise_hooks.llm_guard import _ENTERPRISE_LLMGuard
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use Llm Guard"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -150,7 +157,7 @@ def initialize_callbacks_on_proxy(
|
|||
_ENTERPRISE_BlockedUserList,
|
||||
)
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use ENTERPRISE BlockedUser"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -165,7 +172,7 @@ def initialize_callbacks_on_proxy(
|
|||
_ENTERPRISE_BannedKeywords,
|
||||
)
|
||||
|
||||
if premium_user != True:
|
||||
if premium_user is not True:
|
||||
raise Exception(
|
||||
"Trying to use ENTERPRISE BannedKeyword"
|
||||
+ CommonProxyErrors.not_premium_user.value
|
||||
|
@ -212,7 +219,7 @@ def initialize_callbacks_on_proxy(
|
|||
and isinstance(v, str)
|
||||
and v.startswith("os.environ/")
|
||||
):
|
||||
azure_content_safety_params[k] = litellm.get_secret(v)
|
||||
azure_content_safety_params[k] = get_secret(v)
|
||||
|
||||
azure_content_safety_obj = _PROXY_AzureContentSafety(
|
||||
**azure_content_safety_params,
|
||||
|
|
|
@ -6,13 +6,14 @@ import tracemalloc
|
|||
from fastapi import APIRouter
|
||||
|
||||
import litellm
|
||||
from litellm import get_secret, get_secret_str
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
||||
try:
|
||||
import objgraph
|
||||
import objgraph # type: ignore
|
||||
|
||||
print("growth of objects") # noqa
|
||||
objgraph.show_growth()
|
||||
|
@ -21,8 +22,10 @@ if os.environ.get("LITELLM_PROFILE", "false").lower() == "true":
|
|||
roots = objgraph.get_leaking_objects()
|
||||
print("\n\nLeaking objects") # noqa
|
||||
objgraph.show_most_common_types(objects=roots)
|
||||
except:
|
||||
pass
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"objgraph not found. Please install objgraph to use this feature."
|
||||
)
|
||||
|
||||
tracemalloc.start(10)
|
||||
|
||||
|
@ -57,15 +60,20 @@ async def memory_usage_in_mem_cache():
|
|||
user_api_key_cache,
|
||||
)
|
||||
|
||||
num_items_in_user_api_key_cache = len(
|
||||
user_api_key_cache.in_memory_cache.cache_dict
|
||||
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||
if llm_router is None:
|
||||
num_items_in_llm_router_cache = 0
|
||||
else:
|
||||
num_items_in_llm_router_cache = len(
|
||||
llm_router.cache.in_memory_cache.cache_dict
|
||||
) + len(llm_router.cache.in_memory_cache.ttl_dict)
|
||||
|
||||
num_items_in_user_api_key_cache = len(
|
||||
user_api_key_cache.in_memory_cache.cache_dict
|
||||
) + len(user_api_key_cache.in_memory_cache.ttl_dict)
|
||||
|
||||
num_items_in_proxy_logging_obj_cache = len(
|
||||
proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict
|
||||
) + len(proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict)
|
||||
proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict
|
||||
) + len(proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict)
|
||||
|
||||
return {
|
||||
"num_items_in_user_api_key_cache": num_items_in_user_api_key_cache,
|
||||
|
@ -89,13 +97,20 @@ async def memory_usage_in_mem_cache_items():
|
|||
user_api_key_cache,
|
||||
)
|
||||
|
||||
if llm_router is None:
|
||||
llm_router_in_memory_cache_dict = {}
|
||||
llm_router_in_memory_ttl_dict = {}
|
||||
else:
|
||||
llm_router_in_memory_cache_dict = llm_router.cache.in_memory_cache.cache_dict
|
||||
llm_router_in_memory_ttl_dict = llm_router.cache.in_memory_cache.ttl_dict
|
||||
|
||||
return {
|
||||
"user_api_key_cache": user_api_key_cache.in_memory_cache.cache_dict,
|
||||
"user_api_key_ttl": user_api_key_cache.in_memory_cache.ttl_dict,
|
||||
"llm_router_cache": llm_router.cache.in_memory_cache.cache_dict,
|
||||
"llm_router_ttl": llm_router.cache.in_memory_cache.ttl_dict,
|
||||
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.in_memory_cache.cache_dict,
|
||||
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.in_memory_cache.ttl_dict,
|
||||
"llm_router_cache": llm_router_in_memory_cache_dict,
|
||||
"llm_router_ttl": llm_router_in_memory_ttl_dict,
|
||||
"proxy_logging_obj_cache": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.cache_dict,
|
||||
"proxy_logging_obj_ttl": proxy_logging_obj.internal_usage_cache.dual_cache.in_memory_cache.ttl_dict,
|
||||
}
|
||||
|
||||
|
||||
|
@ -104,9 +119,18 @@ async def get_otel_spans():
|
|||
from litellm.integrations.opentelemetry import OpenTelemetry
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
open_telemetry_logger: OpenTelemetry = open_telemetry_logger
|
||||
if open_telemetry_logger is None:
|
||||
return {
|
||||
"otel_spans": [],
|
||||
"spans_grouped_by_parent": {},
|
||||
"most_recent_parent": None,
|
||||
}
|
||||
|
||||
otel_exporter = open_telemetry_logger.OTEL_EXPORTER
|
||||
recorded_spans = otel_exporter.get_finished_spans()
|
||||
if hasattr(otel_exporter, "get_finished_spans"):
|
||||
recorded_spans = otel_exporter.get_finished_spans() # type: ignore
|
||||
else:
|
||||
recorded_spans = []
|
||||
|
||||
print("Spans: ", recorded_spans) # noqa
|
||||
|
||||
|
@ -137,11 +161,13 @@ async def get_otel_spans():
|
|||
# Helper functions for debugging
|
||||
def init_verbose_loggers():
|
||||
try:
|
||||
worker_config = litellm.get_secret("WORKER_CONFIG")
|
||||
worker_config = get_secret_str("WORKER_CONFIG")
|
||||
# if not, assume it's a json string
|
||||
if worker_config is None:
|
||||
return
|
||||
if os.path.isfile(worker_config):
|
||||
return
|
||||
# if not, assume it's a json string
|
||||
_settings = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
_settings = json.loads(worker_config)
|
||||
if not isinstance(_settings, dict):
|
||||
return
|
||||
|
||||
|
@ -162,7 +188,7 @@ def init_verbose_loggers():
|
|||
level=logging.INFO
|
||||
) # set router logs to info
|
||||
verbose_proxy_logger.setLevel(level=logging.INFO) # set proxy logs to info
|
||||
if detailed_debug == True:
|
||||
if detailed_debug is True:
|
||||
import logging
|
||||
|
||||
from litellm._logging import (
|
||||
|
@ -178,10 +204,10 @@ def init_verbose_loggers():
|
|||
verbose_proxy_logger.setLevel(
|
||||
level=logging.DEBUG
|
||||
) # set proxy logs to debug
|
||||
elif debug == False and detailed_debug == False:
|
||||
elif debug is False and detailed_debug is False:
|
||||
# users can control proxy debugging using env variable = 'LITELLM_LOG'
|
||||
litellm_log_setting = os.environ.get("LITELLM_LOG", "")
|
||||
if litellm_log_setting != None:
|
||||
if litellm_log_setting is not None:
|
||||
if litellm_log_setting.upper() == "INFO":
|
||||
import logging
|
||||
|
||||
|
@ -213,4 +239,6 @@ def init_verbose_loggers():
|
|||
level=logging.DEBUG
|
||||
) # set proxy logs to debug
|
||||
except Exception as e:
|
||||
verbose_logger.warning(f"Failed to init verbose loggers: {str(e)}")
|
||||
import logging
|
||||
|
||||
logging.warning(f"Failed to init verbose loggers: {str(e)}")
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue